Support downloading files via HTTP Calls (for Centria) (#4440)

This commit is contained in:
Marc Kelechava
2026-01-13 12:12:38 -08:00
committed by GitHub
parent a6f0781491
commit e6a3858096
16 changed files with 240 additions and 48 deletions

View File

@@ -33,7 +33,7 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
return file_path.name
def get_file_name_and_suffix_from_headers(headers: CIMultiDictProxy[str]) -> tuple[str, str]:
def get_file_name_and_suffix_from_headers(headers: CIMultiDictProxy[str] | dict[str, str]) -> tuple[str, str]:
file_stem = ""
file_suffix: str | None = ""
# retrieve the stem and suffix from Content-Disposition
@@ -70,6 +70,46 @@ def is_valid_mime_type(file_path: str) -> bool:
return mime_type is not None
def _determine_download_filename(
filename: str | None,
response_headers: dict,
url: str,
) -> str:
"""Determine the filename for a downloaded file."""
if filename:
file_name = filename
if not os.path.splitext(file_name)[1]:
content_type = response_headers.get("Content-Type", "")
if content_type:
ext = mimetypes.guess_extension(content_type.split(";")[0].strip())
if ext:
file_name = file_name + ext
return sanitize_filename(file_name)
file_name = ""
file_suffix = ""
try:
file_name, file_suffix = get_file_name_and_suffix_from_headers(response_headers)
if not file_suffix:
LOG.warning("No extension name retrieved from HTTP headers")
except Exception:
LOG.exception("Failed to retrieve the file extension from HTTP headers")
query_params = dict(parse_qsl(urlparse(url).query))
if "download" in query_params:
file_name = query_params["download"]
if not file_name:
LOG.info("No file name retrieved from HTTP headers, using the file name from the URL")
file_name = os.path.basename(urlparse(url).path) or "download"
if not is_valid_mime_type(file_name) and file_suffix:
LOG.info("No file extension detected, adding the extension from HTTP headers")
file_name = file_name + file_suffix
return sanitize_filename(file_name)
def validate_download_url(url: str) -> bool:
"""Validate if a URL is supported for downloading.
@@ -126,7 +166,13 @@ def validate_download_url(url: str) -> bool:
return False
async def download_file(url: str, max_size_mb: int | None = None) -> str:
async def download_file(
url: str,
max_size_mb: int | None = None,
headers: dict[str, str] | None = None,
output_dir: str | None = None,
filename: str | None = None,
) -> str:
try:
# Check if URL is a Google Drive link
if "drive.google.com" in url:
@@ -175,42 +221,22 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
async with aiohttp.ClientSession(raise_for_status=True) as session:
LOG.info("Starting to download file", url=url)
encoded_url = encode_url(url)
async with session.get(URL(encoded_url, encoded=True)) as response:
async with session.get(URL(encoded_url, encoded=True), headers=headers) as response:
# Check the content length if available
if max_size_mb and response.content_length and response.content_length > max_size_mb * 1024 * 1024:
# todo: move to root exception.py
raise DownloadFileMaxSizeExceeded(max_size_mb)
# Parse the URL
a = urlparse(url)
# Get the file name
temp_dir = make_temp_directory(prefix="skyvern_downloads_")
if output_dir:
os.makedirs(output_dir, exist_ok=True)
download_dir = output_dir
else:
download_dir = make_temp_directory(prefix="skyvern_downloads_")
file_name = ""
file_suffix = ""
try:
file_name, file_suffix = get_file_name_and_suffix_from_headers(response.headers)
if not file_suffix:
LOG.warning("No extension name retrieved from HTTP headers")
except Exception:
LOG.exception("Failed to retrieve the file extension from HTTP headers")
# parse the query params to get the file name
query_params = dict(parse_qsl(a.query))
if "download" in query_params:
file_name = query_params["download"]
if not file_name:
LOG.info("No file name retrieved from HTTP headers, using the file name from the URL")
file_name = os.path.basename(a.path)
if not is_valid_mime_type(file_name) and file_suffix:
LOG.info("No file extension detected, adding the extension from HTTP headers")
file_name = file_name + file_suffix
file_name = sanitize_filename(file_name)
file_path = os.path.join(temp_dir, file_name)
# Determine filename - use provided filename or derive from response/URL
file_name = _determine_download_filename(filename, dict(response.headers), url)
file_path = os.path.join(download_dir, file_name)
LOG.info(f"Downloading file to {file_path}")
with open(file_path, "wb") as f: