Implement DownloadToS3Block (#133)

This commit is contained in:
Kerem Yilmaz
2024-03-28 16:46:54 -07:00
committed by GitHub
parent 57062952b8
commit 3d1b146470
6 changed files with 128 additions and 3 deletions

View File

@@ -40,6 +40,10 @@ class Settings(BaseSettings):
# Artifact storage settings
ARTIFACT_STORAGE_PATH: str = f"{SKYVERN_DIR}/artifacts"
# S3 bucket settings
AWS_REGION: str = "us-east-1"
AWS_S3_BUCKET_DOWNLOADS: str = "skyvern-downloads"
SKYVERN_TELEMETRY: bool = True
ANALYTICS_ID: str = "anonymous"

View File

@@ -22,7 +22,7 @@ def execute_with_async_client(client_type: AWSClientType) -> Callable:
self = args[0]
assert isinstance(self, AsyncAWSClient)
session = aioboto3.Session()
async with session.client(client_type) as client:
async with session.client(client_type, region_name=SettingsManager.get_settings().AWS_REGION) as client:
return await f(*args, client=client, **kwargs)
return wrapper

View File

@@ -29,3 +29,8 @@ class WorkflowDefinitionHasDuplicateParameterKeys(BaseWorkflowException):
f"WorkflowDefinition has parameters with duplicate keys. Each parameter needs to have a unique "
f"key. Duplicate key(s): {','.join(duplicate_keys)}"
)
class DownloadFileMaxSizeExceeded(BaseWorkflowException):
def __init__(self, max_size: int) -> None:
super().__init__(f"Download file size exceeded the maximum allowed size of {max_size} MB.")

View File

@@ -1,8 +1,12 @@
import abc
import json
import os
import uuid
from enum import StrEnum
from tempfile import NamedTemporaryFile
from typing import Annotated, Any, Literal, Union
import aiohttp
import structlog
from pydantic import BaseModel, Field
@@ -14,9 +18,12 @@ from skyvern.exceptions import (
)
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.exceptions import DownloadFileMaxSizeExceeded
from skyvern.forge.sdk.workflow.models.parameter import (
PARAMETER_TYPE,
ContextParameter,
@@ -32,6 +39,7 @@ class BlockType(StrEnum):
FOR_LOOP = "for_loop"
CODE = "code"
TEXT_PROMPT = "text_prompt"
DOWNLOAD_TO_S3 = "download_to_s3"
class Block(BaseModel, abc.ABC):
@@ -48,6 +56,10 @@ class Block(BaseModel, abc.ABC):
def get_workflow_run_context(workflow_run_id: str) -> WorkflowRunContext:
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
@staticmethod
def get_async_aws_client() -> AsyncAWSClient:
return app.WORKFLOW_CONTEXT_MANAGER.aws_client
@abc.abstractmethod
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
pass
@@ -417,5 +429,92 @@ class TextPromptBlock(Block):
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock]
class DownloadToS3Block(Block):
block_type: Literal[BlockType.DOWNLOAD_TO_S3] = BlockType.DOWNLOAD_TO_S3
url: str
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return []
async def _download_file(self, max_size_mb: int = 5) -> str:
async with aiohttp.ClientSession() as session:
LOG.info("Downloading file", url=self.url)
async with session.get(self.url) as response:
# Check the content length if available
if response.content_length and response.content_length > max_size_mb * 1024 * 1024:
raise DownloadFileMaxSizeExceeded(max_size_mb)
# Don't forget to delete the temporary file after we're done with it
temp_file = NamedTemporaryFile(delete=False)
total_bytes_downloaded = 0
async for chunk in response.content.iter_chunked(8192):
temp_file.write(chunk)
total_bytes_downloaded += len(chunk)
if total_bytes_downloaded > max_size_mb * 1024 * 1024:
raise DownloadFileMaxSizeExceeded(max_size_mb)
# Seek back to the start of the file
temp_file.seek(0)
return temp_file.name
async def _upload_file_to_s3(self, uri: str, file_path: str) -> None:
try:
client = self.get_async_aws_client()
await client.upload_file_from_path(uri=uri, file_path=file_path)
finally:
# Clean up the temporary file since it's created with delete=False
os.unlink(file_path)
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
# get all parameters into a dictionary
if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url):
task_url_parameter_value = workflow_run_context.get_value(self.url)
if task_url_parameter_value:
LOG.info(
"DownloadToS3Block: Task URL is parameterized, using parameter value",
task_url_parameter_value=task_url_parameter_value,
task_url_parameter_key=self.url,
)
self.url = task_url_parameter_value
try:
file_path = await self._download_file()
except Exception as e:
LOG.error("DownloadToS3Block: Failed to download file", url=self.url, error=str(e))
raise e
uri = None
try:
uri = f"s3://{SettingsManager.get_settings().AWS_S3_BUCKET_DOWNLOADS}/{SettingsManager.get_settings().ENV}/{workflow_run_id}/{uuid.uuid4()}"
await self._upload_file_to_s3(uri, file_path)
except Exception as e:
LOG.error("DownloadToS3Block: Failed to upload file to S3", uri=uri, error=str(e))
raise e
LOG.info("DownloadToS3Block: File downloaded and uploaded to S3", uri=uri)
if self.output_parameter:
LOG.info("DownloadToS3Block: Output parameter defined, registering output parameter value")
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=uri,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=uri,
)
return self.output_parameter
LOG.info("DownloadToS3Block: No output parameter defined, returning None")
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock, DownloadToS3Block]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]

View File

@@ -107,10 +107,20 @@ class TextPromptBlockYAML(BlockYAML):
json_schema: dict[str, Any] | None = None
class DownloadToS3BlockYAML(BlockYAML):
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
# Parameter 1 of Literal[...] cannot be of type "Any"
# This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able
# to infer the type of the parameter_type attribute.
block_type: Literal[BlockType.DOWNLOAD_TO_S3] = BlockType.DOWNLOAD_TO_S3 # type: ignore
url: str
PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML
PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")]
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML | DownloadToS3BlockYAML
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]

View File

@@ -25,6 +25,7 @@ from skyvern.forge.sdk.workflow.models.block import (
BlockType,
BlockTypeVar,
CodeBlock,
DownloadToS3Block,
ForLoopBlock,
TaskBlock,
TextPromptBlock,
@@ -732,4 +733,10 @@ class WorkflowService:
json_schema=block_yaml.json_schema,
output_parameter=output_parameter,
)
elif block_yaml.block_type == BlockType.DOWNLOAD_TO_S3:
return DownloadToS3Block(
label=block_yaml.label,
output_parameter=output_parameter,
url=block_yaml.url,
)
raise ValueError(f"Invalid block type {block_yaml.block_type}")