parse filename from http header (#3059)
This commit is contained in:
@@ -32,21 +32,27 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
|
|||||||
return file_path.name
|
return file_path.name
|
||||||
|
|
||||||
|
|
||||||
def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str:
|
def get_file_name_and_suffix_from_headers(headers: CIMultiDictProxy[str]) -> tuple[str, str]:
|
||||||
# retrieve it from Content-Disposition
|
file_stem = ""
|
||||||
|
file_suffix: str | None = ""
|
||||||
|
# retrieve the stem and suffix from Content-Disposition
|
||||||
content_disposition = headers.get("Content-Disposition")
|
content_disposition = headers.get("Content-Disposition")
|
||||||
if content_disposition:
|
if content_disposition:
|
||||||
filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE)
|
filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE)
|
||||||
if len(filename) > 0 and Path(filename[0]).suffix:
|
if len(filename) > 0:
|
||||||
return Path(filename[0]).suffix
|
file_stem = Path(filename[0]).stem
|
||||||
|
file_suffix = Path(filename[0]).suffix
|
||||||
|
|
||||||
# retrieve it from Content-Type
|
if file_suffix:
|
||||||
|
return file_stem, file_suffix
|
||||||
|
|
||||||
|
# retrieve the suffix from Content-Type
|
||||||
content_type = headers.get("Content-Type")
|
content_type = headers.get("Content-Type")
|
||||||
if content_type:
|
if content_type:
|
||||||
if file_extension := mimetypes.guess_extension(content_type):
|
if file_suffix := mimetypes.guess_extension(content_type):
|
||||||
return file_extension
|
return file_stem, file_suffix
|
||||||
|
|
||||||
return ""
|
return file_stem, file_suffix or ""
|
||||||
|
|
||||||
|
|
||||||
def extract_google_drive_file_id(url: str) -> str | None:
|
def extract_google_drive_file_id(url: str) -> str | None:
|
||||||
@@ -98,27 +104,30 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
|||||||
# Get the file name
|
# Get the file name
|
||||||
temp_dir = make_temp_directory(prefix="skyvern_downloads_")
|
temp_dir = make_temp_directory(prefix="skyvern_downloads_")
|
||||||
|
|
||||||
|
file_name = ""
|
||||||
|
file_suffix = ""
|
||||||
|
try:
|
||||||
|
file_name, file_suffix = get_file_name_and_suffix_from_headers(response.headers)
|
||||||
|
if not file_suffix:
|
||||||
|
LOG.warning("No extension name retrieved from HTTP headers")
|
||||||
|
except Exception:
|
||||||
|
LOG.exception("Failed to retrieve the file extension from HTTP headers")
|
||||||
|
|
||||||
|
if not file_name:
|
||||||
|
LOG.info("No file name retrieved from HTTP headers, using the file name from the URL")
|
||||||
|
file_name = os.path.basename(a.path)
|
||||||
|
|
||||||
# Check for download parameter in Supabase URLs
|
# Check for download parameter in Supabase URLs
|
||||||
file_name = os.path.basename(a.path)
|
|
||||||
if "supabase.co" in a.netloc.lower():
|
if "supabase.co" in a.netloc.lower():
|
||||||
query_params = dict(parse_qsl(a.query))
|
query_params = dict(parse_qsl(a.query))
|
||||||
if "download" in query_params:
|
if "download" in query_params:
|
||||||
file_name = query_params["download"]
|
file_name = query_params["download"]
|
||||||
else:
|
|
||||||
file_name = os.path.basename(a.path)
|
if not Path(file_name).suffix and file_suffix:
|
||||||
|
LOG.info("No file extension detected, adding the extension from HTTP headers")
|
||||||
|
file_name = file_name + file_suffix
|
||||||
|
|
||||||
file_name = sanitize_filename(file_name)
|
file_name = sanitize_filename(file_name)
|
||||||
|
|
||||||
# if no suffix in the URL, we need to parse it from HTTP headers
|
|
||||||
if not Path(file_name).suffix:
|
|
||||||
LOG.info("No file extension detected, trying to retrieve it from HTTP headers")
|
|
||||||
try:
|
|
||||||
if extension_name := get_file_extension_from_headers(response.headers):
|
|
||||||
file_name = file_name + extension_name
|
|
||||||
else:
|
|
||||||
LOG.warning("No extension name retreived from HTTP headers")
|
|
||||||
except Exception:
|
|
||||||
LOG.exception("Failed to retreive the file extension from HTTP headers")
|
|
||||||
|
|
||||||
file_path = os.path.join(temp_dir, file_name)
|
file_path = os.path.join(temp_dir, file_name)
|
||||||
|
|
||||||
LOG.info(f"Downloading file to {file_path}")
|
LOG.info(f"Downloading file to {file_path}")
|
||||||
|
|||||||
Reference in New Issue
Block a user