file upload block backend (#2000)
This commit is contained in:
@@ -4,7 +4,6 @@ from urllib.parse import urlparse
|
||||
|
||||
import aioboto3
|
||||
import structlog
|
||||
from aiobotocore.client import AioBaseClient
|
||||
|
||||
from skyvern.config import settings
|
||||
|
||||
@@ -32,11 +31,25 @@ def execute_with_async_client(client_type: AWSClientType) -> Callable:
|
||||
|
||||
|
||||
class AsyncAWSClient:
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def get_secret(self, secret_name: str, client: AioBaseClient = None) -> str | None:
|
||||
def __init__(
|
||||
self,
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
region_name: str | None = None,
|
||||
) -> None:
|
||||
self.aws_access_key_id = aws_access_key_id
|
||||
self.aws_secret_access_key = aws_secret_access_key
|
||||
self.region_name = region_name or settings.AWS_REGION
|
||||
self.session = aioboto3.Session(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
)
|
||||
|
||||
async def get_secret(self, secret_name: str) -> str | None:
|
||||
try:
|
||||
response = await client.get_secret_value(SecretId=secret_name)
|
||||
return response["SecretString"]
|
||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
||||
response = await client.get_secret_value(SecretId=secret_name)
|
||||
return response["SecretString"]
|
||||
except Exception as e:
|
||||
try:
|
||||
error_code = e.response["Error"]["Code"] # type: ignore
|
||||
@@ -45,86 +58,93 @@ class AsyncAWSClient:
|
||||
LOG.exception("Failed to get secret.", secret_name=secret_name, error_code=error_code)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def create_secret(self, secret_name: str, secret_value: str, client: AioBaseClient = None) -> None:
|
||||
async def create_secret(self, secret_name: str, secret_value: str) -> None:
|
||||
try:
|
||||
await client.create_secret(Name=secret_name, SecretString=secret_value)
|
||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
||||
await client.create_secret(Name=secret_name, SecretString=secret_value)
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to create secret.", secret_name=secret_name)
|
||||
raise e
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def set_secret(self, secret_name: str, secret_value: str, client: AioBaseClient = None) -> None:
|
||||
async def set_secret(self, secret_name: str, secret_value: str) -> None:
|
||||
try:
|
||||
await client.put_secret_value(SecretId=secret_name, SecretString=secret_value)
|
||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
||||
await client.put_secret_value(SecretId=secret_name, SecretString=secret_value)
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to set secret.", secret_name=secret_name)
|
||||
raise e
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.SECRETS_MANAGER)
|
||||
async def delete_secret(self, secret_name: str, client: AioBaseClient = None) -> None:
|
||||
async def delete_secret(self, secret_name: str) -> None:
|
||||
try:
|
||||
await client.delete_secret(SecretId=secret_name)
|
||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
||||
await client.delete_secret(SecretId=secret_name)
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to delete secret.", secret_name=secret_name)
|
||||
raise e
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file(self, uri: str, data: bytes, client: AioBaseClient = None) -> str | None:
|
||||
async def upload_file(self, uri: str, data: bytes) -> str | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.put_object(Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return uri
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.put_object(Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return uri
|
||||
except Exception:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file_stream(self, uri: str, file_obj: IO[bytes], client: AioBaseClient = None) -> str | None:
|
||||
async def upload_file_stream(self, uri: str, file_obj: IO[bytes]) -> str | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.upload_fileobj(file_obj, parsed_uri.bucket, parsed_uri.key)
|
||||
LOG.debug("Upload file stream success", uri=uri)
|
||||
return uri
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
await client.upload_fileobj(file_obj, parsed_uri.bucket, parsed_uri.key)
|
||||
LOG.debug("Upload file stream success", uri=uri)
|
||||
return uri
|
||||
except Exception:
|
||||
LOG.exception("S3 upload stream failed.", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def upload_file_from_path(
|
||||
self, uri: str, file_path: str, client: AioBaseClient = None, metadata: dict | None = None
|
||||
self,
|
||||
uri: str,
|
||||
file_path: str,
|
||||
metadata: dict | None = None,
|
||||
raise_exception: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
params: dict[str, Any] = {
|
||||
"Filename": file_path,
|
||||
"Bucket": parsed_uri.bucket,
|
||||
"Key": parsed_uri.key,
|
||||
}
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
params: dict[str, Any] = {
|
||||
"Filename": file_path,
|
||||
"Bucket": parsed_uri.bucket,
|
||||
"Key": parsed_uri.key,
|
||||
}
|
||||
|
||||
if metadata:
|
||||
params["ExtraArgs"] = {"Metadata": metadata}
|
||||
if metadata:
|
||||
params["ExtraArgs"] = {"Metadata": metadata}
|
||||
|
||||
await client.upload_file(**params)
|
||||
except Exception:
|
||||
await client.upload_file(**params)
|
||||
except Exception as e:
|
||||
LOG.exception("S3 upload failed.", uri=uri)
|
||||
if raise_exception:
|
||||
raise e
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def download_file(self, uri: str, client: AioBaseClient = None, log_exception: bool = True) -> bytes | None:
|
||||
async def download_file(self, uri: str, log_exception: bool = True) -> bytes | None:
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
|
||||
# Get full object including body
|
||||
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return await response["Body"].read()
|
||||
# Get full object including body
|
||||
response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return await response["Body"].read()
|
||||
except Exception:
|
||||
if log_exception:
|
||||
LOG.exception("S3 download failed", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def get_file_metadata(
|
||||
self, uri: str, client: AioBaseClient = None, log_exception: bool = True
|
||||
self,
|
||||
uri: str,
|
||||
log_exception: bool = True,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieves only the metadata of a file without downloading its content.
|
||||
@@ -138,47 +158,47 @@ class AsyncAWSClient:
|
||||
The metadata dictionary or None if the request fails
|
||||
"""
|
||||
try:
|
||||
parsed_uri = S3Uri(uri)
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
|
||||
# Only get object metadata without the body
|
||||
response = await client.head_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return response.get("Metadata", {})
|
||||
# Only get object metadata without the body
|
||||
response = await client.head_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
|
||||
return response.get("Metadata", {})
|
||||
except Exception:
|
||||
if log_exception:
|
||||
LOG.exception("S3 metadata retrieval failed", uri=uri)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def create_presigned_urls(self, uris: list[str], client: AioBaseClient = None) -> list[str] | None:
|
||||
async def create_presigned_urls(self, uris: list[str]) -> list[str] | None:
|
||||
presigned_urls = []
|
||||
try:
|
||||
for uri in uris:
|
||||
parsed_uri = S3Uri(uri)
|
||||
url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
|
||||
ExpiresIn=settings.PRESIGNED_URL_EXPIRATION,
|
||||
)
|
||||
presigned_urls.append(url)
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
for uri in uris:
|
||||
parsed_uri = S3Uri(uri)
|
||||
url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
|
||||
ExpiresIn=settings.PRESIGNED_URL_EXPIRATION,
|
||||
)
|
||||
presigned_urls.append(url)
|
||||
|
||||
return presigned_urls
|
||||
return presigned_urls
|
||||
except Exception:
|
||||
LOG.exception("Failed to create presigned url for S3 objects.", uris=uris)
|
||||
return None
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.S3)
|
||||
async def list_files(self, uri: str, client: AioBaseClient = None) -> list[str]:
|
||||
async def list_files(self, uri: str) -> list[str]:
|
||||
object_keys: list[str] = []
|
||||
parsed_uri = S3Uri(uri)
|
||||
async for page in client.get_paginator("list_objects_v2").paginate(
|
||||
Bucket=parsed_uri.bucket, Prefix=parsed_uri.key
|
||||
):
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
object_keys.append(obj["Key"])
|
||||
return object_keys
|
||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
||||
async for page in client.get_paginator("list_objects_v2").paginate(
|
||||
Bucket=parsed_uri.bucket, Prefix=parsed_uri.key
|
||||
):
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
object_keys.append(obj["Key"])
|
||||
return object_keys
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def run_task(
|
||||
self,
|
||||
cluster: str,
|
||||
@@ -186,43 +206,40 @@ class AsyncAWSClient:
|
||||
task_definition: str,
|
||||
subnets: list[str],
|
||||
security_groups: list[str],
|
||||
client: AioBaseClient = None,
|
||||
) -> dict:
|
||||
return await client.run_task(
|
||||
cluster=cluster,
|
||||
launchType=launch_type,
|
||||
taskDefinition=task_definition,
|
||||
networkConfiguration={
|
||||
"awsvpcConfiguration": {
|
||||
"subnets": subnets,
|
||||
"securityGroups": security_groups,
|
||||
"assignPublicIp": "DISABLED",
|
||||
}
|
||||
},
|
||||
)
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.run_task(
|
||||
cluster=cluster,
|
||||
launchType=launch_type,
|
||||
taskDefinition=task_definition,
|
||||
networkConfiguration={
|
||||
"awsvpcConfiguration": {
|
||||
"subnets": subnets,
|
||||
"securityGroups": security_groups,
|
||||
"assignPublicIp": "DISABLED",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def stop_task(self, cluster: str, task: str, client: AioBaseClient = None) -> dict:
|
||||
response = await client.stop_task(cluster=cluster, task=task)
|
||||
return response
|
||||
async def stop_task(self, cluster: str, task: str, reason: str | None = None) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.stop_task(cluster=cluster, task=task, reason=reason)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def describe_tasks(self, cluster: str, tasks: list[str], client: AioBaseClient = None) -> dict:
|
||||
response = await client.describe_tasks(cluster=cluster, tasks=tasks)
|
||||
return response
|
||||
async def describe_tasks(self, cluster: str, tasks: list[str]) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.describe_tasks(cluster=cluster, tasks=tasks)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def list_tasks(self, cluster: str, client: AioBaseClient = None) -> dict:
|
||||
response = await client.list_tasks(cluster=cluster)
|
||||
return response
|
||||
async def list_tasks(self, cluster: str) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.list_tasks(cluster=cluster)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def describe_task_definition(self, task_definition: str, client: AioBaseClient = None) -> dict:
|
||||
return await client.describe_task_definition(taskDefinition=task_definition)
|
||||
async def describe_task_definition(self, task_definition: str) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.describe_task_definition(taskDefinition=task_definition)
|
||||
|
||||
@execute_with_async_client(client_type=AWSClientType.ECS)
|
||||
async def deregister_task_definition(self, task_definition: str, client: AioBaseClient = None) -> dict:
|
||||
return await client.deregister_task_definition(taskDefinition=task_definition)
|
||||
async def deregister_task_definition(self, task_definition: str) -> dict:
|
||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
||||
return await client.deregister_task_definition(taskDefinition=task_definition)
|
||||
|
||||
|
||||
class S3Uri(object):
|
||||
|
||||
Reference in New Issue
Block a user