diff --git a/skyvern/forge/sdk/api/files.py b/skyvern/forge/sdk/api/files.py index e9db1071..0434841f 100644 --- a/skyvern/forge/sdk/api/files.py +++ b/skyvern/forge/sdk/api/files.py @@ -1,5 +1,7 @@ import hashlib +import mimetypes import os +import re import tempfile import zipfile from pathlib import Path @@ -7,6 +9,7 @@ from urllib.parse import urlparse import aiohttp import structlog +from multidict import CIMultiDictProxy from skyvern.constants import REPO_ROOT_DIR from skyvern.exceptions import DownloadFileMaxSizeExceeded @@ -22,6 +25,23 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str: return file_path.name +def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str: + # retrieve it from Content-Disposition + content_disposition = headers.get("Content-Disposition") + if content_disposition: + filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE) + if len(filename) > 0 and Path(filename[0]).suffix: + return Path(filename[0]).suffix + + # retrieve it from Content-Type + content_type = headers.get("Content-Type") + if content_type: + if file_extension := mimetypes.guess_extension(content_type): + return file_extension + + return "" + + async def download_file(url: str, max_size_mb: int | None = None) -> str: try: async with aiohttp.ClientSession(raise_for_status=True) as session: @@ -39,6 +59,17 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str: temp_dir = tempfile.mkdtemp(prefix="skyvern_downloads_") file_name = os.path.basename(a.path) + # if no suffix in the URL, we need to parse it from HTTP headers + if not Path(file_name).suffix: + LOG.info("No file extension detected, trying to retrieve it from HTTP headers") + try: + if extension_name := get_file_extension_from_headers(response.headers): + file_name = file_name + extension_name + else: + LOG.warning("No extension name retreived from HTTP headers") + except Exception: + LOG.exception("Failed to retreive the file extension from HTTP headers") + file_path = os.path.join(temp_dir, file_name) LOG.info(f"Downloading file to {file_path}")