automatically parse file extension when downloading (#1101)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
@@ -7,6 +9,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from multidict import CIMultiDictProxy
|
||||
|
||||
from skyvern.constants import REPO_ROOT_DIR
|
||||
from skyvern.exceptions import DownloadFileMaxSizeExceeded
|
||||
@@ -22,6 +25,23 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
|
||||
return file_path.name
|
||||
|
||||
|
||||
def get_file_extension_from_headers(headers: CIMultiDictProxy[str]) -> str:
|
||||
# retrieve it from Content-Disposition
|
||||
content_disposition = headers.get("Content-Disposition")
|
||||
if content_disposition:
|
||||
filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE)
|
||||
if len(filename) > 0 and Path(filename[0]).suffix:
|
||||
return Path(filename[0]).suffix
|
||||
|
||||
# retrieve it from Content-Type
|
||||
content_type = headers.get("Content-Type")
|
||||
if content_type:
|
||||
if file_extension := mimetypes.guess_extension(content_type):
|
||||
return file_extension
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
||||
try:
|
||||
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
||||
@@ -39,6 +59,17 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
||||
temp_dir = tempfile.mkdtemp(prefix="skyvern_downloads_")
|
||||
|
||||
file_name = os.path.basename(a.path)
|
||||
# if no suffix in the URL, we need to parse it from HTTP headers
|
||||
if not Path(file_name).suffix:
|
||||
LOG.info("No file extension detected, trying to retrieve it from HTTP headers")
|
||||
try:
|
||||
if extension_name := get_file_extension_from_headers(response.headers):
|
||||
file_name = file_name + extension_name
|
||||
else:
|
||||
LOG.warning("No extension name retreived from HTTP headers")
|
||||
except Exception:
|
||||
LOG.exception("Failed to retreive the file extension from HTTP headers")
|
||||
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
|
||||
LOG.info(f"Downloading file to {file_path}")
|
||||
|
||||
Reference in New Issue
Block a user