upload all downloaded files when using s3 (#1289)
This commit is contained in:
@@ -68,3 +68,15 @@ class BaseStorage(ABC):
|
||||
@abstractmethod
|
||||
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def save_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> list[str]:
|
||||
pass
|
||||
|
||||
@@ -6,7 +6,7 @@ from urllib.parse import unquote, urlparse
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.api.files import get_skyvern_temp_dir
|
||||
from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
|
||||
from skyvern.forge.sdk.models import Step
|
||||
@@ -120,6 +120,23 @@ class LocalStorage(BaseStorage):
|
||||
return None
|
||||
return str(stored_folder_path)
|
||||
|
||||
async def save_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> list[str]:
|
||||
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
|
||||
files: list[str] = []
|
||||
files_and_folders = os.listdir(download_dir)
|
||||
for file_or_folder in files_and_folders:
|
||||
path = os.path.join(download_dir, file_or_folder)
|
||||
if os.path.isfile(path):
|
||||
files.append(f"file://{path}")
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _parse_uri_to_path(uri: str) -> str:
|
||||
parsed_uri = urlparse(uri)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
|
||||
@@ -5,6 +6,7 @@ from skyvern.config import settings
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.api.files import (
|
||||
create_named_temporary_file,
|
||||
get_download_dir,
|
||||
get_skyvern_temp_dir,
|
||||
make_temp_directory,
|
||||
unzip_files,
|
||||
@@ -68,3 +70,31 @@ class S3Storage(BaseStorage):
|
||||
unzip_files(temp_zip_file_path, temp_dir)
|
||||
temp_zip_file.close()
|
||||
return temp_dir
|
||||
|
||||
async def save_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> None:
|
||||
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
|
||||
files = os.listdir(download_dir)
|
||||
for file in files:
|
||||
fpath = os.path.join(download_dir, file)
|
||||
if os.path.isfile(fpath):
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_DOWNLOADS}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
|
||||
# TODO: use coroutine to speed up uploading if too many files
|
||||
await self.async_client.upload_file_from_path(uri, fpath)
|
||||
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
) -> list[str]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_DOWNLOADS}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}"
|
||||
object_keys = await self.async_client.list_files(uri=uri)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
object_uris: list[str] = []
|
||||
for key in object_keys:
|
||||
object_uri = f"s3://{settings.AWS_S3_BUCKET_DOWNLOADS}/{key}"
|
||||
object_uris.append(object_uri)
|
||||
presigned_urils = await self.async_client.create_presigned_urls(object_uris)
|
||||
if presigned_urils is None:
|
||||
return []
|
||||
return presigned_urils
|
||||
|
||||
Reference in New Issue
Block a user