Add the ability to add tags to s3 objects we upload + more tests for artifact upload (#2684)

This commit is contained in:
Asher Foa
2025-06-11 15:52:25 -04:00
committed by GitHub
parent 2d2146948a
commit 3100ff9543
3 changed files with 107 additions and 37 deletions

View File

@@ -57,6 +57,9 @@ class AsyncAWSClient:
def _s3_client(self) -> S3Client:
return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url)
def _create_tag_string(self, tags: dict[str, str]) -> str:
return "&".join([f"{k}={v}" for k, v in tags.items()])
async def get_secret(self, secret_name: str) -> str | None:
try:
async with self._secrets_manager_client() as client:
@@ -95,15 +98,24 @@ class AsyncAWSClient:
raise e
async def upload_file(
self, uri: str, data: bytes, storage_class: S3StorageClass = S3StorageClass.STANDARD
self,
uri: str,
data: bytes,
storage_class: S3StorageClass = S3StorageClass.STANDARD,
tags: dict[str, str] | None = None,
) -> str | None:
if storage_class not in S3StorageClass:
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
try:
async with self._s3_client() as client:
parsed_uri = S3Uri(uri)
extra_args = {"Tagging": self._create_tag_string(tags)} if tags else {}
await client.put_object(
Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key, StorageClass=str(storage_class)
Body=data,
Bucket=parsed_uri.bucket,
Key=parsed_uri.key,
StorageClass=str(storage_class),
**extra_args,
)
return uri
except Exception:
@@ -111,18 +123,25 @@ class AsyncAWSClient:
return None
async def upload_file_stream(
self, uri: str, file_obj: IO[bytes], storage_class: S3StorageClass = S3StorageClass.STANDARD
self,
uri: str,
file_obj: IO[bytes],
storage_class: S3StorageClass = S3StorageClass.STANDARD,
tags: dict[str, str] | None = None,
) -> str | None:
if storage_class not in S3StorageClass:
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
try:
async with self._s3_client() as client:
parsed_uri = S3Uri(uri)
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
if tags:
extra_args["Tagging"] = self._create_tag_string(tags)
await client.upload_fileobj(
file_obj,
parsed_uri.bucket,
parsed_uri.key,
ExtraArgs={"StorageClass": str(storage_class)},
ExtraArgs=extra_args,
)
LOG.debug("Upload file stream success", uri=uri)
return uri
@@ -137,6 +156,7 @@ class AsyncAWSClient:
storage_class: S3StorageClass = S3StorageClass.STANDARD,
metadata: dict | None = None,
raise_exception: bool = False,
tags: dict[str, str] | None = None,
) -> None:
try:
async with self._s3_client() as client:
@@ -144,6 +164,8 @@ class AsyncAWSClient:
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
if metadata:
extra_args["Metadata"] = metadata
if tags:
extra_args["Tagging"] = self._create_tag_string(tags)
await client.upload_file(
Filename=file_path,
Bucket=parsed_uri.bucket,

View File

@@ -83,18 +83,23 @@ class S3Storage(BaseStorage):
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
sc = await self._get_storage_class_for_org(artifact.organization_id)
tags = await self._get_tags_for_org(artifact.organization_id)
LOG.debug(
"Storing artifact",
artifact_id=artifact.artifact_id,
organization_id=artifact.organization_id,
uri=artifact.uri,
storage_class=sc,
tags=tags,
)
await self.async_client.upload_file(artifact.uri, data, storage_class=sc)
await self.async_client.upload_file(artifact.uri, data, storage_class=sc, tags=tags)
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
return S3StorageClass.STANDARD
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
return {}
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await self.async_client.download_file(artifact.uri)
@@ -107,6 +112,7 @@ class S3Storage(BaseStorage):
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
sc = await self._get_storage_class_for_org(artifact.organization_id)
tags = await self._get_tags_for_org(artifact.organization_id)
LOG.debug(
"Storing artifact from path",
artifact_id=artifact.artifact_id,
@@ -114,13 +120,15 @@ class S3Storage(BaseStorage):
uri=artifact.uri,
storage_class=sc,
path=path,
tags=tags,
)
await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc)
await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc, tags=tags)
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}"
to_path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
LOG.debug(
"Saving streaming file",
organization_id=organization_id,
@@ -128,8 +136,9 @@ class S3Storage(BaseStorage):
from_path=from_path,
to_path=to_path,
storage_class=sc,
tags=tags,
)
await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc)
await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc, tags=tags)
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
@@ -141,6 +150,7 @@ class S3Storage(BaseStorage):
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
LOG.debug(
"Storing browser session",
organization_id=organization_id,
@@ -148,8 +158,9 @@ class S3Storage(BaseStorage):
zip_file_path=zip_file_path,
browser_session_uri=browser_session_uri,
storage_class=sc,
tags=tags,
)
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc)
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc, tags=tags)
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
@@ -171,27 +182,30 @@ class S3Storage(BaseStorage):
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
files = os.listdir(download_dir)
sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
for file in files:
fpath = os.path.join(download_dir, file)
if os.path.isfile(fpath):
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
if not os.path.isfile(fpath):
continue
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
# Calculate SHA-256 checksum
checksum = calculate_sha256_for_file(fpath)
LOG.info(
"Calculated checksum for file",
file=file,
checksum=checksum,
organization_id=organization_id,
storage_class=sc,
)
# Upload file with checksum metadata
await self.async_client.upload_file_from_path(
uri=uri,
file_path=fpath,
metadata={"sha256_checksum": checksum, "original_filename": file},
storage_class=sc,
)
# Calculate SHA-256 checksum
checksum = calculate_sha256_for_file(fpath)
LOG.info(
"Calculated checksum for file",
file=file,
checksum=checksum,
organization_id=organization_id,
storage_class=sc,
)
# Upload file with checksum metadata
await self.async_client.upload_file_from_path(
uri=uri,
file_path=fpath,
metadata={"sha256_checksum": checksum, "original_filename": file},
storage_class=sc,
tags=tags,
)
async def get_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
@@ -232,11 +246,12 @@ class S3Storage(BaseStorage):
todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
bucket = settings.AWS_S3_BUCKET_UPLOADS
sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
# First try uploading with original filename
try:
sanitized_filename = os.path.basename(filename) # Remove any path components
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}"
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc)
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
except Exception:
LOG.error("Failed to upload file to S3", exc_info=True)
uploaded_s3_uri = None
@@ -246,7 +261,7 @@ class S3Storage(BaseStorage):
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}"
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}"
fileObj.seek(0) # Reset file pointer
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc)
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
if not uploaded_s3_uri:
LOG.error(

View File

@@ -6,9 +6,10 @@ import boto3
import pytest
from freezegun import freeze_time
from moto.server import ThreadedMotoServer
from types_boto3_s3.client import S3Client
from skyvern.config import settings
from skyvern.forge.sdk.api.aws import S3Uri
from skyvern.forge.sdk.api.aws import S3StorageClass, S3Uri
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
from skyvern.forge.sdk.artifact.storage.test_helpers import (
@@ -30,9 +31,17 @@ TEST_BLOCK_ID = "block_123456789"
TEST_AI_SUGGESTION_ID = "ai_sugg_test_123"
class S3StorageForTests(S3Storage):
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
return {"dummy": f"org-{organization_id}", "test": "jerry"}
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
return S3StorageClass.ONEZONE_IA
@pytest.fixture
def s3_storage(moto_server: str) -> S3Storage:
return S3Storage(bucket=TEST_BUCKET, endpoint_url=moto_server)
return S3StorageForTests(bucket=TEST_BUCKET, endpoint_url=moto_server)
@pytest.fixture(autouse=True)
@@ -53,7 +62,7 @@ def moto_server() -> Generator[str, None, None]:
@pytest.fixture(scope="module", autouse=True)
def boto3_test_client(moto_server: str) -> boto3.client:
def boto3_test_client(moto_server: str) -> Generator[S3Client, None, None]:
client = boto3.client(
"s3",
aws_access_key_id="testing",
@@ -145,6 +154,23 @@ class TestS3StorageBuildURIs:
)
def _assert_object_meta(boto3_test_client: S3Client, uri: str) -> None:
s3uri = S3Uri(uri)
assert s3uri.bucket == TEST_BUCKET
obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key)
assert obj_meta["StorageClass"] == "ONEZONE_IA"
s3_tags_resp = boto3_test_client.get_object_tagging(Bucket=TEST_BUCKET, Key=s3uri.key)
tags_dict = {tag["Key"]: tag["Value"] for tag in s3_tags_resp["TagSet"]}
assert tags_dict == {"dummy": f"org-{TEST_ORGANIZATION_ID}", "test": "jerry"}
def _assert_object_content(boto3_test_client: S3Client, uri: str, expected_content: bytes) -> None:
s3uri = S3Uri(uri)
assert s3uri.bucket == TEST_BUCKET
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
assert obj_response["Body"].read() == expected_content
@pytest.mark.asyncio
class TestS3StorageStore:
"""Test S3Storage store methods."""
@@ -174,17 +200,24 @@ class TestS3StorageStore:
modified_at=datetime.utcnow(),
)
async def test_store_artifact_screenshot(
self, s3_storage: S3Storage, boto3_test_client: boto3.client, tmp_path: Path
async def test_store_artifact_from_path(
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
) -> None:
test_data = b"fake screenshot data"
artifact = self._create_artifact_for_ai_suggestion(
s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID
)
s3uri = S3Uri(artifact.uri)
assert s3uri.bucket == TEST_BUCKET
test_file = tmp_path / "test_screenshot.png"
test_file.write_bytes(test_data)
await s3_storage.store_artifact_from_path(artifact, str(test_file))
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
assert obj_response["Body"].read() == test_data
_assert_object_content(boto3_test_client, artifact.uri, test_data)
_assert_object_meta(boto3_test_client, artifact.uri)
async def test_store_artifact(self, s3_storage: S3Storage, boto3_test_client: S3Client) -> None:
test_data = b"fake artifact data"
artifact = self._create_artifact_for_ai_suggestion(s3_storage, ArtifactType.LLM_PROMPT, TEST_AI_SUGGESTION_ID)
await s3_storage.store_artifact(artifact, test_data)
_assert_object_content(boto3_test_client, artifact.uri, test_data)
_assert_object_meta(boto3_test_client, artifact.uri)