Add the ability to add tags to s3 objects we upload + more tests for artifact upload (#2684)
This commit is contained in:
@@ -57,6 +57,9 @@ class AsyncAWSClient:
|
||||
def _s3_client(self) -> S3Client:
|
||||
return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url)
|
||||
|
||||
def _create_tag_string(self, tags: dict[str, str]) -> str:
|
||||
return "&".join([f"{k}={v}" for k, v in tags.items()])
|
||||
|
||||
async def get_secret(self, secret_name: str) -> str | None:
|
||||
try:
|
||||
async with self._secrets_manager_client() as client:
|
||||
@@ -95,15 +98,24 @@ class AsyncAWSClient:
|
||||
raise e
|
||||
|
||||
async def upload_file(
|
||||
self, uri: str, data: bytes, storage_class: S3StorageClass = S3StorageClass.STANDARD
|
||||
self,
|
||||
uri: str,
|
||||
data: bytes,
|
||||
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> str | None:
|
||||
if storage_class not in S3StorageClass:
|
||||
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
||||
try:
|
||||
async with self._s3_client() as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
extra_args = {"Tagging": self._create_tag_string(tags)} if tags else {}
|
||||
await client.put_object(
|
||||
Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key, StorageClass=str(storage_class)
|
||||
Body=data,
|
||||
Bucket=parsed_uri.bucket,
|
||||
Key=parsed_uri.key,
|
||||
StorageClass=str(storage_class),
|
||||
**extra_args,
|
||||
)
|
||||
return uri
|
||||
except Exception:
|
||||
@@ -111,18 +123,25 @@ class AsyncAWSClient:
|
||||
return None
|
||||
|
||||
async def upload_file_stream(
|
||||
self, uri: str, file_obj: IO[bytes], storage_class: S3StorageClass = S3StorageClass.STANDARD
|
||||
self,
|
||||
uri: str,
|
||||
file_obj: IO[bytes],
|
||||
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> str | None:
|
||||
if storage_class not in S3StorageClass:
|
||||
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
||||
try:
|
||||
async with self._s3_client() as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
|
||||
if tags:
|
||||
extra_args["Tagging"] = self._create_tag_string(tags)
|
||||
await client.upload_fileobj(
|
||||
file_obj,
|
||||
parsed_uri.bucket,
|
||||
parsed_uri.key,
|
||||
ExtraArgs={"StorageClass": str(storage_class)},
|
||||
ExtraArgs=extra_args,
|
||||
)
|
||||
LOG.debug("Upload file stream success", uri=uri)
|
||||
return uri
|
||||
@@ -137,6 +156,7 @@ class AsyncAWSClient:
|
||||
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
||||
metadata: dict | None = None,
|
||||
raise_exception: bool = False,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
async with self._s3_client() as client:
|
||||
@@ -144,6 +164,8 @@ class AsyncAWSClient:
|
||||
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
|
||||
if metadata:
|
||||
extra_args["Metadata"] = metadata
|
||||
if tags:
|
||||
extra_args["Tagging"] = self._create_tag_string(tags)
|
||||
await client.upload_file(
|
||||
Filename=file_path,
|
||||
Bucket=parsed_uri.bucket,
|
||||
|
||||
Reference in New Issue
Block a user