From 87061f5bb63dc70da53a4a36b3242d5b06531ae3 Mon Sep 17 00:00:00 2001 From: LawyZheng Date: Fri, 29 Nov 2024 16:05:44 +0800 Subject: [PATCH] upload all downloaded files when using s3 (#1289) --- skyvern/config.py | 1 + skyvern/constants.py | 2 + skyvern/forge/agent.py | 126 ++++++++++++------ skyvern/forge/sdk/api/aws.py | 12 ++ skyvern/forge/sdk/api/files.py | 8 +- skyvern/forge/sdk/artifact/storage/base.py | 12 ++ skyvern/forge/sdk/artifact/storage/local.py | 19 ++- skyvern/forge/sdk/artifact/storage/s3.py | 30 +++++ skyvern/forge/sdk/schemas/tasks.py | 3 + skyvern/forge/sdk/workflow/models/workflow.py | 1 + skyvern/forge/sdk/workflow/service.py | 38 ++++++ skyvern/webeye/browser_factory.py | 9 +- 12 files changed, 211 insertions(+), 50 deletions(-) diff --git a/skyvern/config.py b/skyvern/config.py index bce59736..2dc982b7 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -52,6 +52,7 @@ class Settings(BaseSettings): GENERATE_PRESIGNED_URLS: bool = False AWS_S3_BUCKET_ARTIFACTS: str = "skyvern-artifacts" AWS_S3_BUCKET_SCREENSHOTS: str = "skyvern-screenshots" + AWS_S3_BUCKET_DOWNLOADS: str = "skyvern-uploads" AWS_S3_BUCKET_BROWSER_SESSIONS: str = "skyvern-browser-sessions" # Supported storage types: local, s3 diff --git a/skyvern/constants.py b/skyvern/constants.py index 1794053c..c9a19871 100644 --- a/skyvern/constants.py +++ b/skyvern/constants.py @@ -11,6 +11,8 @@ PAGE_CONTENT_TIMEOUT = 300 # 5 mins BUILDING_ELEMENT_TREE_TIMEOUT_MS = 60 * 1000 # 1 minute BROWSER_CLOSE_TIMEOUT = 180 # 3 minute BROWSER_DOWNLOAD_TIMEOUT = 600 # 10 minute +SAVE_DOWNLOADED_FILES_TIMEOUT = 180 +GET_DOWNLOADED_FILES_TIMEOUT = 30 # reserved fields for navigation payload SPECIAL_FIELD_VERIFICATION_CODE = "verification_code" diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 8e7eed1b..4d2704a3 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -14,7 +14,13 @@ from playwright._impl._errors import TargetClosedError from playwright.async_api import Page from skyvern import analytics -from skyvern.constants import SCRAPE_TYPE_ORDER, SPECIAL_FIELD_VERIFICATION_CODE, ScrapeType +from skyvern.constants import ( + GET_DOWNLOADED_FILES_TIMEOUT, + SAVE_DOWNLOADED_FILES_TIMEOUT, + SCRAPE_TYPE_ORDER, + SPECIAL_FIELD_VERIFICATION_CODE, + ScrapeType, +) from skyvern.exceptions import ( BrowserStateMissingPage, EmptyScrapePage, @@ -1491,6 +1497,26 @@ class ForgeAgent: ) return + if task.organization_id: + try: + async with asyncio.timeout(SAVE_DOWNLOADED_FILES_TIMEOUT): + await app.STORAGE.save_downloaded_files( + task.organization_id, task_id=task.task_id, workflow_run_id=None + ) + except asyncio.TimeoutError: + LOG.warning( + "Timeout to save downloaded files", + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + ) + except Exception: + LOG.warning( + "Failed to save downloaded files", + exc_info=True, + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + ) + await self.async_operation_pool.remove_task(task.task_id) await self.cleanup_browser_and_create_artifacts(close_browser_on_completion, last_step, task) @@ -1520,6 +1546,11 @@ class ForgeAgent: ) return + screenshot_url = None + recording_url = None + latest_action_screenshot_urls: list[str] | None = None + downloaded_file_urls: list[str] | None = None + # get the artifact of the screenshot and get the screenshot_url screenshot_artifact = await app.DATABASE.get_artifact( task_id=task.task_id, @@ -1527,50 +1558,65 @@ class ForgeAgent: artifact_type=ArtifactType.SCREENSHOT_FINAL, organization_id=task.organization_id, ) - if screenshot_artifact is None: - screenshot_url = None - if screenshot_artifact: - screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact) + if screenshot_artifact: + screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact) - recording_artifact = await app.DATABASE.get_artifact( - task_id=task.task_id, - step_id=last_step.step_id, - artifact_type=ArtifactType.RECORDING, - organization_id=task.organization_id, - ) - recording_url = None - if recording_artifact: - recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact) + recording_artifact = await app.DATABASE.get_artifact( + task_id=task.task_id, + step_id=last_step.step_id, + artifact_type=ArtifactType.RECORDING, + organization_id=task.organization_id, + ) + if recording_artifact: + recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact) # get the artifact of the last TASK_RESPONSE_ACTION_SCREENSHOT_COUNT screenshots and get the screenshot_url - latest_action_screenshot_artifacts = await app.DATABASE.get_latest_n_artifacts( - task_id=task.task_id, - organization_id=task.organization_id, - artifact_types=[ArtifactType.SCREENSHOT_ACTION], - n=SettingsManager.get_settings().TASK_RESPONSE_ACTION_SCREENSHOT_COUNT, - ) - latest_action_screenshot_urls: list[str] | None = [] - if latest_action_screenshot_artifacts: - latest_action_screenshot_urls = await app.ARTIFACT_MANAGER.get_share_links( - latest_action_screenshot_artifacts - ) - else: - LOG.error("Failed to get latest action screenshots") - - # get the latest task from the db to get the latest status, extracted_information, and failure_reason - task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id) - if not task_from_db: - LOG.error("Failed to get task from db when sending task response") - raise TaskNotFound(task_id=task.task_id) - - task = task_from_db - task_response = task.to_task_response( - action_screenshot_urls=latest_action_screenshot_urls, - screenshot_url=screenshot_url, - recording_url=recording_url, + latest_action_screenshot_artifacts = await app.DATABASE.get_latest_n_artifacts( + task_id=task.task_id, + organization_id=task.organization_id, + artifact_types=[ArtifactType.SCREENSHOT_ACTION], + n=SettingsManager.get_settings().TASK_RESPONSE_ACTION_SCREENSHOT_COUNT, + ) + if latest_action_screenshot_artifacts: + latest_action_screenshot_urls = await app.ARTIFACT_MANAGER.get_share_links( + latest_action_screenshot_artifacts ) else: - task_response = task.to_task_response() + LOG.error("Failed to get latest action screenshots") + + if task.organization_id: + try: + async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT): + downloaded_file_urls = await app.STORAGE.get_downloaded_files( + organization_id=task.organization_id, task_id=task.task_id, workflow_run_id=task.workflow_run_id + ) + except asyncio.TimeoutError: + LOG.warning( + "Timeout to get downloaded files", + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + ) + except Exception: + LOG.warning( + "Failed to get downloaded files", + exc_info=True, + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + ) + + # get the latest task from the db to get the latest status, extracted_information, and failure_reason + task_from_db = await app.DATABASE.get_task(task_id=task.task_id, organization_id=task.organization_id) + if not task_from_db: + LOG.error("Failed to get task from db when sending task response") + raise TaskNotFound(task_id=task.task_id) + + task = task_from_db + task_response = task.to_task_response( + action_screenshot_urls=latest_action_screenshot_urls, + screenshot_url=screenshot_url, + recording_url=recording_url, + downloaded_file_urls=downloaded_file_urls, + ) if not task.webhook_callback_url: LOG.info("Task has no webhook callback url. Not sending task response") diff --git a/skyvern/forge/sdk/api/aws.py b/skyvern/forge/sdk/api/aws.py index 92ed4ac4..c1f2c99b 100644 --- a/skyvern/forge/sdk/api/aws.py +++ b/skyvern/forge/sdk/api/aws.py @@ -104,6 +104,18 @@ class AsyncAWSClient: LOG.exception("Failed to create presigned url for S3 objects.", uris=uris) return None + @execute_with_async_client(client_type=AWSClientType.S3) + async def list_files(self, uri: str, client: AioBaseClient = None) -> list[str]: + object_keys: list[str] = [] + parsed_uri = S3Uri(uri) + async for page in client.get_paginator("list_objects_v2").paginate( + Bucket=parsed_uri.bucket, Prefix=parsed_uri.key + ): + if "Contents" in page: + for obj in page["Contents"]: + object_keys.append(obj["Key"]) + return object_keys + class S3Uri(object): # From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path diff --git a/skyvern/forge/sdk/api/files.py b/skyvern/forge/sdk/api/files.py index 98f47f87..925ab158 100644 --- a/skyvern/forge/sdk/api/files.py +++ b/skyvern/forge/sdk/api/files.py @@ -114,7 +114,13 @@ def unzip_files(zip_file_path: str, output_dir: str) -> None: def get_path_for_workflow_download_directory(workflow_run_id: str) -> Path: - return Path(f"{REPO_ROOT_DIR}/downloads/{workflow_run_id}/") + return Path(get_download_dir(workflow_run_id=workflow_run_id, task_id=None)) + + +def get_download_dir(workflow_run_id: str | None, task_id: str | None) -> str: + download_dir = f"{REPO_ROOT_DIR}/downloads/{workflow_run_id or task_id}" + os.makedirs(download_dir, exist_ok=True) + return download_dir def list_files_in_directory(directory: Path, recursive: bool = False) -> list[str]: diff --git a/skyvern/forge/sdk/artifact/storage/base.py b/skyvern/forge/sdk/artifact/storage/base.py index 4e5e6118..811c5bf1 100644 --- a/skyvern/forge/sdk/artifact/storage/base.py +++ b/skyvern/forge/sdk/artifact/storage/base.py @@ -68,3 +68,15 @@ class BaseStorage(ABC): @abstractmethod async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None: pass + + @abstractmethod + async def save_downloaded_files( + self, organization_id: str, task_id: str | None, workflow_run_id: str | None + ) -> None: + pass + + @abstractmethod + async def get_downloaded_files( + self, organization_id: str, task_id: str | None, workflow_run_id: str | None + ) -> list[str]: + pass diff --git a/skyvern/forge/sdk/artifact/storage/local.py b/skyvern/forge/sdk/artifact/storage/local.py index 54f0c54e..1dfc3541 100644 --- a/skyvern/forge/sdk/artifact/storage/local.py +++ b/skyvern/forge/sdk/artifact/storage/local.py @@ -6,7 +6,7 @@ from urllib.parse import unquote, urlparse import structlog -from skyvern.forge.sdk.api.files import get_skyvern_temp_dir +from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage from skyvern.forge.sdk.models import Step @@ -120,6 +120,23 @@ class LocalStorage(BaseStorage): return None return str(stored_folder_path) + async def save_downloaded_files( + self, organization_id: str, task_id: str | None, workflow_run_id: str | None + ) -> None: + pass + + async def get_downloaded_files( + self, organization_id: str, task_id: str | None, workflow_run_id: str | None + ) -> list[str]: + download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id) + files: list[str] = [] + files_and_folders = os.listdir(download_dir) + for file_or_folder in files_and_folders: + path = os.path.join(download_dir, file_or_folder) + if os.path.isfile(path): + files.append(f"file://{path}") + return files + @staticmethod def _parse_uri_to_path(uri: str) -> str: parsed_uri = urlparse(uri) diff --git a/skyvern/forge/sdk/artifact/storage/s3.py b/skyvern/forge/sdk/artifact/storage/s3.py index c3b7601c..ece1c072 100644 --- a/skyvern/forge/sdk/artifact/storage/s3.py +++ b/skyvern/forge/sdk/artifact/storage/s3.py @@ -1,3 +1,4 @@ +import os import shutil from datetime import datetime @@ -5,6 +6,7 @@ from skyvern.config import settings from skyvern.forge.sdk.api.aws import AsyncAWSClient from skyvern.forge.sdk.api.files import ( create_named_temporary_file, + get_download_dir, get_skyvern_temp_dir, make_temp_directory, unzip_files, @@ -68,3 +70,31 @@ class S3Storage(BaseStorage): unzip_files(temp_zip_file_path, temp_dir) temp_zip_file.close() return temp_dir + + async def save_downloaded_files( + self, organization_id: str, task_id: str | None, workflow_run_id: str | None + ) -> None: + download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id) + files = os.listdir(download_dir) + for file in files: + fpath = os.path.join(download_dir, file) + if os.path.isfile(fpath): + uri = f"s3://{settings.AWS_S3_BUCKET_DOWNLOADS}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}" + # TODO: use coroutine to speed up uploading if too many files + await self.async_client.upload_file_from_path(uri, fpath) + + async def get_downloaded_files( + self, organization_id: str, task_id: str | None, workflow_run_id: str | None + ) -> list[str]: + uri = f"s3://{settings.AWS_S3_BUCKET_DOWNLOADS}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}" + object_keys = await self.async_client.list_files(uri=uri) + if len(object_keys) == 0: + return [] + object_uris: list[str] = [] + for key in object_keys: + object_uri = f"s3://{settings.AWS_S3_BUCKET_DOWNLOADS}/{key}" + object_uris.append(object_uri) + presigned_urils = await self.async_client.create_presigned_urls(object_uris) + if presigned_urils is None: + return [] + return presigned_urils diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 1505b7a4..1df94347 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -252,6 +252,7 @@ class Task(TaskBase): screenshot_url: str | None = None, recording_url: str | None = None, browser_console_log_url: str | None = None, + downloaded_file_urls: list[str] | None = None, failure_reason: str | None = None, ) -> TaskResponse: return TaskResponse( @@ -266,6 +267,7 @@ class Task(TaskBase): screenshot_url=screenshot_url, recording_url=recording_url, browser_console_log_url=browser_console_log_url, + downloaded_file_urls=downloaded_file_urls, errors=self.errors, max_steps_per_run=self.max_steps_per_run, workflow_run_id=self.workflow_run_id, @@ -283,6 +285,7 @@ class TaskResponse(BaseModel): screenshot_url: str | None = None recording_url: str | None = None browser_console_log_url: str | None = None + downloaded_file_urls: list[str] | None = None failure_reason: str | None = None errors: list[dict[str, Any]] = [] max_steps_per_run: int | None = None diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index e0030e2e..d8d1dc0a 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -124,4 +124,5 @@ class WorkflowRunStatusResponse(BaseModel): parameters: dict[str, Any] screenshot_urls: list[str] | None = None recording_url: str | None = None + downloaded_file_urls: list[str] | None = None outputs: dict[str, Any] | None = None diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 5ff1eb2a..c50399ff 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -1,3 +1,4 @@ +import asyncio import json from datetime import datetime from typing import Any @@ -6,6 +7,7 @@ import httpx import structlog from skyvern import analytics +from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT, SAVE_DOWNLOADED_FILES_TIMEOUT from skyvern.exceptions import ( FailedToSendWebhook, MissingValueForParameter, @@ -778,6 +780,24 @@ class WorkflowService: if recording_artifact: recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact) + downloaded_file_urls: list[str] | None = None + try: + async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT): + downloaded_file_urls = await app.STORAGE.get_downloaded_files( + organization_id=workflow.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id + ) + except asyncio.TimeoutError: + LOG.warning( + "Timeout to get downloaded files", + workflow_run_id=workflow_run.workflow_run_id, + ) + except Exception: + LOG.warning( + "Failed to get downloaded files", + exc_info=True, + workflow_run_id=workflow_run.workflow_run_id, + ) + workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id) parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples} output_parameter_tuples: list[ @@ -804,6 +824,7 @@ class WorkflowService: parameters=parameters_with_value, screenshot_urls=screenshot_urls, recording_url=recording_url, + downloaded_file_urls=downloaded_file_urls, outputs=outputs, ) @@ -837,6 +858,23 @@ class WorkflowService: await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids) + try: + async with asyncio.timeout(SAVE_DOWNLOADED_FILES_TIMEOUT): + await app.STORAGE.save_downloaded_files( + workflow.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id + ) + except asyncio.TimeoutError: + LOG.warning( + "Timeout to save downloaded files", + workflow_run_id=workflow_run.workflow_run_id, + ) + except Exception: + LOG.warning( + "Failed to save downloaded files", + exc_info=True, + workflow_run_id=workflow_run.workflow_run_id, + ) + if not need_call_webhook: return diff --git a/skyvern/webeye/browser_factory.py b/skyvern/webeye/browser_factory.py index da588905..d7787d76 100644 --- a/skyvern/webeye/browser_factory.py +++ b/skyvern/webeye/browser_factory.py @@ -14,7 +14,7 @@ from playwright.async_api import BrowserContext, ConsoleMessage, Download, Error from pydantic import BaseModel, PrivateAttr from skyvern.config import settings -from skyvern.constants import BROWSER_CLOSE_TIMEOUT, BROWSER_DOWNLOAD_TIMEOUT, REPO_ROOT_DIR +from skyvern.constants import BROWSER_CLOSE_TIMEOUT, BROWSER_DOWNLOAD_TIMEOUT from skyvern.exceptions import ( FailedToNavigateToUrl, FailedToReloadPage, @@ -35,13 +35,6 @@ LOG = structlog.get_logger() BrowserCleanupFunc = Callable[[], None] | None -def get_download_dir(workflow_run_id: str | None, task_id: str | None) -> str: - download_dir = f"{REPO_ROOT_DIR}/downloads/{workflow_run_id or task_id}" - LOG.info("Initializing download directory", download_dir=download_dir) - os.makedirs(download_dir, exist_ok=True) - return download_dir - - def set_browser_console_log(browser_context: BrowserContext, browser_artifacts: BrowserArtifacts) -> None: if browser_artifacts.browser_console_log_path is None: log_path = f"{settings.LOG_PATH}/{datetime.utcnow().strftime('%Y-%m-%d')}/{uuid.uuid4()}.log"