From 3d1b14647052861c28cc11458570ddbdebdb47e6 Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Thu, 28 Mar 2024 16:46:54 -0700 Subject: [PATCH] Implement DownloadToS3Block (#133) --- skyvern/config.py | 4 + skyvern/forge/sdk/api/aws.py | 2 +- skyvern/forge/sdk/workflow/exceptions.py | 5 + skyvern/forge/sdk/workflow/models/block.py | 101 ++++++++++++++++++++- skyvern/forge/sdk/workflow/models/yaml.py | 12 ++- skyvern/forge/sdk/workflow/service.py | 7 ++ 6 files changed, 128 insertions(+), 3 deletions(-) diff --git a/skyvern/config.py b/skyvern/config.py index 8255e3c4..ecfc5114 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -40,6 +40,10 @@ class Settings(BaseSettings): # Artifact storage settings ARTIFACT_STORAGE_PATH: str = f"{SKYVERN_DIR}/artifacts" + # S3 bucket settings + AWS_REGION: str = "us-east-1" + AWS_S3_BUCKET_DOWNLOADS: str = "skyvern-downloads" + SKYVERN_TELEMETRY: bool = True ANALYTICS_ID: str = "anonymous" diff --git a/skyvern/forge/sdk/api/aws.py b/skyvern/forge/sdk/api/aws.py index 615ec489..f5300ddd 100644 --- a/skyvern/forge/sdk/api/aws.py +++ b/skyvern/forge/sdk/api/aws.py @@ -22,7 +22,7 @@ def execute_with_async_client(client_type: AWSClientType) -> Callable: self = args[0] assert isinstance(self, AsyncAWSClient) session = aioboto3.Session() - async with session.client(client_type) as client: + async with session.client(client_type, region_name=SettingsManager.get_settings().AWS_REGION) as client: return await f(*args, client=client, **kwargs) return wrapper diff --git a/skyvern/forge/sdk/workflow/exceptions.py b/skyvern/forge/sdk/workflow/exceptions.py index 18112ae7..8b3e9bb1 100644 --- a/skyvern/forge/sdk/workflow/exceptions.py +++ b/skyvern/forge/sdk/workflow/exceptions.py @@ -29,3 +29,8 @@ class WorkflowDefinitionHasDuplicateParameterKeys(BaseWorkflowException): f"WorkflowDefinition has parameters with duplicate keys. Each parameter needs to have a unique " f"key. Duplicate key(s): {','.join(duplicate_keys)}" ) + + +class DownloadFileMaxSizeExceeded(BaseWorkflowException): + def __init__(self, max_size: int) -> None: + super().__init__(f"Download file size exceeded the maximum allowed size of {max_size} MB.") diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 0a09ccdd..861b2acc 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -1,8 +1,12 @@ import abc import json +import os +import uuid from enum import StrEnum +from tempfile import NamedTemporaryFile from typing import Annotated, Any, Literal, Union +import aiohttp import structlog from pydantic import BaseModel, Field @@ -14,9 +18,12 @@ from skyvern.exceptions import ( ) from skyvern.forge import app from skyvern.forge.prompts import prompt_engine +from skyvern.forge.sdk.api.aws import AsyncAWSClient from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.schemas.tasks import TaskStatus +from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext +from skyvern.forge.sdk.workflow.exceptions import DownloadFileMaxSizeExceeded from skyvern.forge.sdk.workflow.models.parameter import ( PARAMETER_TYPE, ContextParameter, @@ -32,6 +39,7 @@ class BlockType(StrEnum): FOR_LOOP = "for_loop" CODE = "code" TEXT_PROMPT = "text_prompt" + DOWNLOAD_TO_S3 = "download_to_s3" class Block(BaseModel, abc.ABC): @@ -48,6 +56,10 @@ class Block(BaseModel, abc.ABC): def get_workflow_run_context(workflow_run_id: str) -> WorkflowRunContext: return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id) + @staticmethod + def get_async_aws_client() -> AsyncAWSClient: + return app.WORKFLOW_CONTEXT_MANAGER.aws_client + @abc.abstractmethod async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None: pass @@ -417,5 +429,92 @@ class TextPromptBlock(Block): return None -BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock] +class DownloadToS3Block(Block): + block_type: Literal[BlockType.DOWNLOAD_TO_S3] = BlockType.DOWNLOAD_TO_S3 + + url: str + + def get_all_parameters( + self, + ) -> list[PARAMETER_TYPE]: + return [] + + async def _download_file(self, max_size_mb: int = 5) -> str: + async with aiohttp.ClientSession() as session: + LOG.info("Downloading file", url=self.url) + async with session.get(self.url) as response: + # Check the content length if available + if response.content_length and response.content_length > max_size_mb * 1024 * 1024: + raise DownloadFileMaxSizeExceeded(max_size_mb) + + # Don't forget to delete the temporary file after we're done with it + temp_file = NamedTemporaryFile(delete=False) + + total_bytes_downloaded = 0 + async for chunk in response.content.iter_chunked(8192): + temp_file.write(chunk) + total_bytes_downloaded += len(chunk) + if total_bytes_downloaded > max_size_mb * 1024 * 1024: + raise DownloadFileMaxSizeExceeded(max_size_mb) + + # Seek back to the start of the file + temp_file.seek(0) + + return temp_file.name + + async def _upload_file_to_s3(self, uri: str, file_path: str) -> None: + try: + client = self.get_async_aws_client() + await client.upload_file_from_path(uri=uri, file_path=file_path) + finally: + # Clean up the temporary file since it's created with delete=False + os.unlink(file_path) + + async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None: + # get workflow run context + workflow_run_context = self.get_workflow_run_context(workflow_run_id) + # get all parameters into a dictionary + if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url): + task_url_parameter_value = workflow_run_context.get_value(self.url) + if task_url_parameter_value: + LOG.info( + "DownloadToS3Block: Task URL is parameterized, using parameter value", + task_url_parameter_value=task_url_parameter_value, + task_url_parameter_key=self.url, + ) + self.url = task_url_parameter_value + + try: + file_path = await self._download_file() + except Exception as e: + LOG.error("DownloadToS3Block: Failed to download file", url=self.url, error=str(e)) + raise e + + uri = None + try: + uri = f"s3://{SettingsManager.get_settings().AWS_S3_BUCKET_DOWNLOADS}/{SettingsManager.get_settings().ENV}/{workflow_run_id}/{uuid.uuid4()}" + await self._upload_file_to_s3(uri, file_path) + except Exception as e: + LOG.error("DownloadToS3Block: Failed to upload file to S3", uri=uri, error=str(e)) + raise e + + LOG.info("DownloadToS3Block: File downloaded and uploaded to S3", uri=uri) + if self.output_parameter: + LOG.info("DownloadToS3Block: Output parameter defined, registering output parameter value") + await workflow_run_context.register_output_parameter_value_post_execution( + parameter=self.output_parameter, + value=uri, + ) + await app.DATABASE.create_workflow_run_output_parameter( + workflow_run_id=workflow_run_id, + output_parameter_id=self.output_parameter.output_parameter_id, + value=uri, + ) + return self.output_parameter + + LOG.info("DownloadToS3Block: No output parameter defined, returning None") + return None + + +BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock, DownloadToS3Block] BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")] diff --git a/skyvern/forge/sdk/workflow/models/yaml.py b/skyvern/forge/sdk/workflow/models/yaml.py index fefb9ad3..03ab4a00 100644 --- a/skyvern/forge/sdk/workflow/models/yaml.py +++ b/skyvern/forge/sdk/workflow/models/yaml.py @@ -107,10 +107,20 @@ class TextPromptBlockYAML(BlockYAML): json_schema: dict[str, Any] | None = None +class DownloadToS3BlockYAML(BlockYAML): + # There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error: + # Parameter 1 of Literal[...] cannot be of type "Any" + # This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able + # to infer the type of the parameter_type attribute. + block_type: Literal[BlockType.DOWNLOAD_TO_S3] = BlockType.DOWNLOAD_TO_S3 # type: ignore + + url: str + + PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")] -BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML +BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML | DownloadToS3BlockYAML BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")] diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index c2c9c160..cecfc1d5 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -25,6 +25,7 @@ from skyvern.forge.sdk.workflow.models.block import ( BlockType, BlockTypeVar, CodeBlock, + DownloadToS3Block, ForLoopBlock, TaskBlock, TextPromptBlock, @@ -732,4 +733,10 @@ class WorkflowService: json_schema=block_yaml.json_schema, output_parameter=output_parameter, ) + elif block_yaml.block_type == BlockType.DOWNLOAD_TO_S3: + return DownloadToS3Block( + label=block_yaml.label, + output_parameter=output_parameter, + url=block_yaml.url, + ) raise ValueError(f"Invalid block type {block_yaml.block_type}")