Support downloading files via HTTP Calls (for Centria) (#4440)
This commit is contained in:
@@ -33,7 +33,7 @@ async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
|
||||
return file_path.name
|
||||
|
||||
|
||||
def get_file_name_and_suffix_from_headers(headers: CIMultiDictProxy[str]) -> tuple[str, str]:
|
||||
def get_file_name_and_suffix_from_headers(headers: CIMultiDictProxy[str] | dict[str, str]) -> tuple[str, str]:
|
||||
file_stem = ""
|
||||
file_suffix: str | None = ""
|
||||
# retrieve the stem and suffix from Content-Disposition
|
||||
@@ -70,6 +70,46 @@ def is_valid_mime_type(file_path: str) -> bool:
|
||||
return mime_type is not None
|
||||
|
||||
|
||||
def _determine_download_filename(
|
||||
filename: str | None,
|
||||
response_headers: dict,
|
||||
url: str,
|
||||
) -> str:
|
||||
"""Determine the filename for a downloaded file."""
|
||||
if filename:
|
||||
file_name = filename
|
||||
if not os.path.splitext(file_name)[1]:
|
||||
content_type = response_headers.get("Content-Type", "")
|
||||
if content_type:
|
||||
ext = mimetypes.guess_extension(content_type.split(";")[0].strip())
|
||||
if ext:
|
||||
file_name = file_name + ext
|
||||
return sanitize_filename(file_name)
|
||||
|
||||
file_name = ""
|
||||
file_suffix = ""
|
||||
try:
|
||||
file_name, file_suffix = get_file_name_and_suffix_from_headers(response_headers)
|
||||
if not file_suffix:
|
||||
LOG.warning("No extension name retrieved from HTTP headers")
|
||||
except Exception:
|
||||
LOG.exception("Failed to retrieve the file extension from HTTP headers")
|
||||
|
||||
query_params = dict(parse_qsl(urlparse(url).query))
|
||||
if "download" in query_params:
|
||||
file_name = query_params["download"]
|
||||
|
||||
if not file_name:
|
||||
LOG.info("No file name retrieved from HTTP headers, using the file name from the URL")
|
||||
file_name = os.path.basename(urlparse(url).path) or "download"
|
||||
|
||||
if not is_valid_mime_type(file_name) and file_suffix:
|
||||
LOG.info("No file extension detected, adding the extension from HTTP headers")
|
||||
file_name = file_name + file_suffix
|
||||
|
||||
return sanitize_filename(file_name)
|
||||
|
||||
|
||||
def validate_download_url(url: str) -> bool:
|
||||
"""Validate if a URL is supported for downloading.
|
||||
|
||||
@@ -126,7 +166,13 @@ def validate_download_url(url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
||||
async def download_file(
|
||||
url: str,
|
||||
max_size_mb: int | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
output_dir: str | None = None,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
# Check if URL is a Google Drive link
|
||||
if "drive.google.com" in url:
|
||||
@@ -175,42 +221,22 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
||||
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
||||
LOG.info("Starting to download file", url=url)
|
||||
encoded_url = encode_url(url)
|
||||
async with session.get(URL(encoded_url, encoded=True)) as response:
|
||||
async with session.get(URL(encoded_url, encoded=True), headers=headers) as response:
|
||||
# Check the content length if available
|
||||
if max_size_mb and response.content_length and response.content_length > max_size_mb * 1024 * 1024:
|
||||
# todo: move to root exception.py
|
||||
raise DownloadFileMaxSizeExceeded(max_size_mb)
|
||||
|
||||
# Parse the URL
|
||||
a = urlparse(url)
|
||||
|
||||
# Get the file name
|
||||
temp_dir = make_temp_directory(prefix="skyvern_downloads_")
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
download_dir = output_dir
|
||||
else:
|
||||
download_dir = make_temp_directory(prefix="skyvern_downloads_")
|
||||
|
||||
file_name = ""
|
||||
file_suffix = ""
|
||||
try:
|
||||
file_name, file_suffix = get_file_name_and_suffix_from_headers(response.headers)
|
||||
if not file_suffix:
|
||||
LOG.warning("No extension name retrieved from HTTP headers")
|
||||
except Exception:
|
||||
LOG.exception("Failed to retrieve the file extension from HTTP headers")
|
||||
|
||||
# parse the query params to get the file name
|
||||
query_params = dict(parse_qsl(a.query))
|
||||
if "download" in query_params:
|
||||
file_name = query_params["download"]
|
||||
|
||||
if not file_name:
|
||||
LOG.info("No file name retrieved from HTTP headers, using the file name from the URL")
|
||||
file_name = os.path.basename(a.path)
|
||||
|
||||
if not is_valid_mime_type(file_name) and file_suffix:
|
||||
LOG.info("No file extension detected, adding the extension from HTTP headers")
|
||||
file_name = file_name + file_suffix
|
||||
|
||||
file_name = sanitize_filename(file_name)
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
# Determine filename - use provided filename or derive from response/URL
|
||||
file_name = _determine_download_filename(filename, dict(response.headers), url)
|
||||
file_path = os.path.join(download_dir, file_name)
|
||||
|
||||
LOG.info(f"Downloading file to {file_path}")
|
||||
with open(file_path, "wb") as f:
|
||||
|
||||
@@ -96,11 +96,9 @@ async def aiohttp_request(
|
||||
async with session.request(method.upper(), **request_kwargs) as response:
|
||||
response_headers = dict(response.headers)
|
||||
|
||||
# Try to parse response as JSON
|
||||
try:
|
||||
response_body = await response.json()
|
||||
except (aiohttp.ContentTypeError, Exception):
|
||||
# If not JSON, get as text
|
||||
response_body = await response.text()
|
||||
|
||||
return response.status, response_headers, response_body
|
||||
|
||||
@@ -18,6 +18,7 @@ from types import SimpleNamespace
|
||||
from typing import Annotated, Any, Awaitable, Callable, ClassVar, Literal, Union, cast
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
import aiohttp
|
||||
import filetype
|
||||
import pandas as pd
|
||||
import pyotp
|
||||
@@ -38,6 +39,7 @@ from skyvern.constants import (
|
||||
from skyvern.exceptions import (
|
||||
AzureConfigurationError,
|
||||
ContextParameterValueNotFound,
|
||||
DownloadFileMaxSizeExceeded,
|
||||
MissingBrowserState,
|
||||
MissingBrowserStatePage,
|
||||
PDFParsingError,
|
||||
@@ -54,6 +56,7 @@ from skyvern.forge.sdk.api.files import (
|
||||
create_named_temporary_file,
|
||||
download_file,
|
||||
download_from_s3,
|
||||
get_download_dir,
|
||||
get_path_for_workflow_download_directory,
|
||||
parse_uri_to_path,
|
||||
)
|
||||
@@ -4003,6 +4006,8 @@ class HttpRequestBlock(Block):
|
||||
files: dict[str, str] | None = None # Dictionary mapping field names to file paths for multipart file uploads
|
||||
timeout: int = 30
|
||||
follow_redirects: bool = True
|
||||
download_filename: str | None = None
|
||||
save_response_as_file: bool = False
|
||||
|
||||
# Parameters for templating
|
||||
parameters: list[PARAMETER_TYPE] = []
|
||||
@@ -4101,6 +4106,11 @@ class HttpRequestBlock(Block):
|
||||
if self.headers:
|
||||
self.headers = cast(dict[str, str], _render_templates_in_json(self.headers))
|
||||
|
||||
if self.download_filename:
|
||||
self.download_filename = self.format_block_parameter_template_from_workflow_run_context(
|
||||
self.download_filename, workflow_run_context, **template_kwargs
|
||||
)
|
||||
|
||||
def validate_url(self, url: str) -> bool:
|
||||
"""Validate if the URL is properly formatted"""
|
||||
try:
|
||||
@@ -4109,6 +4119,92 @@ class HttpRequestBlock(Block):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _execute_file_download(
|
||||
self,
|
||||
workflow_run_context: WorkflowRunContext,
|
||||
workflow_run_id: str,
|
||||
workflow_run_block_id: str,
|
||||
organization_id: str | None,
|
||||
) -> BlockResult:
|
||||
if not self.url:
|
||||
return await self.build_block_result(
|
||||
success=False,
|
||||
failure_reason="URL is required for file download",
|
||||
output_parameter_value=None,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
try:
|
||||
max_size_mb = settings.MAX_HTTP_DOWNLOAD_FILE_SIZE // (1024 * 1024)
|
||||
output_dir = get_download_dir(workflow_run_id)
|
||||
file_path = await download_file(
|
||||
self.url,
|
||||
max_size_mb=max_size_mb,
|
||||
headers=self.headers,
|
||||
output_dir=output_dir,
|
||||
filename=self.download_filename,
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"file_path": file_path,
|
||||
"file_name": os.path.basename(file_path),
|
||||
"file_size": os.path.getsize(file_path),
|
||||
}
|
||||
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, response_data)
|
||||
|
||||
return await self.build_block_result(
|
||||
success=True,
|
||||
failure_reason=None,
|
||||
output_parameter_value=response_data,
|
||||
status=BlockStatus.completed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
except aiohttp.ClientResponseError as e:
|
||||
error_data = {"error": f"HTTP {e.status}", "error_type": "http_error"}
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, error_data)
|
||||
return await self.build_block_result(
|
||||
success=False,
|
||||
failure_reason=f"HTTP {e.status}",
|
||||
output_parameter_value=error_data,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
except DownloadFileMaxSizeExceeded as e:
|
||||
max_size_str = f"{e.max_size:.1f}"
|
||||
error_data = {"error": f"File exceeds maximum size of {max_size_str}MB", "error_type": "file_too_large"}
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, error_data)
|
||||
return await self.build_block_result(
|
||||
success=False,
|
||||
failure_reason=f"File exceeds maximum size of {max_size_str}MB",
|
||||
output_parameter_value=error_data,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
except Exception as e:
|
||||
error_data = {"error": str(e), "error_type": "unknown"}
|
||||
LOG.warning(
|
||||
"File download failed",
|
||||
error=str(e),
|
||||
url=self.url,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, error_data)
|
||||
return await self.build_block_result(
|
||||
success=False,
|
||||
failure_reason=f"File download failed: {str(e)}",
|
||||
output_parameter_value=error_data,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@@ -4280,7 +4376,14 @@ class HttpRequestBlock(Block):
|
||||
# Update self.files with local file paths
|
||||
self.files = downloaded_files
|
||||
|
||||
# Execute HTTP request using the generic aiohttp_request function
|
||||
if self.save_response_as_file:
|
||||
return await self._execute_file_download(
|
||||
workflow_run_context=workflow_run_context,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Executing HTTP request",
|
||||
@@ -4292,7 +4395,6 @@ class HttpRequestBlock(Block):
|
||||
files=self.files,
|
||||
)
|
||||
|
||||
# Use the generic aiohttp_request function
|
||||
status_code, response_headers, response_body = await aiohttp_request(
|
||||
method=self.method,
|
||||
url=self.url,
|
||||
@@ -4304,22 +4406,18 @@ class HttpRequestBlock(Block):
|
||||
)
|
||||
|
||||
response_data = {
|
||||
# Response information
|
||||
"status_code": status_code,
|
||||
"response_headers": response_headers,
|
||||
"response_body": response_body,
|
||||
# Request information (what was sent)
|
||||
"request_method": self.method,
|
||||
"request_url": self.url,
|
||||
"request_headers": self.headers,
|
||||
"request_body": self.body,
|
||||
# Backwards compatibility
|
||||
"headers": response_headers,
|
||||
"body": response_body,
|
||||
"url": self.url,
|
||||
}
|
||||
|
||||
# Mask secrets in output to prevent credential exposure in DB/UI
|
||||
response_data = workflow_run_context.mask_secrets_in_data(response_data)
|
||||
|
||||
LOG.info(
|
||||
@@ -4331,14 +4429,14 @@ class HttpRequestBlock(Block):
|
||||
response_data=response_data,
|
||||
)
|
||||
|
||||
# Determine success based on status code
|
||||
success = 200 <= status_code < 300
|
||||
failure_reason = None if success else f"HTTP {status_code}: {response_data.get('response_body', '')}"
|
||||
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, response_data)
|
||||
|
||||
return await self.build_block_result(
|
||||
success=success,
|
||||
failure_reason=None if success else f"HTTP {status_code}: {response_body}",
|
||||
failure_reason=failure_reason,
|
||||
output_parameter_value=response_data,
|
||||
status=BlockStatus.completed if success else BlockStatus.failed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
@@ -4358,7 +4456,7 @@ class HttpRequestBlock(Block):
|
||||
)
|
||||
except Exception as e:
|
||||
error_data = {"error": str(e), "error_type": "unknown"}
|
||||
LOG.warning( # Changed from LOG.exception to LOG.warning as requested
|
||||
LOG.warning(
|
||||
"HTTP request failed with unexpected error",
|
||||
error=str(e),
|
||||
url=self.url,
|
||||
|
||||
@@ -3752,6 +3752,8 @@ class WorkflowService:
|
||||
files=block_yaml.files,
|
||||
timeout=block_yaml.timeout,
|
||||
follow_redirects=block_yaml.follow_redirects,
|
||||
download_filename=block_yaml.download_filename,
|
||||
save_response_as_file=block_yaml.save_response_as_file,
|
||||
parameters=http_request_block_parameters,
|
||||
)
|
||||
elif block_yaml.block_type == BlockType.GOTO_URL:
|
||||
|
||||
Reference in New Issue
Block a user