DRY getting boto3 clients (#2622)
This commit is contained in:
487
poetry.lock
generated
487
poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -100,6 +100,7 @@ pandas = "^2.2.3"
|
|||||||
pre-commit = "^4.2.0"
|
pre-commit = "^4.2.0"
|
||||||
ruff = "^0.11.12"
|
ruff = "^0.11.12"
|
||||||
aiosqlite = "^0.21.0"
|
aiosqlite = "^0.21.0"
|
||||||
|
types-boto3 = {extras = ["full"], version = "^1.38.31"}
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.0.0"]
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
import aioboto3
|
import aioboto3
|
||||||
import structlog
|
import structlog
|
||||||
|
from types_boto3_ecs.client import ECSClient
|
||||||
|
from types_boto3_s3.client import S3Client
|
||||||
|
from types_boto3_secretsmanager.client import SecretsManagerClient
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
|
|
||||||
@@ -41,9 +44,18 @@ class AsyncAWSClient:
|
|||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _ecs_client(self) -> ECSClient:
|
||||||
|
return self.session.client(AWSClientType.ECS, region_name=self.region_name)
|
||||||
|
|
||||||
|
def _secrets_manager_client(self) -> SecretsManagerClient:
|
||||||
|
return self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name)
|
||||||
|
|
||||||
|
def _s3_client(self) -> S3Client:
|
||||||
|
return self.session.client(AWSClientType.S3, region_name=self.region_name)
|
||||||
|
|
||||||
async def get_secret(self, secret_name: str) -> str | None:
|
async def get_secret(self, secret_name: str) -> str | None:
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
async with self._secrets_manager_client() as client:
|
||||||
response = await client.get_secret_value(SecretId=secret_name)
|
response = await client.get_secret_value(SecretId=secret_name)
|
||||||
return response["SecretString"]
|
return response["SecretString"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -56,7 +68,7 @@ class AsyncAWSClient:
|
|||||||
|
|
||||||
async def create_secret(self, secret_name: str, secret_value: str) -> None:
|
async def create_secret(self, secret_name: str, secret_value: str) -> None:
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
async with self._secrets_manager_client() as client:
|
||||||
await client.create_secret(Name=secret_name, SecretString=secret_value)
|
await client.create_secret(Name=secret_name, SecretString=secret_value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.exception("Failed to create secret.", secret_name=secret_name)
|
LOG.exception("Failed to create secret.", secret_name=secret_name)
|
||||||
@@ -64,7 +76,7 @@ class AsyncAWSClient:
|
|||||||
|
|
||||||
async def set_secret(self, secret_name: str, secret_value: str) -> None:
|
async def set_secret(self, secret_name: str, secret_value: str) -> None:
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
async with self._secrets_manager_client() as client:
|
||||||
await client.put_secret_value(SecretId=secret_name, SecretString=secret_value)
|
await client.put_secret_value(SecretId=secret_name, SecretString=secret_value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.exception("Failed to set secret.", secret_name=secret_name)
|
LOG.exception("Failed to set secret.", secret_name=secret_name)
|
||||||
@@ -72,7 +84,7 @@ class AsyncAWSClient:
|
|||||||
|
|
||||||
async def delete_secret(self, secret_name: str) -> None:
|
async def delete_secret(self, secret_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name) as client:
|
async with self._secrets_manager_client() as client:
|
||||||
await client.delete_secret(SecretId=secret_name)
|
await client.delete_secret(SecretId=secret_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.exception("Failed to delete secret.", secret_name=secret_name)
|
LOG.exception("Failed to delete secret.", secret_name=secret_name)
|
||||||
@@ -84,7 +96,7 @@ class AsyncAWSClient:
|
|||||||
if storage_class not in S3StorageClass:
|
if storage_class not in S3StorageClass:
|
||||||
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
async with self._s3_client() as client:
|
||||||
parsed_uri = S3Uri(uri)
|
parsed_uri = S3Uri(uri)
|
||||||
await client.put_object(
|
await client.put_object(
|
||||||
Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key, StorageClass=str(storage_class)
|
Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key, StorageClass=str(storage_class)
|
||||||
@@ -100,7 +112,7 @@ class AsyncAWSClient:
|
|||||||
if storage_class not in S3StorageClass:
|
if storage_class not in S3StorageClass:
|
||||||
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
async with self._s3_client() as client:
|
||||||
parsed_uri = S3Uri(uri)
|
parsed_uri = S3Uri(uri)
|
||||||
await client.upload_fileobj(
|
await client.upload_fileobj(
|
||||||
file_obj,
|
file_obj,
|
||||||
@@ -123,7 +135,7 @@ class AsyncAWSClient:
|
|||||||
raise_exception: bool = False,
|
raise_exception: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
async with self._s3_client() as client:
|
||||||
parsed_uri = S3Uri(uri)
|
parsed_uri = S3Uri(uri)
|
||||||
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
|
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
|
||||||
if metadata:
|
if metadata:
|
||||||
@@ -141,7 +153,7 @@ class AsyncAWSClient:
|
|||||||
|
|
||||||
async def download_file(self, uri: str, log_exception: bool = True) -> bytes | None:
|
async def download_file(self, uri: str, log_exception: bool = True) -> bytes | None:
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
async with self._s3_client() as client:
|
||||||
parsed_uri = S3Uri(uri)
|
parsed_uri = S3Uri(uri)
|
||||||
|
|
||||||
# Get full object including body
|
# Get full object including body
|
||||||
@@ -169,7 +181,7 @@ class AsyncAWSClient:
|
|||||||
The metadata dictionary or None if the request fails
|
The metadata dictionary or None if the request fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
async with self._s3_client() as client:
|
||||||
parsed_uri = S3Uri(uri)
|
parsed_uri = S3Uri(uri)
|
||||||
|
|
||||||
# Only get object metadata without the body
|
# Only get object metadata without the body
|
||||||
@@ -183,7 +195,7 @@ class AsyncAWSClient:
|
|||||||
async def create_presigned_urls(self, uris: list[str]) -> list[str] | None:
|
async def create_presigned_urls(self, uris: list[str]) -> list[str] | None:
|
||||||
presigned_urls = []
|
presigned_urls = []
|
||||||
try:
|
try:
|
||||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
async with self._s3_client() as client:
|
||||||
for uri in uris:
|
for uri in uris:
|
||||||
parsed_uri = S3Uri(uri)
|
parsed_uri = S3Uri(uri)
|
||||||
url = await client.generate_presigned_url(
|
url = await client.generate_presigned_url(
|
||||||
@@ -201,7 +213,7 @@ class AsyncAWSClient:
|
|||||||
async def list_files(self, uri: str) -> list[str]:
|
async def list_files(self, uri: str) -> list[str]:
|
||||||
object_keys: list[str] = []
|
object_keys: list[str] = []
|
||||||
parsed_uri = S3Uri(uri)
|
parsed_uri = S3Uri(uri)
|
||||||
async with self.session.client(AWSClientType.S3, region_name=self.region_name) as client:
|
async with self._s3_client() as client:
|
||||||
async for page in client.get_paginator("list_objects_v2").paginate(
|
async for page in client.get_paginator("list_objects_v2").paginate(
|
||||||
Bucket=parsed_uri.bucket, Prefix=parsed_uri.key
|
Bucket=parsed_uri.bucket, Prefix=parsed_uri.key
|
||||||
):
|
):
|
||||||
@@ -218,7 +230,7 @@ class AsyncAWSClient:
|
|||||||
subnets: list[str],
|
subnets: list[str],
|
||||||
security_groups: list[str],
|
security_groups: list[str],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
async with self._ecs_client() as client:
|
||||||
return await client.run_task(
|
return await client.run_task(
|
||||||
cluster=cluster,
|
cluster=cluster,
|
||||||
launchType=launch_type,
|
launchType=launch_type,
|
||||||
@@ -233,23 +245,23 @@ class AsyncAWSClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def stop_task(self, cluster: str, task: str, reason: str | None = None) -> dict:
|
async def stop_task(self, cluster: str, task: str, reason: str | None = None) -> dict:
|
||||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
async with self._ecs_client() as client:
|
||||||
return await client.stop_task(cluster=cluster, task=task, reason=reason)
|
return await client.stop_task(cluster=cluster, task=task, reason=reason)
|
||||||
|
|
||||||
async def describe_tasks(self, cluster: str, tasks: list[str]) -> dict:
|
async def describe_tasks(self, cluster: str, tasks: list[str]) -> dict:
|
||||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
async with self._ecs_client() as client:
|
||||||
return await client.describe_tasks(cluster=cluster, tasks=tasks)
|
return await client.describe_tasks(cluster=cluster, tasks=tasks)
|
||||||
|
|
||||||
async def list_tasks(self, cluster: str) -> dict:
|
async def list_tasks(self, cluster: str) -> dict:
|
||||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
async with self._ecs_client() as client:
|
||||||
return await client.list_tasks(cluster=cluster)
|
return await client.list_tasks(cluster=cluster)
|
||||||
|
|
||||||
async def describe_task_definition(self, task_definition: str) -> dict:
|
async def describe_task_definition(self, task_definition: str) -> dict:
|
||||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
async with self._ecs_client() as client:
|
||||||
return await client.describe_task_definition(taskDefinition=task_definition)
|
return await client.describe_task_definition(taskDefinition=task_definition)
|
||||||
|
|
||||||
async def deregister_task_definition(self, task_definition: str) -> dict:
|
async def deregister_task_definition(self, task_definition: str) -> dict:
|
||||||
async with self.session.client(AWSClientType.ECS, region_name=self.region_name) as client:
|
async with self._ecs_client() as client:
|
||||||
return await client.deregister_task_definition(taskDefinition=task_definition)
|
return await client.deregister_task_definition(taskDefinition=task_definition)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user