Implement DownloadToS3Block (#133)
This commit is contained in:
@@ -40,6 +40,10 @@ class Settings(BaseSettings):
|
|||||||
# Artifact storage settings
|
# Artifact storage settings
|
||||||
ARTIFACT_STORAGE_PATH: str = f"{SKYVERN_DIR}/artifacts"
|
ARTIFACT_STORAGE_PATH: str = f"{SKYVERN_DIR}/artifacts"
|
||||||
|
|
||||||
|
# S3 bucket settings
|
||||||
|
AWS_REGION: str = "us-east-1"
|
||||||
|
AWS_S3_BUCKET_DOWNLOADS: str = "skyvern-downloads"
|
||||||
|
|
||||||
SKYVERN_TELEMETRY: bool = True
|
SKYVERN_TELEMETRY: bool = True
|
||||||
ANALYTICS_ID: str = "anonymous"
|
ANALYTICS_ID: str = "anonymous"
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def execute_with_async_client(client_type: AWSClientType) -> Callable:
|
|||||||
self = args[0]
|
self = args[0]
|
||||||
assert isinstance(self, AsyncAWSClient)
|
assert isinstance(self, AsyncAWSClient)
|
||||||
session = aioboto3.Session()
|
session = aioboto3.Session()
|
||||||
async with session.client(client_type) as client:
|
async with session.client(client_type, region_name=SettingsManager.get_settings().AWS_REGION) as client:
|
||||||
return await f(*args, client=client, **kwargs)
|
return await f(*args, client=client, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -29,3 +29,8 @@ class WorkflowDefinitionHasDuplicateParameterKeys(BaseWorkflowException):
|
|||||||
f"WorkflowDefinition has parameters with duplicate keys. Each parameter needs to have a unique "
|
f"WorkflowDefinition has parameters with duplicate keys. Each parameter needs to have a unique "
|
||||||
f"key. Duplicate key(s): {','.join(duplicate_keys)}"
|
f"key. Duplicate key(s): {','.join(duplicate_keys)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadFileMaxSizeExceeded(BaseWorkflowException):
|
||||||
|
def __init__(self, max_size: int) -> None:
|
||||||
|
super().__init__(f"Download file size exceeded the maximum allowed size of {max_size} MB.")
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
import abc
|
import abc
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Annotated, Any, Literal, Union
|
from typing import Annotated, Any, Literal, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import structlog
|
import structlog
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -14,9 +18,12 @@ from skyvern.exceptions import (
|
|||||||
)
|
)
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.prompts import prompt_engine
|
from skyvern.forge.prompts import prompt_engine
|
||||||
|
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||||
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||||
|
from skyvern.forge.sdk.workflow.exceptions import DownloadFileMaxSizeExceeded
|
||||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||||
PARAMETER_TYPE,
|
PARAMETER_TYPE,
|
||||||
ContextParameter,
|
ContextParameter,
|
||||||
@@ -32,6 +39,7 @@ class BlockType(StrEnum):
|
|||||||
FOR_LOOP = "for_loop"
|
FOR_LOOP = "for_loop"
|
||||||
CODE = "code"
|
CODE = "code"
|
||||||
TEXT_PROMPT = "text_prompt"
|
TEXT_PROMPT = "text_prompt"
|
||||||
|
DOWNLOAD_TO_S3 = "download_to_s3"
|
||||||
|
|
||||||
|
|
||||||
class Block(BaseModel, abc.ABC):
|
class Block(BaseModel, abc.ABC):
|
||||||
@@ -48,6 +56,10 @@ class Block(BaseModel, abc.ABC):
|
|||||||
def get_workflow_run_context(workflow_run_id: str) -> WorkflowRunContext:
|
def get_workflow_run_context(workflow_run_id: str) -> WorkflowRunContext:
|
||||||
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_async_aws_client() -> AsyncAWSClient:
|
||||||
|
return app.WORKFLOW_CONTEXT_MANAGER.aws_client
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||||
pass
|
pass
|
||||||
@@ -417,5 +429,92 @@ class TextPromptBlock(Block):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock]
|
class DownloadToS3Block(Block):
|
||||||
|
block_type: Literal[BlockType.DOWNLOAD_TO_S3] = BlockType.DOWNLOAD_TO_S3
|
||||||
|
|
||||||
|
url: str
|
||||||
|
|
||||||
|
def get_all_parameters(
|
||||||
|
self,
|
||||||
|
) -> list[PARAMETER_TYPE]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _download_file(self, max_size_mb: int = 5) -> str:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
LOG.info("Downloading file", url=self.url)
|
||||||
|
async with session.get(self.url) as response:
|
||||||
|
# Check the content length if available
|
||||||
|
if response.content_length and response.content_length > max_size_mb * 1024 * 1024:
|
||||||
|
raise DownloadFileMaxSizeExceeded(max_size_mb)
|
||||||
|
|
||||||
|
# Don't forget to delete the temporary file after we're done with it
|
||||||
|
temp_file = NamedTemporaryFile(delete=False)
|
||||||
|
|
||||||
|
total_bytes_downloaded = 0
|
||||||
|
async for chunk in response.content.iter_chunked(8192):
|
||||||
|
temp_file.write(chunk)
|
||||||
|
total_bytes_downloaded += len(chunk)
|
||||||
|
if total_bytes_downloaded > max_size_mb * 1024 * 1024:
|
||||||
|
raise DownloadFileMaxSizeExceeded(max_size_mb)
|
||||||
|
|
||||||
|
# Seek back to the start of the file
|
||||||
|
temp_file.seek(0)
|
||||||
|
|
||||||
|
return temp_file.name
|
||||||
|
|
||||||
|
async def _upload_file_to_s3(self, uri: str, file_path: str) -> None:
|
||||||
|
try:
|
||||||
|
client = self.get_async_aws_client()
|
||||||
|
await client.upload_file_from_path(uri=uri, file_path=file_path)
|
||||||
|
finally:
|
||||||
|
# Clean up the temporary file since it's created with delete=False
|
||||||
|
os.unlink(file_path)
|
||||||
|
|
||||||
|
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||||
|
# get workflow run context
|
||||||
|
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||||
|
# get all parameters into a dictionary
|
||||||
|
if self.url and workflow_run_context.has_parameter(self.url) and workflow_run_context.has_value(self.url):
|
||||||
|
task_url_parameter_value = workflow_run_context.get_value(self.url)
|
||||||
|
if task_url_parameter_value:
|
||||||
|
LOG.info(
|
||||||
|
"DownloadToS3Block: Task URL is parameterized, using parameter value",
|
||||||
|
task_url_parameter_value=task_url_parameter_value,
|
||||||
|
task_url_parameter_key=self.url,
|
||||||
|
)
|
||||||
|
self.url = task_url_parameter_value
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_path = await self._download_file()
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error("DownloadToS3Block: Failed to download file", url=self.url, error=str(e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
uri = None
|
||||||
|
try:
|
||||||
|
uri = f"s3://{SettingsManager.get_settings().AWS_S3_BUCKET_DOWNLOADS}/{SettingsManager.get_settings().ENV}/{workflow_run_id}/{uuid.uuid4()}"
|
||||||
|
await self._upload_file_to_s3(uri, file_path)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error("DownloadToS3Block: Failed to upload file to S3", uri=uri, error=str(e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
LOG.info("DownloadToS3Block: File downloaded and uploaded to S3", uri=uri)
|
||||||
|
if self.output_parameter:
|
||||||
|
LOG.info("DownloadToS3Block: Output parameter defined, registering output parameter value")
|
||||||
|
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||||
|
parameter=self.output_parameter,
|
||||||
|
value=uri,
|
||||||
|
)
|
||||||
|
await app.DATABASE.create_workflow_run_output_parameter(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||||
|
value=uri,
|
||||||
|
)
|
||||||
|
return self.output_parameter
|
||||||
|
|
||||||
|
LOG.info("DownloadToS3Block: No output parameter defined, returning None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock, DownloadToS3Block]
|
||||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||||
|
|||||||
@@ -107,10 +107,20 @@ class TextPromptBlockYAML(BlockYAML):
|
|||||||
json_schema: dict[str, Any] | None = None
|
json_schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadToS3BlockYAML(BlockYAML):
|
||||||
|
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
||||||
|
# Parameter 1 of Literal[...] cannot be of type "Any"
|
||||||
|
# This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able
|
||||||
|
# to infer the type of the parameter_type attribute.
|
||||||
|
block_type: Literal[BlockType.DOWNLOAD_TO_S3] = BlockType.DOWNLOAD_TO_S3 # type: ignore
|
||||||
|
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML
|
PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML
|
||||||
PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")]
|
PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")]
|
||||||
|
|
||||||
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML
|
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML | DownloadToS3BlockYAML
|
||||||
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]
|
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from skyvern.forge.sdk.workflow.models.block import (
|
|||||||
BlockType,
|
BlockType,
|
||||||
BlockTypeVar,
|
BlockTypeVar,
|
||||||
CodeBlock,
|
CodeBlock,
|
||||||
|
DownloadToS3Block,
|
||||||
ForLoopBlock,
|
ForLoopBlock,
|
||||||
TaskBlock,
|
TaskBlock,
|
||||||
TextPromptBlock,
|
TextPromptBlock,
|
||||||
@@ -732,4 +733,10 @@ class WorkflowService:
|
|||||||
json_schema=block_yaml.json_schema,
|
json_schema=block_yaml.json_schema,
|
||||||
output_parameter=output_parameter,
|
output_parameter=output_parameter,
|
||||||
)
|
)
|
||||||
|
elif block_yaml.block_type == BlockType.DOWNLOAD_TO_S3:
|
||||||
|
return DownloadToS3Block(
|
||||||
|
label=block_yaml.label,
|
||||||
|
output_parameter=output_parameter,
|
||||||
|
url=block_yaml.url,
|
||||||
|
)
|
||||||
raise ValueError(f"Invalid block type {block_yaml.block_type}")
|
raise ValueError(f"Invalid block type {block_yaml.block_type}")
|
||||||
|
|||||||
Reference in New Issue
Block a user