diff --git a/skyvern/forge/sdk/api/files.py b/skyvern/forge/sdk/api/files.py index 2d80b7c9..afd742b6 100644 --- a/skyvern/forge/sdk/api/files.py +++ b/skyvern/forge/sdk/api/files.py @@ -22,7 +22,8 @@ LOG = structlog.get_logger() async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str: downloaded_bytes = await client.download_file(uri=s3_uri) - file_path = create_named_temporary_file(delete=False) + filename = s3_uri.split("/")[-1] # Extract filename from the end of S3 URI + file_path = create_named_temporary_file(delete=False, file_name=filename) file_path.write(downloaded_bytes) return file_path.name @@ -46,8 +47,14 @@ def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str: async def download_file(url: str, max_size_mb: int | None = None) -> str: try: + # Check if URL is an S3 URI + if url.startswith("s3://skyvern-uploads/local/o_"): + LOG.info("Downloading Skyvern file from S3", url=url) + client = AsyncAWSClient() + return await download_from_s3(client, url) + async with aiohttp.ClientSession(raise_for_status=True) as session: - LOG.info("Starting to download file") + LOG.info("Starting to download file", url=url) async with session.get(url) as response: # Check the content length if available if max_size_mb and response.content_length and response.content_length > max_size_mb * 1024 * 1024: @@ -183,10 +190,10 @@ def make_temp_directory( return tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=temp_dir) -def create_named_temporary_file(delete: bool = True) -> tempfile._TemporaryFileWrapper: +def create_named_temporary_file(delete: bool = True, file_name: str | None = None) -> tempfile._TemporaryFileWrapper: temp_dir = settings.TEMP_PATH create_folder_if_not_exist(temp_dir) - return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete) + return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete, prefix=file_name) def clean_up_dir(dir: str) -> None: