Support uploading files from S3 (#1618)
This commit is contained in:
@@ -22,7 +22,8 @@ LOG = structlog.get_logger()
|
|||||||
|
|
||||||
async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
|
async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
|
||||||
downloaded_bytes = await client.download_file(uri=s3_uri)
|
downloaded_bytes = await client.download_file(uri=s3_uri)
|
||||||
file_path = create_named_temporary_file(delete=False)
|
filename = s3_uri.split("/")[-1] # Extract filename from the end of S3 URI
|
||||||
|
file_path = create_named_temporary_file(delete=False, file_name=filename)
|
||||||
file_path.write(downloaded_bytes)
|
file_path.write(downloaded_bytes)
|
||||||
return file_path.name
|
return file_path.name
|
||||||
|
|
||||||
@@ -46,8 +47,14 @@ def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str:
|
|||||||
|
|
||||||
async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
||||||
try:
|
try:
|
||||||
|
# Check if URL is an S3 URI
|
||||||
|
if url.startswith("s3://skyvern-uploads/local/o_"):
|
||||||
|
LOG.info("Downloading Skyvern file from S3", url=url)
|
||||||
|
client = AsyncAWSClient()
|
||||||
|
return await download_from_s3(client, url)
|
||||||
|
|
||||||
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
||||||
LOG.info("Starting to download file")
|
LOG.info("Starting to download file", url=url)
|
||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
# Check the content length if available
|
# Check the content length if available
|
||||||
if max_size_mb and response.content_length and response.content_length > max_size_mb * 1024 * 1024:
|
if max_size_mb and response.content_length and response.content_length > max_size_mb * 1024 * 1024:
|
||||||
@@ -183,10 +190,10 @@ def make_temp_directory(
|
|||||||
return tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=temp_dir)
|
return tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=temp_dir)
|
||||||
|
|
||||||
|
|
||||||
def create_named_temporary_file(delete: bool = True) -> tempfile._TemporaryFileWrapper:
|
def create_named_temporary_file(delete: bool = True, file_name: str | None = None) -> tempfile._TemporaryFileWrapper:
|
||||||
temp_dir = settings.TEMP_PATH
|
temp_dir = settings.TEMP_PATH
|
||||||
create_folder_if_not_exist(temp_dir)
|
create_folder_if_not_exist(temp_dir)
|
||||||
return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete)
|
return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete, prefix=file_name)
|
||||||
|
|
||||||
|
|
||||||
def clean_up_dir(dir: str) -> None:
|
def clean_up_dir(dir: str) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user