Implement DownloadToS3Block (#133)
This commit is contained in:
@@ -40,6 +40,10 @@ class Settings(BaseSettings):
|
||||
# Artifact storage settings
|
||||
ARTIFACT_STORAGE_PATH: str = f"{SKYVERN_DIR}/artifacts"
|
||||
|
||||
# S3 bucket settings
|
||||
AWS_REGION: str = "us-east-1"
|
||||
AWS_S3_BUCKET_DOWNLOADS: str = "skyvern-downloads"
|
||||
|
||||
SKYVERN_TELEMETRY: bool = True
|
||||
ANALYTICS_ID: str = "anonymous"
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def execute_with_async_client(client_type: AWSClientType) -> Callable:
|
||||
self = args[0]
|
||||
assert isinstance(self, AsyncAWSClient)
|
||||
session = aioboto3.Session()
|
||||
async with session.client(client_type) as client:
|
||||
async with session.client(client_type, region_name=SettingsManager.get_settings().AWS_REGION) as client:
|
||||
return await f(*args, client=client, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -29,3 +29,8 @@ class WorkflowDefinitionHasDuplicateParameterKeys(BaseWorkflowException):
|
||||
f"WorkflowDefinition has parameters with duplicate keys. Each parameter needs to have a unique "
|
||||
f"key. Duplicate key(s): {','.join(duplicate_keys)}"
|
||||
)
|
||||
|
||||
|
||||
class DownloadFileMaxSizeExceeded(BaseWorkflowException):
|
||||
def __init__(self, max_size: int) -> None:
|
||||
super().__init__(f"Download file size exceeded the maximum allowed size of {max_size} MB.")
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import abc
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -14,9 +18,12 @@ from skyvern.exceptions import (
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.exceptions import DownloadFileMaxSizeExceeded
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
PARAMETER_TYPE,
|
||||
ContextParameter,
|
||||
@@ -32,6 +39,7 @@ class BlockType(StrEnum):
|
||||
FOR_LOOP = "for_loop"
|
||||
CODE = "code"
|
||||
TEXT_PROMPT = "text_prompt"
|
||||
DOWNLOAD_TO_S3 = "download_to_s3"
|
||||
|
||||
|
||||
class Block(BaseModel, abc.ABC):
|
||||
@@ -48,6 +56,10 @@ class Block(BaseModel, abc.ABC):
|
||||
def get_workflow_run_context(workflow_run_id: str) -> WorkflowRunContext:
|
||||
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
||||
|
||||
@staticmethod
|
||||
def get_async_aws_client() -> AsyncAWSClient:
|
||||
return app.WORKFLOW_CONTEXT_MANAGER.aws_client
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
pass
|
||||
@@ -417,5 +429,92 @@ class TextPromptBlock(Block):
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock]
|
||||
class DownloadToS3Block(Block):
|
||||
block_type: Literal[BlockType.DOWNLOAD_TO_S3] = BlockType.DOWNLOAD_TO_S3
|
||||
|
||||
url: str
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
return []
|
||||
|
||||
async def _download_file(self, max_size_mb: int = 5) -> str:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
LOG.info("Downloading file", url=self.url)
|
||||
async with session.get(self.url) as response:
|
||||
# Check the content length if available
|
||||
if response.content_length and response.content_length > max_size_mb * 1024 * 1024:
|
||||
raise DownloadFileMaxSizeExceeded(max_size_mb)
|
||||
|
||||
# Don't forget to delete the temporary file after we're done with it
|
||||
temp_file = NamedTemporaryFile(delete=False)
|
||||
|
||||
total_bytes_downloaded = 0
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
temp_file.write(chunk)
|
||||
total_bytes_downloaded += len(chunk)
|
||||
if total_bytes_downloaded > max_size_mb * 1024 * 1024:
|
||||
raise DownloadFileMaxSizeExceeded(max_size_mb)
|
||||
|
||||
# Seek back to the start of the file
|
||||
temp_file.seek(0)
|
||||
|
||||
return temp_file.name
|
||||
|
||||
async def _upload_file_to_s3(self, uri: str, file_path: str) -> None:
|
||||
try:
|
||||
client = self.get_async_aws_client()
|
||||
await client.upload_file_from_path(uri=uri, file_path=file_path)
|
||||
finally:
|
||||
# Clean up the temporary file since it's created with delete=False
|
||||
os.unlink(file_path)
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
# get workflow run context
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
# get all parameters into a dictionary
|
||||
if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url):
|
||||
task_url_parameter_value = workflow_run_context.get_value(self.url)
|
||||
if task_url_parameter_value:
|
||||
LOG.info(
|
||||
"DownloadToS3Block: Task URL is parameterized, using parameter value",
|
||||
task_url_parameter_value=task_url_parameter_value,
|
||||
task_url_parameter_key=self.url,
|
||||
)
|
||||
self.url = task_url_parameter_value
|
||||
|
||||
try:
|
||||
file_path = await self._download_file()
|
||||
except Exception as e:
|
||||
LOG.error("DownloadToS3Block: Failed to download file", url=self.url, error=str(e))
|
||||
raise e
|
||||
|
||||
uri = None
|
||||
try:
|
||||
uri = f"s3://{SettingsManager.get_settings().AWS_S3_BUCKET_DOWNLOADS}/{SettingsManager.get_settings().ENV}/{workflow_run_id}/{uuid.uuid4()}"
|
||||
await self._upload_file_to_s3(uri, file_path)
|
||||
except Exception as e:
|
||||
LOG.error("DownloadToS3Block: Failed to upload file to S3", uri=uri, error=str(e))
|
||||
raise e
|
||||
|
||||
LOG.info("DownloadToS3Block: File downloaded and uploaded to S3", uri=uri)
|
||||
if self.output_parameter:
|
||||
LOG.info("DownloadToS3Block: Output parameter defined, registering output parameter value")
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=uri,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=uri,
|
||||
)
|
||||
return self.output_parameter
|
||||
|
||||
LOG.info("DownloadToS3Block: No output parameter defined, returning None")
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock, DownloadToS3Block]
|
||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||
|
||||
@@ -107,10 +107,20 @@ class TextPromptBlockYAML(BlockYAML):
|
||||
json_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DownloadToS3BlockYAML(BlockYAML):
|
||||
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
||||
# Parameter 1 of Literal[...] cannot be of type "Any"
|
||||
# This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able
|
||||
# to infer the type of the parameter_type attribute.
|
||||
block_type: Literal[BlockType.DOWNLOAD_TO_S3] = BlockType.DOWNLOAD_TO_S3 # type: ignore
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML
|
||||
PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")]
|
||||
|
||||
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML
|
||||
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML | DownloadToS3BlockYAML
|
||||
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from skyvern.forge.sdk.workflow.models.block import (
|
||||
BlockType,
|
||||
BlockTypeVar,
|
||||
CodeBlock,
|
||||
DownloadToS3Block,
|
||||
ForLoopBlock,
|
||||
TaskBlock,
|
||||
TextPromptBlock,
|
||||
@@ -732,4 +733,10 @@ class WorkflowService:
|
||||
json_schema=block_yaml.json_schema,
|
||||
output_parameter=output_parameter,
|
||||
)
|
||||
elif block_yaml.block_type == BlockType.DOWNLOAD_TO_S3:
|
||||
return DownloadToS3Block(
|
||||
label=block_yaml.label,
|
||||
output_parameter=output_parameter,
|
||||
url=block_yaml.url,
|
||||
)
|
||||
raise ValueError(f"Invalid block type {block_yaml.block_type}")
|
||||
|
||||
Reference in New Issue
Block a user