diff --git a/skyvern/forge/sdk/api/files.py b/skyvern/forge/sdk/api/files.py index 116417f2..ea5c61de 100644 --- a/skyvern/forge/sdk/api/files.py +++ b/skyvern/forge/sdk/api/files.py @@ -7,7 +7,7 @@ import shutil import tempfile import zipfile from pathlib import Path -from urllib.parse import urlparse +from urllib.parse import unquote, urlparse import aiohttp import structlog @@ -72,6 +72,15 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str: client = AsyncAWSClient() return await download_from_s3(client, url) + # Check if URL is a file:// URI + # we only support to download local files when the environment is local + # and the file is in the skyvern downloads directory + if url.startswith("file://") and settings.ENV == "local": + file_path = parse_uri_to_path(url) + if file_path.startswith(f"{REPO_ROOT_DIR}/downloads"): + LOG.info("Downloading file from local file system", url=url) + return file_path + async with aiohttp.ClientSession(raise_for_status=True) as session: LOG.info("Starting to download file", url=url) async with session.get(url) as response: @@ -273,3 +282,11 @@ def clean_up_dir(dir: str) -> None: def clean_up_skyvern_temp_dir() -> None: return clean_up_dir(get_skyvern_temp_dir()) + + +def parse_uri_to_path(uri: str) -> str: + parsed_uri = urlparse(uri) + if parsed_uri.scheme != "file": + raise ValueError(f"Invalid URI scheme: {parsed_uri.scheme} expected: file") + path = parsed_uri.netloc + parsed_uri.path + return unquote(path) diff --git a/skyvern/forge/sdk/artifact/storage/local.py b/skyvern/forge/sdk/artifact/storage/local.py index e4483a9f..035d0587 100644 --- a/skyvern/forge/sdk/artifact/storage/local.py +++ b/skyvern/forge/sdk/artifact/storage/local.py @@ -2,12 +2,11 @@ import os import shutil from datetime import datetime from pathlib import Path -from urllib.parse import unquote, urlparse import structlog from skyvern.config import settings -from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir +from skyvern.forge.sdk.api.files import get_download_dir, get_skyvern_temp_dir, parse_uri_to_path from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage from skyvern.forge.sdk.models import Step @@ -68,7 +67,7 @@ class LocalStorage(BaseStorage): async def store_artifact(self, artifact: Artifact, data: bytes) -> None: file_path = None try: - file_path = Path(self._parse_uri_to_path(artifact.uri)) + file_path = Path(parse_uri_to_path(artifact.uri)) self._create_directories_if_not_exists(file_path) with open(file_path, "wb") as f: f.write(data) @@ -82,7 +81,7 @@ class LocalStorage(BaseStorage): async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None: file_path = None try: - file_path = Path(self._parse_uri_to_path(artifact.uri)) + file_path = Path(parse_uri_to_path(artifact.uri)) self._create_directories_if_not_exists(file_path) Path(path).replace(file_path) except Exception: @@ -95,7 +94,7 @@ class LocalStorage(BaseStorage): async def retrieve_artifact(self, artifact: Artifact) -> bytes | None: file_path = None try: - file_path = self._parse_uri_to_path(artifact.uri) + file_path = parse_uri_to_path(artifact.uri) with open(file_path, "rb") as f: return f.read() except Exception: @@ -170,14 +169,6 @@ class LocalStorage(BaseStorage): files.append(f"file://{path}") return files - @staticmethod - def _parse_uri_to_path(uri: str) -> str: - parsed_uri = urlparse(uri) - if parsed_uri.scheme != "file": - raise ValueError("Invalid URI scheme: {parsed_uri.scheme} expected: file") - path = parsed_uri.netloc + parsed_uri.path - return unquote(path) - @staticmethod def _create_directories_if_not_exists(path_including_file_name: Path) -> None: path = path_including_file_name.parent diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 104383a7..74147872 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -352,15 +352,17 @@ class TaskOutput(BaseModel): extracted_information: list | dict[str, Any] | str | None = None failure_reason: str | None = None errors: list[dict[str, Any]] = [] + downloaded_file_urls: list[str] | None = None @staticmethod - def from_task(task: Task) -> TaskOutput: + def from_task(task: Task, downloaded_file_urls: list[str] | None = None) -> TaskOutput: return TaskOutput( task_id=task.task_id, status=task.status, extracted_information=task.extracted_information, failure_reason=task.failure_reason, errors=task.errors, + downloaded_file_urls=downloaded_file_urls, ) diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 7cffde14..81312409 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -25,7 +25,7 @@ from pypdf import PdfReader from pypdf.errors import PdfReadError from skyvern.config import settings -from skyvern.constants import MAX_UPLOAD_FILE_COUNT +from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT, MAX_UPLOAD_FILE_COUNT from skyvern.exceptions import ( ContextParameterValueNotFound, DisabledBlockExecutionError, @@ -633,7 +633,18 @@ class BaseTaskBlock(Block): organization_id=workflow_run.organization_id, ) success = updated_task.status == TaskStatus.completed - task_output = TaskOutput.from_task(updated_task) + + downloaded_file_urls = [] + try: + async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT): + downloaded_file_urls = await app.STORAGE.get_downloaded_files( + organization_id=workflow_run.organization_id, + task_id=updated_task.task_id, + workflow_run_id=workflow_run_id, + ) + except asyncio.TimeoutError: + LOG.warning("Timeout getting downloaded files", task_id=updated_task.task_id) + task_output = TaskOutput.from_task(updated_task, downloaded_file_urls) output_parameter_value = task_output.model_dump() await self.record_output_parameter_value(workflow_run_context, workflow_run_id, output_parameter_value) return await self.build_block_result( @@ -682,7 +693,18 @@ class BaseTaskBlock(Block): current_retry += 1 will_retry = current_retry <= self.max_retries retry_message = f", retrying task {current_retry}/{self.max_retries}" if will_retry else "" - task_output = TaskOutput.from_task(updated_task) + downloaded_file_urls = [] + try: + async with asyncio.timeout(GET_DOWNLOADED_FILES_TIMEOUT): + downloaded_file_urls = await app.STORAGE.get_downloaded_files( + organization_id=workflow_run.organization_id, + task_id=updated_task.task_id, + workflow_run_id=workflow_run_id, + ) + except asyncio.TimeoutError: + LOG.warning("Timeout getting downloaded files", task_id=updated_task.task_id) + + task_output = TaskOutput.from_task(updated_task, downloaded_file_urls) LOG.warning( f"Task failed with status {updated_task.status}{retry_message}", task_id=updated_task.task_id,