fix task v2 download issue (#3220)
This commit is contained in:
@@ -376,7 +376,9 @@ class ForgeAgent:
|
||||
try:
|
||||
if task.workflow_run_id:
|
||||
list_files_before = list_files_in_directory(
|
||||
get_path_for_workflow_download_directory(task.workflow_run_id)
|
||||
get_path_for_workflow_download_directory(
|
||||
context.run_id if context and context.run_id else task.workflow_run_id
|
||||
)
|
||||
)
|
||||
# Check some conditions before executing the step, throw an exception if the step can't be executed
|
||||
await app.AGENT_FUNCTION.validate_step_execution(task, step)
|
||||
@@ -457,7 +459,9 @@ class ForgeAgent:
|
||||
retry = False
|
||||
|
||||
if task_block and task_block.complete_on_download and task.workflow_run_id:
|
||||
workflow_download_directory = get_path_for_workflow_download_directory(task.workflow_run_id)
|
||||
workflow_download_directory = get_path_for_workflow_download_directory(
|
||||
context.run_id if context and context.run_id else task.workflow_run_id
|
||||
)
|
||||
|
||||
downloading_files: list[Path] = list_downloading_files_in_directory(workflow_download_directory)
|
||||
if len(downloading_files) > 0:
|
||||
@@ -2219,8 +2223,10 @@ class ForgeAgent:
|
||||
if task.organization_id:
|
||||
try:
|
||||
async with asyncio.timeout(SAVE_DOWNLOADED_FILES_TIMEOUT):
|
||||
context = skyvern_context.current()
|
||||
await app.STORAGE.save_downloaded_files(
|
||||
task.organization_id, task_id=task.task_id, workflow_run_id=task.workflow_run_id
|
||||
organization_id=task.organization_id,
|
||||
run_id=context.run_id if context and context.run_id else task.workflow_run_id or task.task_id,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning(
|
||||
@@ -2391,8 +2397,10 @@ class ForgeAgent:
|
||||
if task.organization_id:
|
||||
try:
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
context = skyvern_context.current()
|
||||
downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=task.organization_id, task_id=task.task_id, workflow_run_id=task.workflow_run_id
|
||||
organization_id=task.organization_id,
|
||||
run_id=context.run_id if context and context.run_id else task.workflow_run_id or task.task_id,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning(
|
||||
|
||||
@@ -173,12 +173,12 @@ def unzip_files(zip_file_path: str, output_dir: str) -> None:
|
||||
zip_ref.extractall(output_dir)
|
||||
|
||||
|
||||
def get_path_for_workflow_download_directory(workflow_run_id: str) -> Path:
|
||||
return Path(get_download_dir(workflow_run_id=workflow_run_id, task_id=None))
|
||||
def get_path_for_workflow_download_directory(run_id: str | None) -> Path:
|
||||
return Path(get_download_dir(run_id=run_id))
|
||||
|
||||
|
||||
def get_download_dir(workflow_run_id: str | None, task_id: str | None) -> str:
|
||||
download_dir = f"{REPO_ROOT_DIR}/downloads/{workflow_run_id or task_id}"
|
||||
def get_download_dir(run_id: str | None) -> str:
|
||||
download_dir = f"{REPO_ROOT_DIR}/downloads/{run_id}"
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
return download_dir
|
||||
|
||||
|
||||
@@ -124,15 +124,11 @@ class BaseStorage(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def save_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> None:
|
||||
async def save_downloaded_files(self, organization_id: str, run_id: str | None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> list[FileInfo]:
|
||||
async def get_downloaded_files(self, organization_id: str, run_id: str | None) -> list[FileInfo]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -171,15 +171,11 @@ class LocalStorage(BaseStorage):
|
||||
return None
|
||||
return str(stored_folder_path)
|
||||
|
||||
async def save_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> None:
|
||||
async def save_downloaded_files(self, organization_id: str, run_id: str | None) -> None:
|
||||
pass
|
||||
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> list[FileInfo]:
|
||||
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
|
||||
async def get_downloaded_files(self, organization_id: str, run_id: str | None) -> list[FileInfo]:
|
||||
download_dir = get_download_dir(run_id=run_id)
|
||||
file_infos: list[FileInfo] = []
|
||||
files_and_folders = os.listdir(download_dir)
|
||||
for file_or_folder in files_and_folders:
|
||||
|
||||
@@ -195,14 +195,14 @@ class S3Storage(BaseStorage):
|
||||
temp_zip_file.close()
|
||||
return temp_dir
|
||||
|
||||
async def save_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> None:
|
||||
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
|
||||
async def save_downloaded_files(self, organization_id: str, run_id: str | None) -> None:
|
||||
download_dir = get_download_dir(run_id=run_id)
|
||||
files = os.listdir(download_dir)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
base_uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}"
|
||||
base_uri = (
|
||||
f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
)
|
||||
for file in files:
|
||||
fpath = os.path.join(download_dir, file)
|
||||
if not os.path.isfile(fpath):
|
||||
@@ -225,10 +225,8 @@ class S3Storage(BaseStorage):
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> list[FileInfo]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}"
|
||||
async def get_downloaded_files(self, organization_id: str, run_id: str | None) -> list[FileInfo]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
object_keys = await self.async_client.list_files(uri=uri)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
@@ -1530,7 +1530,7 @@ async def get_workflow_run_with_workflow_id(
|
||||
organization_id=current_org.organization_id,
|
||||
include_cost=True,
|
||||
)
|
||||
return_dict = workflow_run_status_response.model_dump()
|
||||
return_dict = workflow_run_status_response.model_dump(by_alias=True)
|
||||
|
||||
browser_session = await app.DATABASE.get_persistent_browser_session_by_runnable_id(
|
||||
runnable_id=workflow_run_id,
|
||||
@@ -1541,14 +1541,6 @@ async def get_workflow_run_with_workflow_id(
|
||||
|
||||
return_dict["browser_session_id"] = browser_session_id or return_dict.get("browser_session_id")
|
||||
|
||||
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
|
||||
if task_v2:
|
||||
return_dict["task_v2"] = task_v2.model_dump(by_alias=True)
|
||||
|
||||
return return_dict
|
||||
|
||||
|
||||
|
||||
@@ -739,8 +739,9 @@ class BaseTaskBlock(Block):
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=workflow_run.organization_id,
|
||||
task_id=updated_task.task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
run_id=current_context.run_id
|
||||
if current_context and current_context.run_id
|
||||
else workflow_run_id or updated_task.task_id,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning("Timeout getting downloaded files", task_id=updated_task.task_id)
|
||||
@@ -798,8 +799,9 @@ class BaseTaskBlock(Block):
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=workflow_run.organization_id,
|
||||
task_id=updated_task.task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
run_id=current_context.run_id
|
||||
if current_context and current_context.run_id
|
||||
else workflow_run_id or updated_task.task_id,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
@@ -1816,7 +1818,12 @@ class UploadToS3Block(Block):
|
||||
self.path = file_path_parameter_value
|
||||
# if the path is WORKFLOW_DOWNLOAD_DIRECTORY_PARAMETER_KEY, use the download directory for the workflow run
|
||||
elif self.path == settings.WORKFLOW_DOWNLOAD_DIRECTORY_PARAMETER_KEY:
|
||||
self.path = str(get_path_for_workflow_download_directory(workflow_run_id).absolute())
|
||||
context = skyvern_context.current()
|
||||
self.path = str(
|
||||
get_path_for_workflow_download_directory(
|
||||
context.run_id if context and context.run_id else workflow_run_id
|
||||
).absolute()
|
||||
)
|
||||
|
||||
try:
|
||||
self.format_potential_template_parameters(workflow_run_context)
|
||||
@@ -2011,7 +2018,12 @@ class FileUploadBlock(Block):
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
download_files_path = str(get_path_for_workflow_download_directory(workflow_run_id).absolute())
|
||||
context = skyvern_context.current()
|
||||
download_files_path = str(
|
||||
get_path_for_workflow_download_directory(
|
||||
context.run_id if context and context.run_id else workflow_run_id
|
||||
).absolute()
|
||||
)
|
||||
|
||||
uploaded_uris = []
|
||||
try:
|
||||
@@ -2197,7 +2209,12 @@ class SendEmailBlock(Block):
|
||||
|
||||
if path == settings.WORKFLOW_DOWNLOAD_DIRECTORY_PARAMETER_KEY:
|
||||
# if the path is WORKFLOW_DOWNLOAD_DIRECTORY_PARAMETER_KEY, use download directory for the workflow run
|
||||
path = str(get_path_for_workflow_download_directory(workflow_run_id).absolute())
|
||||
context = skyvern_context.current()
|
||||
path = str(
|
||||
get_path_for_workflow_download_directory(
|
||||
context.run_id if context and context.run_id else workflow_run_id
|
||||
).absolute()
|
||||
)
|
||||
LOG.info(
|
||||
"SendEmailBlock: Using download directory for the workflow run",
|
||||
workflow_run_id=workflow_run_id,
|
||||
|
||||
@@ -1155,6 +1155,11 @@ class WorkflowService:
|
||||
raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id)
|
||||
|
||||
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
|
||||
task_v2 = await app.DATABASE.get_task_v2_by_workflow_run_id(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
screenshot_artifacts = []
|
||||
screenshot_urls: list[str] | None = None
|
||||
@@ -1184,15 +1189,22 @@ class WorkflowService:
|
||||
if recording_artifact:
|
||||
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
|
||||
|
||||
downloaded_files: list[FileInfo] | None = None
|
||||
downloaded_files: list[FileInfo] = []
|
||||
downloaded_file_urls: list[str] | None = None
|
||||
try:
|
||||
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
|
||||
context = skyvern_context.current()
|
||||
downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=workflow_run.organization_id,
|
||||
task_id=None,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
run_id=context.run_id if context and context.run_id else workflow_run.workflow_run_id,
|
||||
)
|
||||
if task_v2:
|
||||
task_v2_downloaded_files = await app.STORAGE.get_downloaded_files(
|
||||
organization_id=workflow_run.organization_id,
|
||||
run_id=task_v2.observer_cruise_id,
|
||||
)
|
||||
if task_v2_downloaded_files:
|
||||
downloaded_files.extend(task_v2_downloaded_files)
|
||||
if downloaded_files:
|
||||
downloaded_file_urls = [file_info.url for file_info in downloaded_files]
|
||||
except asyncio.TimeoutError:
|
||||
@@ -1267,6 +1279,7 @@ class WorkflowService:
|
||||
workflow_title=workflow.title,
|
||||
browser_session_id=workflow_run.browser_session_id,
|
||||
max_screenshot_scrolls=workflow_run.max_screenshot_scrolls,
|
||||
task_v2=task_v2,
|
||||
)
|
||||
|
||||
async def clean_up_workflow(
|
||||
@@ -1304,8 +1317,10 @@ class WorkflowService:
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(SAVE_DOWNLOADED_FILES_TIMEOUT):
|
||||
context = skyvern_context.current()
|
||||
await app.STORAGE.save_downloaded_files(
|
||||
workflow_run.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id
|
||||
organization_id=workflow_run.organization_id,
|
||||
run_id=context.run_id if context and context.run_id else workflow_run.workflow_run_id,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
LOG.warning(
|
||||
|
||||
@@ -790,7 +790,10 @@ async def handle_click_to_download_file_action(
|
||||
skyvern_element = await dom.get_skyvern_element_by_id(action.element_id)
|
||||
locator = skyvern_element.locator
|
||||
|
||||
download_dir = Path(get_download_dir(workflow_run_id=task.workflow_run_id, task_id=task.task_id))
|
||||
context = skyvern_context.current()
|
||||
download_dir = Path(
|
||||
get_download_dir(run_id=context.run_id if context and context.run_id else task.workflow_run_id or task.task_id)
|
||||
)
|
||||
list_files_before = list_files_in_directory(download_dir)
|
||||
LOG.info(
|
||||
"Number of files in download directory before click",
|
||||
|
||||
@@ -131,7 +131,9 @@ def set_download_file_listener(browser_context: BrowserContext, **kwargs: Any) -
|
||||
|
||||
def initialize_download_dir() -> str:
|
||||
context = ensure_context()
|
||||
return get_download_dir(context.workflow_run_id, context.task_id)
|
||||
return get_download_dir(
|
||||
context.run_id if context and context.run_id else context.workflow_run_id or context.task_id
|
||||
)
|
||||
|
||||
|
||||
class BrowserContextCreator(Protocol):
|
||||
|
||||
Reference in New Issue
Block a user