automatically parse file extension when downloading (#1101)

This commit is contained in:
LawyZheng
2024-11-01 01:24:34 +08:00
committed by GitHub
parent ce8963bb7d
commit d649699619

View File

@@ -1,5 +1,7 @@
import hashlib
import mimetypes
import os
import re
import tempfile
import zipfile
from pathlib import Path
@@ -7,6 +9,7 @@ from urllib.parse import urlparse
import aiohttp
import structlog
from multidict import CIMultiDictProxy
from skyvern.constants import REPO_ROOT_DIR
from skyvern.exceptions import DownloadFileMaxSizeExceeded
@@ -22,6 +25,23 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
return file_path.name
def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str:
# retrieve it from Content-Disposition
content_disposition = headers.get("Content-Disposition")
if content_disposition:
filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE)
if len(filename) > 0 and Path(filename[0]).suffix:
return Path(filename[0]).suffix
# retrieve it from Content-Type
content_type = headers.get("Content-Type")
if content_type:
if file_extension := mimetypes.guess_extension(content_type):
return file_extension
return ""
async def download_file(url: str, max_size_mb: int | None = None) -> str:
try:
async with aiohttp.ClientSession(raise_for_status=True) as session:
@@ -39,6 +59,17 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
temp_dir = tempfile.mkdtemp(prefix="skyvern_downloads_")
file_name = os.path.basename(a.path)
# if no suffix in the URL, we need to parse it from HTTP headers
if not Path(file_name).suffix:
LOG.info("No file extension detected, trying to retrieve it from HTTP headers")
try:
if extension_name := get_file_extension_from_headers(response.headers):
file_name = file_name + extension_name
else:
LOG.warning("No extension name retreived from HTTP headers")
except Exception:
LOG.exception("Failed to retreive the file extension from HTTP headers")
file_path = os.path.join(temp_dir, file_name)
LOG.info(f"Downloading file to {file_path}")