Use suffix instead of prefix when creating temp file (#1622)

This commit is contained in:
Shuchang Zheng
2025-01-23 04:13:40 +08:00
committed by GitHub
parent 43cd5a8119
commit 31885113a4
2 changed files with 33 additions and 8 deletions

View File

@@ -24,6 +24,7 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
downloaded_bytes = await client.download_file(uri=s3_uri) downloaded_bytes = await client.download_file(uri=s3_uri)
filename = s3_uri.split("/")[-1] # Extract filename from the end of S3 URI filename = s3_uri.split("/")[-1] # Extract filename from the end of S3 URI
file_path = create_named_temporary_file(delete=False, file_name=filename) file_path = create_named_temporary_file(delete=False, file_name=filename)
LOG.info(f"Downloaded file to {file_path.name}")
file_path.write(downloaded_bytes) file_path.write(downloaded_bytes)
return file_path.name return file_path.name
@@ -48,7 +49,7 @@ def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str:
async def download_file(url: str, max_size_mb: int | None = None) -> str: async def download_file(url: str, max_size_mb: int | None = None) -> str:
try: try:
# Check if URL is an S3 URI # Check if URL is an S3 URI
if url.startswith("s3://skyvern-uploads/local/o_"): if url.startswith(f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"):
LOG.info("Downloading Skyvern file from S3", url=url) LOG.info("Downloading Skyvern file from S3", url=url)
client = AsyncAWSClient() client = AsyncAWSClient()
return await download_from_s3(client, url) return await download_from_s3(client, url)
@@ -193,7 +194,17 @@ def make_temp_directory(
def create_named_temporary_file(delete: bool = True, file_name: str | None = None) -> tempfile._TemporaryFileWrapper: def create_named_temporary_file(delete: bool = True, file_name: str | None = None) -> tempfile._TemporaryFileWrapper:
temp_dir = settings.TEMP_PATH temp_dir = settings.TEMP_PATH
create_folder_if_not_exist(temp_dir) create_folder_if_not_exist(temp_dir)
return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete, prefix=file_name)
if file_name:
# Sanitize the filename to remove any dangerous characters
safe_file_name = sanitize_filename(file_name)
# Create file with exact name (without random characters)
file_path = os.path.join(temp_dir, safe_file_name)
# Open in binary mode and return a NamedTemporaryFile-like object
file = open(file_path, "wb")
return tempfile._TemporaryFileWrapper(file, file_path, delete=delete)
return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete)
def clean_up_dir(dir: str) -> None: def clean_up_dir(dir: str) -> None:

View File

@@ -1,5 +1,6 @@
import datetime import datetime
import hashlib import hashlib
import os
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Annotated, Any from typing import Annotated, Any
@@ -1094,12 +1095,25 @@ async def upload_file(
) -> Response: ) -> Response:
bucket = app.SETTINGS_MANAGER.AWS_S3_BUCKET_UPLOADS bucket = app.SETTINGS_MANAGER.AWS_S3_BUCKET_UPLOADS
todays_date = datetime.datetime.now().strftime("%Y-%m-%d") todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}"
s3_uri = ( # First try uploading with original filename
f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}" try:
) sanitized_filename = os.path.basename(file.filename) # Remove any path components
# Stream the file to S3 s3_uri = (
uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file) f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{sanitized_filename}"
)
uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file)
except Exception:
LOG.error("Failed to upload file to S3", exc_info=True)
uploaded_s3_uri = None
# If upload fails, try again with UUID prefix
if not uploaded_s3_uri:
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}"
s3_uri = f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}"
file.file.seek(0) # Reset file pointer
uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file)
if not uploaded_s3_uri: if not uploaded_s3_uri:
raise HTTPException(status_code=500, detail="Failed to upload file to S3.") raise HTTPException(status_code=500, detail="Failed to upload file to S3.")