parse filename from http header (#3059)

This commit is contained in:
LawyZheng
2025-07-30 16:10:02 +08:00
committed by GitHub
parent 4ec4d7d1f5
commit 0adc3078ed

View File

@@ -32,21 +32,27 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
return file_path.name
def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str:
# retrieve it from Content-Disposition
def get_file_name_and_suffix_from_headers(headers: CIMultiDictProxy[str]) -> tuple[str, str]:
file_stem = ""
file_suffix: str | None = ""
# retrieve the stem and suffix from Content-Disposition
content_disposition = headers.get("Content-Disposition")
if content_disposition:
filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE)
if len(filename) > 0 and Path(filename[0]).suffix:
return Path(filename[0]).suffix
if len(filename) > 0:
file_stem = Path(filename[0]).stem
file_suffix = Path(filename[0]).suffix
# retrieve it from Content-Type
if file_suffix:
return file_stem, file_suffix
# retrieve the suffix from Content-Type
content_type = headers.get("Content-Type")
if content_type:
if file_extension := mimetypes.guess_extension(content_type):
return file_extension
if file_suffix := mimetypes.guess_extension(content_type):
return file_stem, file_suffix
return ""
return file_stem, file_suffix or ""
def extract_google_drive_file_id(url: str) -> str | None:
@@ -98,27 +104,30 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
# Get the file name
temp_dir = make_temp_directory(prefix="skyvern_downloads_")
file_name = ""
file_suffix = ""
try:
file_name, file_suffix = get_file_name_and_suffix_from_headers(response.headers)
if not file_suffix:
LOG.warning("No extension name retrieved from HTTP headers")
except Exception:
LOG.exception("Failed to retrieve the file extension from HTTP headers")
if not file_name:
LOG.info("No file name retrieved from HTTP headers, using the file name from the URL")
file_name = os.path.basename(a.path)
# Check for download parameter in Supabase URLs
file_name = os.path.basename(a.path)
if "supabase.co" in a.netloc.lower():
query_params = dict(parse_qsl(a.query))
if "download" in query_params:
file_name = query_params["download"]
else:
file_name = os.path.basename(a.path)
if not Path(file_name).suffix and file_suffix:
LOG.info("No file extension detected, adding the extension from HTTP headers")
file_name = file_name + file_suffix
file_name = sanitize_filename(file_name)
# if no suffix in the URL, we need to parse it from HTTP headers
if not Path(file_name).suffix:
LOG.info("No file extension detected, trying to retrieve it from HTTP headers")
try:
if extension_name := get_file_extension_from_headers(response.headers):
file_name = file_name + extension_name
else:
LOG.warning("No extension name retreived from HTTP headers")
except Exception:
LOG.exception("Failed to retreive the file extension from HTTP headers")
file_path = os.path.join(temp_dir, file_name)
LOG.info(f"Downloading file to {file_path}")