diff --git a/skyvern/forge/sdk/api/files.py b/skyvern/forge/sdk/api/files.py index 36e4a90d..dfd4e34c 100644 --- a/skyvern/forge/sdk/api/files.py +++ b/skyvern/forge/sdk/api/files.py @@ -24,6 +24,7 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str: downloaded_bytes = await client.download_file(uri=s3_uri) filename = s3_uri.split("/")[-1] # Extract filename from the end of S3 URI file_path = create_named_temporary_file(delete=False, file_name=filename) + LOG.info(f"Downloaded file to {file_path.name}") file_path.write(downloaded_bytes) return file_path.name @@ -48,7 +49,7 @@ def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str: async def download_file(url: str, max_size_mb: int | None = None) -> str: try: # Check if URL is an S3 URI - if url.startswith("s3://skyvern-uploads/local/o_"): + if url.startswith(f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"): LOG.info("Downloading Skyvern file from S3", url=url) client = AsyncAWSClient() return await download_from_s3(client, url) @@ -193,7 +194,17 @@ def make_temp_directory( def create_named_temporary_file(delete: bool = True, file_name: str | None = None) -> tempfile._TemporaryFileWrapper: temp_dir = settings.TEMP_PATH create_folder_if_not_exist(temp_dir) - return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete, prefix=file_name) + + if file_name: + # Sanitize the filename to remove any dangerous characters + safe_file_name = sanitize_filename(file_name) + # Create file with exact name (without random characters) + file_path = os.path.join(temp_dir, safe_file_name) + # Open in binary mode and return a NamedTemporaryFile-like object + file = open(file_path, "wb") + return tempfile._TemporaryFileWrapper(file, file_path, delete=delete) + + return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete) def clean_up_dir(dir: str) -> None: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 35299527..5a0fab34 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -1,5 +1,6 @@ import datetime import hashlib +import os import uuid from enum import Enum from typing import Annotated, Any @@ -1094,12 +1095,25 @@ async def upload_file( ) -> Response: bucket = app.SETTINGS_MANAGER.AWS_S3_BUCKET_UPLOADS todays_date = datetime.datetime.now().strftime("%Y-%m-%d") - uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}" - s3_uri = ( - f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}" - ) - # Stream the file to S3 - uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file) + + # First try uploading with original filename + try: + sanitized_filename = os.path.basename(file.filename) # Remove any path components + s3_uri = ( + f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{sanitized_filename}" + ) + uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file) + except Exception: + LOG.error("Failed to upload file to S3", exc_info=True) + uploaded_s3_uri = None + + # If upload fails, try again with UUID prefix + if not uploaded_s3_uri: + uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}" + s3_uri = f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}" + file.file.seek(0) # Reset file pointer + uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file) + if not uploaded_s3_uri: raise HTTPException(status_code=500, detail="Failed to upload file to S3.")