upload all downloaded files when using s3 (#1289)

This commit is contained in:
LawyZheng
2024-11-29 16:05:44 +08:00
committed by GitHub
parent d697023994
commit 87061f5bb6
12 changed files with 211 additions and 50 deletions

View File

@@ -104,6 +104,18 @@ class AsyncAWSClient:
LOG.exception("Failed to create presigned url for S3 objects.", uris=uris)
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def list_files(self, uri: str, client: AioBaseClient = None) -> list[str]:
object_keys: list[str] = []
parsed_uri = S3Uri(uri)
async for page in client.get_paginator("list_objects_v2").paginate(
Bucket=parsed_uri.bucket, Prefix=parsed_uri.key
):
if "Contents" in page:
for obj in page["Contents"]:
object_keys.append(obj["Key"])
return object_keys
class S3Uri(object):
# From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path

View File

@@ -114,7 +114,13 @@ def unzip_files(zip_file_path: str, output_dir: str) -> None:
def get_path_for_workflow_download_directory(workflow_run_id: str) -> Path:
return Path(f"{REPO_ROOT_DIR}/downloads/{workflow_run_id}/")
return Path(get_download_dir(workflow_run_id=workflow_run_id, task_id=None))
def get_download_dir(workflow_run_id: str | None, task_id: str | None) -> str:
download_dir = f"{REPO_ROOT_DIR}/downloads/{workflow_run_id or task_id}"
os.makedirs(download_dir, exist_ok=True)
return download_dir
def list_files_in_directory(directory: Path, recursive: bool = False) -> list[str]:

View File

@@ -68,3 +68,15 @@ class BaseStorage(ABC):
@abstractmethod
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
pass
@abstractmethod
async def save_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
) -> None:
pass
@abstractmethod
async def get_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
) -> list[str]:
pass

View File

@@ -6,7 +6,7 @@ from urllib.parse import unquote, urlparse
import structlog
from skyvern.forge.sdk.api.files import get_skyvern_temp_dir
from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step
@@ -120,6 +120,23 @@ class LocalStorage(BaseStorage):
return None
return str(stored_folder_path)
async def save_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
) -> None:
pass
async def get_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
) -> list[str]:
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
files: list[str] = []
files_and_folders = os.listdir(download_dir)
for file_or_folder in files_and_folders:
path = os.path.join(download_dir, file_or_folder)
if os.path.isfile(path):
files.append(f"file://{path}")
return files
@staticmethod
def _parse_uri_to_path(uri: str) -> str:
parsed_uri = urlparse(uri)

View File

@@ -1,3 +1,4 @@
import os
import shutil
from datetime import datetime
@@ -5,6 +6,7 @@ from skyvern.config import settings
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.api.files import (
create_named_temporary_file,
get_download_dir,
get_skyvern_temp_dir,
make_temp_directory,
unzip_files,
@@ -68,3 +70,31 @@ class S3Storage(BaseStorage):
unzip_files(temp_zip_file_path, temp_dir)
temp_zip_file.close()
return temp_dir
async def save_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
) -> None:
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
files = os.listdir(download_dir)
for file in files:
fpath = os.path.join(download_dir, file)
if os.path.isfile(fpath):
uri = f"s3://{settings.AWS_S3_BUCKET_DOWNLOADS}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
# TODO: use coroutine to speed up uploading if too many files
await self.async_client.upload_file_from_path(uri, fpath)
async def get_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
) -> list[str]:
uri = f"s3://{settings.AWS_S3_BUCKET_DOWNLOADS}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}"
object_keys = await self.async_client.list_files(uri=uri)
if len(object_keys) == 0:
return []
object_uris: list[str] = []
for key in object_keys:
object_uri = f"s3://{settings.AWS_S3_BUCKET_DOWNLOADS}/{key}"
object_uris.append(object_uri)
presigned_urils = await self.async_client.create_presigned_urls(object_uris)
if presigned_urils is None:
return []
return presigned_urils

View File

@@ -252,6 +252,7 @@ class Task(TaskBase):
screenshot_url: str | None = None,
recording_url: str | None = None,
browser_console_log_url: str | None = None,
downloaded_file_urls: list[str] | None = None,
failure_reason: str | None = None,
) -> TaskResponse:
return TaskResponse(
@@ -266,6 +267,7 @@ class Task(TaskBase):
screenshot_url=screenshot_url,
recording_url=recording_url,
browser_console_log_url=browser_console_log_url,
downloaded_file_urls=downloaded_file_urls,
errors=self.errors,
max_steps_per_run=self.max_steps_per_run,
workflow_run_id=self.workflow_run_id,
@@ -283,6 +285,7 @@ class TaskResponse(BaseModel):
screenshot_url: str | None = None
recording_url: str | None = None
browser_console_log_url: str | None = None
downloaded_file_urls: list[str] | None = None
failure_reason: str | None = None
errors: list[dict[str, Any]] = []
max_steps_per_run: int | None = None

View File

@@ -124,4 +124,5 @@ class WorkflowRunStatusResponse(BaseModel):
parameters: dict[str, Any]
screenshot_urls: list[str] | None = None
recording_url: str | None = None
downloaded_file_urls: list[str] | None = None
outputs: dict[str, Any] | None = None

View File

@@ -1,3 +1,4 @@
import asyncio
import json
from datetime import datetime
from typing import Any
@@ -6,6 +7,7 @@ import httpx
import structlog
from skyvern import analytics
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT, SAVE_DOWNLOADED_FILES_TIMEOUT
from skyvern.exceptions import (
FailedToSendWebhook,
MissingValueForParameter,
@@ -778,6 +780,24 @@ class WorkflowService:
if recording_artifact:
recording_url = await app.ARTIFACT_MANAGER.get_share_link(recording_artifact)
downloaded_file_urls: list[str] | None = None
try:
async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT):
downloaded_file_urls = await app.STORAGE.get_downloaded_files(
organization_id=workflow.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id
)
except asyncio.TimeoutError:
LOG.warning(
"Timeout to get downloaded files",
workflow_run_id=workflow_run.workflow_run_id,
)
except Exception:
LOG.warning(
"Failed to get downloaded files",
exc_info=True,
workflow_run_id=workflow_run.workflow_run_id,
)
workflow_parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters_with_value = {wfp.key: wfrp.value for wfp, wfrp in workflow_parameter_tuples}
output_parameter_tuples: list[
@@ -804,6 +824,7 @@ class WorkflowService:
parameters=parameters_with_value,
screenshot_urls=screenshot_urls,
recording_url=recording_url,
downloaded_file_urls=downloaded_file_urls,
outputs=outputs,
)
@@ -837,6 +858,23 @@ class WorkflowService:
await app.ARTIFACT_MANAGER.wait_for_upload_aiotasks_for_tasks(all_workflow_task_ids)
try:
async with asyncio.timeout(SAVE_DOWNLOADED_FILES_TIMEOUT):
await app.STORAGE.save_downloaded_files(
workflow.organization_id, task_id=None, workflow_run_id=workflow_run.workflow_run_id
)
except asyncio.TimeoutError:
LOG.warning(
"Timeout to save downloaded files",
workflow_run_id=workflow_run.workflow_run_id,
)
except Exception:
LOG.warning(
"Failed to save downloaded files",
exc_info=True,
workflow_run_id=workflow_run.workflow_run_id,
)
if not need_call_webhook:
return