parse filename from http header (#3059)
This commit is contained in:
@@ -32,21 +32,27 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
|
||||
return file_path.name
|
||||
|
||||
|
||||
def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str:
|
||||
# retrieve it from Content-Disposition
|
||||
def get_file_name_and_suffix_from_headers(headers: CIMultiDictProxy[str]) -> tuple[str, str]:
|
||||
file_stem = ""
|
||||
file_suffix: str | None = ""
|
||||
# retrieve the stem and suffix from Content-Disposition
|
||||
content_disposition = headers.get("Content-Disposition")
|
||||
if content_disposition:
|
||||
filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE)
|
||||
if len(filename) > 0 and Path(filename[0]).suffix:
|
||||
return Path(filename[0]).suffix
|
||||
if len(filename) > 0:
|
||||
file_stem = Path(filename[0]).stem
|
||||
file_suffix = Path(filename[0]).suffix
|
||||
|
||||
# retrieve it from Content-Type
|
||||
if file_suffix:
|
||||
return file_stem, file_suffix
|
||||
|
||||
# retrieve the suffix from Content-Type
|
||||
content_type = headers.get("Content-Type")
|
||||
if content_type:
|
||||
if file_extension := mimetypes.guess_extension(content_type):
|
||||
return file_extension
|
||||
if file_suffix := mimetypes.guess_extension(content_type):
|
||||
return file_stem, file_suffix
|
||||
|
||||
return ""
|
||||
return file_stem, file_suffix or ""
|
||||
|
||||
|
||||
def extract_google_drive_file_id(url: str) -> str | None:
|
||||
@@ -98,27 +104,30 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
||||
# Get the file name
|
||||
temp_dir = make_temp_directory(prefix="skyvern_downloads_")
|
||||
|
||||
file_name = ""
|
||||
file_suffix = ""
|
||||
try:
|
||||
file_name, file_suffix = get_file_name_and_suffix_from_headers(response.headers)
|
||||
if not file_suffix:
|
||||
LOG.warning("No extension name retrieved from HTTP headers")
|
||||
except Exception:
|
||||
LOG.exception("Failed to retrieve the file extension from HTTP headers")
|
||||
|
||||
if not file_name:
|
||||
LOG.info("No file name retrieved from HTTP headers, using the file name from the URL")
|
||||
file_name = os.path.basename(a.path)
|
||||
|
||||
# Check for download parameter in Supabase URLs
|
||||
file_name = os.path.basename(a.path)
|
||||
if "supabase.co" in a.netloc.lower():
|
||||
query_params = dict(parse_qsl(a.query))
|
||||
if "download" in query_params:
|
||||
file_name = query_params["download"]
|
||||
else:
|
||||
file_name = os.path.basename(a.path)
|
||||
|
||||
if not Path(file_name).suffix and file_suffix:
|
||||
LOG.info("No file extension detected, adding the extension from HTTP headers")
|
||||
file_name = file_name + file_suffix
|
||||
|
||||
file_name = sanitize_filename(file_name)
|
||||
|
||||
# if no suffix in the URL, we need to parse it from HTTP headers
|
||||
if not Path(file_name).suffix:
|
||||
LOG.info("No file extension detected, trying to retrieve it from HTTP headers")
|
||||
try:
|
||||
if extension_name := get_file_extension_from_headers(response.headers):
|
||||
file_name = file_name + extension_name
|
||||
else:
|
||||
LOG.warning("No extension name retreived from HTTP headers")
|
||||
except Exception:
|
||||
LOG.exception("Failed to retreive the file extension from HTTP headers")
|
||||
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
|
||||
LOG.info(f"Downloading file to {file_path}")
|
||||
|
||||
Reference in New Issue
Block a user