From 889c8f1963f91564f506af67220629d836b6d77b Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sun, 2 Mar 2025 02:00:41 -0500 Subject: [PATCH] aws_client ECS tasks (#1861) --- skyvern/forge/sdk/api/aws.py | 47 ++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/skyvern/forge/sdk/api/aws.py b/skyvern/forge/sdk/api/aws.py index 045e20df..e1de25a1 100644 --- a/skyvern/forge/sdk/api/aws.py +++ b/skyvern/forge/sdk/api/aws.py @@ -14,6 +14,7 @@ LOG = structlog.get_logger() class AWSClientType(StrEnum): S3 = "s3" SECRETS_MANAGER = "secretsmanager" + ECS = "ecs" def execute_with_async_client(client_type: AWSClientType) -> Callable: @@ -177,6 +178,52 @@ class AsyncAWSClient: object_keys.append(obj["Key"]) return object_keys + @execute_with_async_client(client_type=AWSClientType.ECS) + async def run_task( + self, + cluster: str, + launch_type: str, + task_definition: str, + subnets: list[str], + security_groups: list[str], + client: AioBaseClient = None, + ) -> dict: + return await client.run_task( + cluster=cluster, + launchType=launch_type, + taskDefinition=task_definition, + networkConfiguration={ + "awsvpcConfiguration": { + "subnets": subnets, + "securityGroups": security_groups, + "assignPublicIp": "DISABLED", + } + }, + ) + + @execute_with_async_client(client_type=AWSClientType.ECS) + async def stop_task(self, cluster: str, task: str, client: AioBaseClient = None) -> dict: + response = await client.stop_task(cluster=cluster, task=task) + return response + + @execute_with_async_client(client_type=AWSClientType.ECS) + async def describe_tasks(self, cluster: str, tasks: list[str], client: AioBaseClient = None) -> dict: + response = await client.describe_tasks(cluster=cluster, tasks=tasks) + return response + + @execute_with_async_client(client_type=AWSClientType.ECS) + async def list_tasks(self, cluster: str, client: AioBaseClient = None) -> dict: + response = await client.list_tasks(cluster=cluster) + return response + + @execute_with_async_client(client_type=AWSClientType.ECS) + async def describe_task_definition(self, task_definition: str, client: AioBaseClient = None) -> dict: + return await client.describe_task_definition(taskDefinition=task_definition) + + @execute_with_async_client(client_type=AWSClientType.ECS) + async def deregister_task_definition(self, task_definition: str, client: AioBaseClient = None) -> dict: + return await client.deregister_task_definition(taskDefinition=task_definition) + class S3Uri(object): # From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path