Use suffix instead of prefix when creating temp file (#1622)
This commit is contained in:
@@ -24,6 +24,7 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
|
|||||||
downloaded_bytes = await client.download_file(uri=s3_uri)
|
downloaded_bytes = await client.download_file(uri=s3_uri)
|
||||||
filename = s3_uri.split("/")[-1] # Extract filename from the end of S3 URI
|
filename = s3_uri.split("/")[-1] # Extract filename from the end of S3 URI
|
||||||
file_path = create_named_temporary_file(delete=False, file_name=filename)
|
file_path = create_named_temporary_file(delete=False, file_name=filename)
|
||||||
|
LOG.info(f"Downloaded file to {file_path.name}")
|
||||||
file_path.write(downloaded_bytes)
|
file_path.write(downloaded_bytes)
|
||||||
return file_path.name
|
return file_path.name
|
||||||
|
|
||||||
@@ -48,7 +49,7 @@ def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str:
|
|||||||
async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
||||||
try:
|
try:
|
||||||
# Check if URL is an S3 URI
|
# Check if URL is an S3 URI
|
||||||
if url.startswith("s3://skyvern-uploads/local/o_"):
|
if url.startswith(f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"):
|
||||||
LOG.info("Downloading Skyvern file from S3", url=url)
|
LOG.info("Downloading Skyvern file from S3", url=url)
|
||||||
client = AsyncAWSClient()
|
client = AsyncAWSClient()
|
||||||
return await download_from_s3(client, url)
|
return await download_from_s3(client, url)
|
||||||
@@ -193,7 +194,17 @@ def make_temp_directory(
|
|||||||
def create_named_temporary_file(delete: bool = True, file_name: str | None = None) -> tempfile._TemporaryFileWrapper:
|
def create_named_temporary_file(delete: bool = True, file_name: str | None = None) -> tempfile._TemporaryFileWrapper:
|
||||||
temp_dir = settings.TEMP_PATH
|
temp_dir = settings.TEMP_PATH
|
||||||
create_folder_if_not_exist(temp_dir)
|
create_folder_if_not_exist(temp_dir)
|
||||||
return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete, prefix=file_name)
|
|
||||||
|
if file_name:
|
||||||
|
# Sanitize the filename to remove any dangerous characters
|
||||||
|
safe_file_name = sanitize_filename(file_name)
|
||||||
|
# Create file with exact name (without random characters)
|
||||||
|
file_path = os.path.join(temp_dir, safe_file_name)
|
||||||
|
# Open in binary mode and return a NamedTemporaryFile-like object
|
||||||
|
file = open(file_path, "wb")
|
||||||
|
return tempfile._TemporaryFileWrapper(file, file_path, delete=delete)
|
||||||
|
|
||||||
|
return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete)
|
||||||
|
|
||||||
|
|
||||||
def clean_up_dir(dir: str) -> None:
|
def clean_up_dir(dir: str) -> None:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
@@ -1094,12 +1095,25 @@ async def upload_file(
|
|||||||
) -> Response:
|
) -> Response:
|
||||||
bucket = app.SETTINGS_MANAGER.AWS_S3_BUCKET_UPLOADS
|
bucket = app.SETTINGS_MANAGER.AWS_S3_BUCKET_UPLOADS
|
||||||
todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||||
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}"
|
|
||||||
s3_uri = (
|
# First try uploading with original filename
|
||||||
f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}"
|
try:
|
||||||
)
|
sanitized_filename = os.path.basename(file.filename) # Remove any path components
|
||||||
# Stream the file to S3
|
s3_uri = (
|
||||||
uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file)
|
f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{sanitized_filename}"
|
||||||
|
)
|
||||||
|
uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file)
|
||||||
|
except Exception:
|
||||||
|
LOG.error("Failed to upload file to S3", exc_info=True)
|
||||||
|
uploaded_s3_uri = None
|
||||||
|
|
||||||
|
# If upload fails, try again with UUID prefix
|
||||||
|
if not uploaded_s3_uri:
|
||||||
|
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{file.filename}"
|
||||||
|
s3_uri = f"s3://{bucket}/{app.SETTINGS_MANAGER.ENV}/{current_org.organization_id}/{todays_date}/{uuid_prefixed_filename}"
|
||||||
|
file.file.seek(0) # Reset file pointer
|
||||||
|
uploaded_s3_uri = await aws_client.upload_file_stream(s3_uri, file.file)
|
||||||
|
|
||||||
if not uploaded_s3_uri:
|
if not uploaded_s3_uri:
|
||||||
raise HTTPException(status_code=500, detail="Failed to upload file to S3.")
|
raise HTTPException(status_code=500, detail="Failed to upload file to S3.")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user