file upload block backend (#2000)
This commit is contained in:
@@ -4,7 +4,6 @@ from urllib.parse import urlparse
|
||||
|
||||
import aioboto3
|
||||
import structlog
|
||||
from aiobotocore.client import AioBaseClient
|
||||
|
||||
from skyvern.config import settings
|
||||
|
||||
@@ -32,11 +31,25 @@ def execute_with_async_client(client_type: AWSClientType) -> Callable:
|
||||
|
||||
|
||||
class AsyncAWSClient:
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def get_secret(self, secret_name: str, client: AioBaseClient = None) -> str | None:
|
||||
def __init__(
|
||||
self,
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
region_name: str | None = None,
|
||||
) -> None:
|
||||
self.aws_access_key_id = aws_access_key_id
|
||||
self.aws_secret_access_key = aws_secret_access_key
|
||||
self.region_name = region_name or settings.AWS_REGION
|
||||
self.session = aioboto3.Session(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
)
|
||||
|
||||
async def get_secret(self, secret_name: str) -> str | None:
|
||||
try:
|
||||
response = await client.get_secret_value(SecretId=secret_name)
|
||||
return response["SecretString"]
|
||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
||||
response = await client.get_secret_value(SecretId=secret_name)
|
||||
return response["SecretString"]
|
||||
except Exception as e:
|
||||
try:
|
||||
error_code = e.response["Error"]["Code"] # type: ignore
|
||||
@@ -45,86 +58,93 @@ class AsyncAWSClient:
|
||||
LOG.exception("Failed to get secret.", secret_name=secret_name, error_code=error_code)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def create_secret(self, secret_name: str, secret_value: str, client: AioBaseClient = None) -> None:
|
||||
async def create_secret(self, secret_name: str, secret_value: str) -> None:
|
||||
try:
|
||||
await client.create_secret(Name=secret_name, SecretString=secret_value)
|
||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
||||
await client.create_secret(Name=secret_name, SecretString=secret_value)
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to create secret.", secret_name=secret_name)
|
||||
raise e
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def set_secret(self, secret_name: str, secret_value: str, client: AioBaseClient = None) -> None:
|
||||
async def set_secret(self, secret_name: str, secret_value: str) -> None:
|
||||
try:
|
||||
await client.put_secret_value(SecretId=secret_name, SecretString=secret_value)
|
||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
||||
await client.put_secret_value(SecretId=secret_name, SecretString=secret_value)
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to set secret.", secret_name=secret_name)
|
||||
raise e
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def delete_secret(self, secret_name: str, client: AioBaseClient = None) -> None:
|
||||
async def delete_secret(self, secret_name: str) -> None:
|
||||
try:
|
||||
await client.delete_secret(SecretId=secret_name)
|
||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
||||
await client.delete_secret(SecretId=secret_name)
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to delete secret.", secret_name=secret_name)
|
||||
raise e
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file(self, uri: str, data: bytes, client: AioBaseClient = None) -> str | None:
|
||||
async def upload_file(self, uri: str, data: bytes) -> str | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.put_object(Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return uri
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.put_object(Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return uri
|
||||
except Exception:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file_stream(self, uri: str, file_obj: IO[bytes], client: AioBaseClient = None) -> str | None:
|
||||
async def upload_file_stream(self, uri: str, file_obj: IO[bytes]) -> str | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.upload_fileobj(file_obj, parsed_uri.bucket, parsed_uri.key)
|
||||
LOG.debug("Upload file stream success", uri=uri)
|
||||
return uri
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.upload_fileobj(file_obj, parsed_uri.bucket, parsed_uri.key)
|
||||
LOG.debug("Upload file stream success", uri=uri)
|
||||
return uri
|
||||
except Exception:
|
||||
LOG.exception("S3 upload stream failed.", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file_from_path(
|
||||
self, uri: str, file_path: str, client: AioBaseClient = None, metadata: dict | None = None
|
||||
self,
|
||||
uri: str,
|
||||
file_path: str,
|
||||
metadata: dict | None = None,
|
||||
raise_exception: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
params: dict[str, Any] = {
|
||||
"Filename": file_path,
|
||||
"Bucket": parsed_uri.bucket,
|
||||
"Key": parsed_uri.key,
|
||||
}
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
params: dict[str, Any] = {
|
||||
"Filename": file_path,
|
||||
"Bucket": parsed_uri.bucket,
|
||||
"Key": parsed_uri.key,
|
||||
}
|
||||
|
||||
if metadata:
|
||||
params["ExtraArgs"] = {"Metadata": metadata}
|
||||
if metadata:
|
||||
params["ExtraArgs"] = {"Metadata": metadata}
|
||||
|
||||
await client.upload_file(**params)
|
||||
except Exception:
|
||||
await client.upload_file(**params)
|
||||
except Exception as e:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
if raise_exception:
|
||||
raise e
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def download_file(self, uri: str, client: AioBaseClient = None, log_exception: bool = True) -> bytes | None:
|
||||
async def download_file(self, uri: str, log_exception: bool = True) -> bytes | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
|
||||
# Get full object including body
|
||||
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return await response["Body"].read()
|
||||
# Get full object including body
|
||||
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return await response["Body"].read()
|
||||
except Exception:
|
||||
if log_exception:
|
||||
LOG.exception("S3 download failed", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def get_file_metadata(
|
||||
self, uri: str, client: AioBaseClient = None, log_exception: bool = True
|
||||
self,
|
||||
uri: str,
|
||||
log_exception: bool = True,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieves only the metadata of a file without downloading its content.
|
||||
@@ -138,47 +158,47 @@ class AsyncAWSClient:
|
||||
The metadata dictionary or None if the request fails
|
||||
"""
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
|
||||
# Only get object metadata without the body
|
||||
response = await client.head_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return response.get("Metadata", {})
|
||||
# Only get object metadata without the body
|
||||
response = await client.head_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return response.get("Metadata", {})
|
||||
except Exception:
|
||||
if log_exception:
|
||||
LOG.exception("S3 metadata retrieval failed", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def create_presigned_urls(self, uris: list[str], client: AioBaseClient = None) -> list[str] | None:
|
||||
async def create_presigned_urls(self, uris: list[str]) -> list[str] | None:
|
||||
presigned_urls = []
|
||||
try:
|
||||
for uri in uris:
|
||||
parsed_uri = S3Uri(uri)
|
||||
url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
|
||||
ExpiresIn=settings.PRESIGNED_URL_EXPIRATION,
|
||||
)
|
||||
presigned_urls.append(url)
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
for uri in uris:
|
||||
parsed_uri = S3Uri(uri)
|
||||
url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
|
||||
ExpiresIn=settings.PRESIGNED_URL_EXPIRATION,
|
||||
)
|
||||
presigned_urls.append(url)
|
||||
|
||||
return presigned_urls
|
||||
return presigned_urls
|
||||
except Exception:
|
||||
LOG.exception("Failed to create presigned url for S3 objects.", uris=uris)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def list_files(self, uri: str, client: AioBaseClient = None) -> list[str]:
|
||||
async def list_files(self, uri: str) -> list[str]:
|
||||
object_keys: list[str] = []
|
||||
parsed_uri = S3Uri(uri)
|
||||
async for page in client.get_paginator("list_objects_v2").paginate(
|
||||
Bucket=parsed_uri.bucket, Prefix=parsed_uri.key
|
||||
):
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
object_keys.append(obj["Key"])
|
||||
return object_keys
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
async for page in client.get_paginator("list_objects_v2").paginate(
|
||||
Bucket=parsed_uri.bucket, Prefix=parsed_uri.key
|
||||
):
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
object_keys.append(obj["Key"])
|
||||
return object_keys
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def run_task(
|
||||
self,
|
||||
cluster: str,
|
||||
@@ -186,43 +206,40 @@ class AsyncAWSClient:
|
||||
task_definition: str,
|
||||
subnets: list[str],
|
||||
security_groups: list[str],
|
||||
client: AioBaseClient = None,
|
||||
) -> dict:
|
||||
return await client.run_task(
|
||||
cluster=cluster,
|
||||
launchType=launch_type,
|
||||
taskDefinition=task_definition,
|
||||
networkConfiguration={
|
||||
"awsvpcConfiguration": {
|
||||
"subnets": subnets,
|
||||
"securityGroups": security_groups,
|
||||
"assignPublicIp": "DISABLED",
|
||||
}
|
||||
},
|
||||
)
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.run_task(
|
||||
cluster=cluster,
|
||||
launchType=launch_type,
|
||||
taskDefinition=task_definition,
|
||||
networkConfiguration={
|
||||
"awsvpcConfiguration": {
|
||||
"subnets": subnets,
|
||||
"securityGroups": security_groups,
|
||||
"assignPublicIp": "DISABLED",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def stop_task(self, cluster: str, task: str, client: AioBaseClient = None) -> dict:
|
||||
response = await client.stop_task(cluster=cluster, task=task)
|
||||
return response
|
||||
async def stop_task(self, cluster: str, task: str, reason: str | None = None) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.stop_task(cluster=cluster, task=task, reason=reason)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def describe_tasks(self, cluster: str, tasks: list[str], client: AioBaseClient = None) -> dict:
|
||||
response = await client.describe_tasks(cluster=cluster, tasks=tasks)
|
||||
return response
|
||||
async def describe_tasks(self, cluster: str, tasks: list[str]) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.describe_tasks(cluster=cluster, tasks=tasks)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def list_tasks(self, cluster: str, client: AioBaseClient = None) -> dict:
|
||||
response = await client.list_tasks(cluster=cluster)
|
||||
return response
|
||||
async def list_tasks(self, cluster: str) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.list_tasks(cluster=cluster)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def describe_task_definition(self, task_definition: str, client: AioBaseClient = None) -> dict:
|
||||
return await client.describe_task_definition(taskDefinition=task_definition)
|
||||
async def describe_task_definition(self, task_definition: str) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.describe_task_definition(taskDefinition=task_definition)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def deregister_task_definition(self, task_definition: str, client: AioBaseClient = None) -> dict:
|
||||
return await client.deregister_task_definition(taskDefinition=task_definition)
|
||||
async def deregister_task_definition(self, task_definition: str) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.deregister_task_definition(taskDefinition=task_definition)
|
||||
|
||||
|
||||
class S3Uri(object):
|
||||
|
||||
@@ -298,6 +298,13 @@ class WorkflowRunContext:
|
||||
LOG.error(f"Failed to get Bitwarden login credentials from AWS secrets. Error: {e}")
|
||||
raise e
|
||||
|
||||
if not client_id:
|
||||
raise ValueError("Bitwarden client ID not found")
|
||||
if not client_secret:
|
||||
raise ValueError("Bitwarden client secret not found")
|
||||
if not master_password:
|
||||
raise ValueError("Bitwarden master password not found")
|
||||
|
||||
if (
|
||||
parameter.url_parameter_key
|
||||
and self.has_parameter(parameter.url_parameter_key)
|
||||
@@ -395,6 +402,13 @@ class WorkflowRunContext:
|
||||
LOG.error(f"Failed to get Bitwarden login credentials from AWS secrets. Error: {e}")
|
||||
raise e
|
||||
|
||||
if not client_id:
|
||||
raise ValueError("Bitwarden client ID not found")
|
||||
if not client_secret:
|
||||
raise ValueError("Bitwarden client secret not found")
|
||||
if not master_password:
|
||||
raise ValueError("Bitwarden master password not found")
|
||||
|
||||
bitwarden_identity_key = parameter.bitwarden_identity_key
|
||||
if self.has_parameter(parameter.bitwarden_identity_key) and self.has_value(parameter.bitwarden_identity_key):
|
||||
bitwarden_identity_key = self.values[parameter.bitwarden_identity_key]
|
||||
@@ -456,6 +470,13 @@ class WorkflowRunContext:
|
||||
LOG.error(f"Failed to get Bitwarden login credentials from AWS secrets. Error: {e}")
|
||||
raise e
|
||||
|
||||
if not client_id:
|
||||
raise ValueError("Bitwarden client ID not found")
|
||||
if not client_secret:
|
||||
raise ValueError("Bitwarden client secret not found")
|
||||
if not master_password:
|
||||
raise ValueError("Bitwarden master password not found")
|
||||
|
||||
if self.has_parameter(parameter.bitwarden_item_id) and self.has_value(parameter.bitwarden_item_id):
|
||||
item_id = self.values[parameter.bitwarden_item_id]
|
||||
else:
|
||||
|
||||
@@ -64,6 +64,7 @@ from skyvern.forge.sdk.workflow.exceptions import (
|
||||
NoIterableValueFound,
|
||||
NoValidEmailRecipient,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.constants import FileStorageType
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
PARAMETER_TYPE,
|
||||
AWSSecretParameter,
|
||||
@@ -85,6 +86,7 @@ class BlockType(StrEnum):
|
||||
TEXT_PROMPT = "text_prompt"
|
||||
DOWNLOAD_TO_S3 = "download_to_s3"
|
||||
UPLOAD_TO_S3 = "upload_to_s3"
|
||||
FILE_UPLOAD = "file_upload"
|
||||
SEND_EMAIL = "send_email"
|
||||
FILE_URL_PARSER = "file_url_parser"
|
||||
VALIDATION = "validation"
|
||||
@@ -1581,6 +1583,152 @@ class UploadToS3Block(Block):
|
||||
)
|
||||
|
||||
|
||||
class FileUploadBlock(Block):
|
||||
block_type: Literal[BlockType.FILE_UPLOAD] = BlockType.FILE_UPLOAD
|
||||
|
||||
storage_type: FileStorageType = FileStorageType.S3
|
||||
s3_bucket: str | None = None
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
region_name: str | None = None
|
||||
path: str | None = None
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
parameters = []
|
||||
|
||||
if self.path and workflow_run_context.has_parameter(self.path):
|
||||
parameters.append(workflow_run_context.get_parameter(self.path))
|
||||
|
||||
if self.s3_bucket and workflow_run_context.has_parameter(self.s3_bucket):
|
||||
parameters.append(workflow_run_context.get_parameter(self.s3_bucket))
|
||||
|
||||
if self.aws_access_key_id and workflow_run_context.has_parameter(self.aws_access_key_id):
|
||||
parameters.append(workflow_run_context.get_parameter(self.aws_access_key_id))
|
||||
|
||||
if self.aws_secret_access_key and workflow_run_context.has_parameter(self.aws_secret_access_key):
|
||||
parameters.append(workflow_run_context.get_parameter(self.aws_secret_access_key))
|
||||
|
||||
return parameters
|
||||
|
||||
def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None:
|
||||
if self.path:
|
||||
self.path = self.format_block_parameter_template_from_workflow_run_context(self.path, workflow_run_context)
|
||||
if self.s3_bucket:
|
||||
self.s3_bucket = self.format_block_parameter_template_from_workflow_run_context(
|
||||
self.s3_bucket, workflow_run_context
|
||||
)
|
||||
if self.aws_access_key_id:
|
||||
self.aws_access_key_id = self.format_block_parameter_template_from_workflow_run_context(
|
||||
self.aws_access_key_id, workflow_run_context
|
||||
)
|
||||
if self.aws_secret_access_key:
|
||||
self.aws_secret_access_key = self.format_block_parameter_template_from_workflow_run_context(
|
||||
self.aws_secret_access_key, workflow_run_context
|
||||
)
|
||||
|
||||
def _get_s3_uri(self, workflow_run_id: str, path: str) -> str:
|
||||
s3_suffix = f"{workflow_run_id}/{uuid.uuid4()}_{Path(path).name}"
|
||||
if not self.path:
|
||||
return f"s3://{self.s3_bucket}/{s3_suffix}"
|
||||
return f"s3://{self.s3_bucket}/{self.path}/{s3_suffix}"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
workflow_run_block_id: str,
|
||||
organization_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
**kwargs: dict,
|
||||
) -> BlockResult:
|
||||
# get workflow run context
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
# get all parameters into a dictionary
|
||||
# data validate before uploading
|
||||
missing_parameters = []
|
||||
if not self.s3_bucket:
|
||||
missing_parameters.append("s3_bucket")
|
||||
if not self.aws_access_key_id:
|
||||
missing_parameters.append("aws_access_key_id")
|
||||
if not self.aws_secret_access_key:
|
||||
missing_parameters.append("aws_secret_access_key")
|
||||
|
||||
if missing_parameters:
|
||||
return await self.build_block_result(
|
||||
success=False,
|
||||
failure_reason=f"Required block values are missing in the FileUploadBlock (label: {self.label}): {', '.join(missing_parameters)}",
|
||||
output_parameter_value=None,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
try:
|
||||
self.format_potential_template_parameters(workflow_run_context)
|
||||
except Exception as e:
|
||||
return await self.build_block_result(
|
||||
success=False,
|
||||
failure_reason=f"Failed to format jinja template: {str(e)}",
|
||||
output_parameter_value=None,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
download_files_path = str(get_path_for_workflow_download_directory(workflow_run_id).absolute())
|
||||
|
||||
s3_uris = []
|
||||
try:
|
||||
client = AsyncAWSClient(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
# is the file path a file or a directory?
|
||||
if os.path.isdir(download_files_path):
|
||||
# get all files in the directory, if there are more than 25 files, we will not upload them
|
||||
files = os.listdir(download_files_path)
|
||||
if len(files) > MAX_UPLOAD_FILE_COUNT:
|
||||
raise ValueError("Too many files in the directory, not uploading")
|
||||
for file in files:
|
||||
# if the file is a directory, we will not upload it
|
||||
if os.path.isdir(os.path.join(download_files_path, file)):
|
||||
LOG.warning("FileUploadBlock: Skipping directory", file=file)
|
||||
continue
|
||||
file_path = os.path.join(download_files_path, file)
|
||||
s3_uri = self._get_s3_uri(workflow_run_id, file_path)
|
||||
s3_uris.append(s3_uri)
|
||||
await client.upload_file_from_path(uri=s3_uri, file_path=file_path, raise_exception=True)
|
||||
else:
|
||||
s3_uri = self._get_s3_uri(workflow_run_id, download_files_path)
|
||||
s3_uris.append(s3_uri)
|
||||
await client.upload_file_from_path(uri=s3_uri, file_path=download_files_path, raise_exception=True)
|
||||
except Exception as e:
|
||||
LOG.exception("FileUploadBlock: Failed to upload file to S3", file_path=self.path)
|
||||
return await self.build_block_result(
|
||||
success=False,
|
||||
failure_reason=f"Failed to upload file to S3: {str(e)}",
|
||||
output_parameter_value=None,
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
LOG.info("FileUploadBlock: File(s) uploaded to S3", file_path=self.path)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, s3_uris)
|
||||
return await self.build_block_result(
|
||||
success=True,
|
||||
failure_reason=None,
|
||||
output_parameter_value=s3_uris,
|
||||
status=BlockStatus.completed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
|
||||
class SendEmailBlock(Block):
|
||||
block_type: Literal[BlockType.SEND_EMAIL] = BlockType.SEND_EMAIL
|
||||
|
||||
@@ -2348,5 +2496,6 @@ BlockSubclasses = Union[
|
||||
FileDownloadBlock,
|
||||
UrlBlock,
|
||||
TaskV2Block,
|
||||
FileUploadBlock,
|
||||
]
|
||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||
|
||||
5
skyvern/forge/sdk/workflow/models/constants.py
Normal file
5
skyvern/forge/sdk/workflow/models/constants.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class FileStorageType(StrEnum):
|
||||
S3 = "s3"
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockType, FileType
|
||||
from skyvern.forge.sdk.workflow.models.constants import FileStorageType
|
||||
from skyvern.forge.sdk.workflow.models.parameter import ParameterType, WorkflowParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowStatus
|
||||
|
||||
@@ -200,6 +201,17 @@ class UploadToS3BlockYAML(BlockYAML):
|
||||
path: str | None = None
|
||||
|
||||
|
||||
class FileUploadBlockYAML(BlockYAML):
|
||||
block_type: Literal[BlockType.FILE_UPLOAD] = BlockType.FILE_UPLOAD # type: ignore
|
||||
|
||||
storage_type: FileStorageType = FileStorageType.S3
|
||||
s3_bucket: str | None = None
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
region_name: str | None = None
|
||||
path: str | None = None
|
||||
|
||||
|
||||
class SendEmailBlockYAML(BlockYAML):
|
||||
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
||||
# Parameter 1 of Literal[...] cannot be of type "Any"
|
||||
@@ -363,6 +375,7 @@ BLOCK_YAML_SUBCLASSES = (
|
||||
| TextPromptBlockYAML
|
||||
| DownloadToS3BlockYAML
|
||||
| UploadToS3BlockYAML
|
||||
| FileUploadBlockYAML
|
||||
| SendEmailBlockYAML
|
||||
| FileParserBlockYAML
|
||||
| ValidationBlockYAML
|
||||
|
||||
@@ -45,6 +45,7 @@ from skyvern.forge.sdk.workflow.models.block import (
|
||||
ExtractionBlock,
|
||||
FileDownloadBlock,
|
||||
FileParserBlock,
|
||||
FileUploadBlock,
|
||||
ForLoopBlock,
|
||||
LoginBlock,
|
||||
NavigationBlock,
|
||||
@@ -1668,6 +1669,18 @@ class WorkflowService:
|
||||
path=block_yaml.path,
|
||||
continue_on_failure=block_yaml.continue_on_failure,
|
||||
)
|
||||
elif block_yaml.block_type == BlockType.FILE_UPLOAD:
|
||||
return FileUploadBlock(
|
||||
label=block_yaml.label,
|
||||
output_parameter=output_parameter,
|
||||
storage_type=block_yaml.storage_type,
|
||||
s3_bucket=block_yaml.s3_bucket,
|
||||
aws_access_key_id=block_yaml.aws_access_key_id,
|
||||
aws_secret_access_key=block_yaml.aws_secret_access_key,
|
||||
region_name=block_yaml.region_name,
|
||||
path=block_yaml.path,
|
||||
continue_on_failure=block_yaml.continue_on_failure,
|
||||
)
|
||||
elif block_yaml.block_type == BlockType.SEND_EMAIL:
|
||||
return SendEmailBlock(
|
||||
label=block_yaml.label,
|
||||
|
||||
Reference in New Issue
Block a user