Add the ability to add tags to s3 objects we upload + more tests for artifact upload (#2684)

This commit is contained in:
Asher Foa
2025-06-11 15:52:25 -04:00
committed by GitHub
parent 2d2146948a
commit 3100ff9543
3 changed files with 107 additions and 37 deletions

View File

@@ -57,6 +57,9 @@ class AsyncAWSClient:
def _s3_client(self) -> S3Client: def _s3_client(self) -> S3Client:
return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url) return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url)
def _create_tag_string(self, tags: dict[str, str]) -> str:
return "&".join([f"{k}={v}" for k, v in tags.items()])
async def get_secret(self, secret_name: str) -> str | None: async def get_secret(self, secret_name: str) -> str | None:
try: try:
async with self._secrets_manager_client() as client: async with self._secrets_manager_client() as client:
@@ -95,15 +98,24 @@ class AsyncAWSClient:
raise e raise e
async def upload_file( async def upload_file(
self, uri: str, data: bytes, storage_class: S3StorageClass = S3StorageClass.STANDARD self,
uri: str,
data: bytes,
storage_class: S3StorageClass = S3StorageClass.STANDARD,
tags: dict[str, str] | None = None,
) -> str | None: ) -> str | None:
if storage_class not in S3StorageClass: if storage_class not in S3StorageClass:
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}") raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
try: try:
async with self._s3_client() as client: async with self._s3_client() as client:
parsed_uri = S3Uri(uri) parsed_uri = S3Uri(uri)
extra_args = {"Tagging": self._create_tag_string(tags)} if tags else {}
await client.put_object( await client.put_object(
Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key, StorageClass=str(storage_class) Body=data,
Bucket=parsed_uri.bucket,
Key=parsed_uri.key,
StorageClass=str(storage_class),
**extra_args,
) )
return uri return uri
except Exception: except Exception:
@@ -111,18 +123,25 @@ class AsyncAWSClient:
return None return None
async def upload_file_stream( async def upload_file_stream(
self, uri: str, file_obj: IO[bytes], storage_class: S3StorageClass = S3StorageClass.STANDARD self,
uri: str,
file_obj: IO[bytes],
storage_class: S3StorageClass = S3StorageClass.STANDARD,
tags: dict[str, str] | None = None,
) -> str | None: ) -> str | None:
if storage_class not in S3StorageClass: if storage_class not in S3StorageClass:
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}") raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
try: try:
async with self._s3_client() as client: async with self._s3_client() as client:
parsed_uri = S3Uri(uri) parsed_uri = S3Uri(uri)
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
if tags:
extra_args["Tagging"] = self._create_tag_string(tags)
await client.upload_fileobj( await client.upload_fileobj(
file_obj, file_obj,
parsed_uri.bucket, parsed_uri.bucket,
parsed_uri.key, parsed_uri.key,
ExtraArgs={"StorageClass": str(storage_class)}, ExtraArgs=extra_args,
) )
LOG.debug("Upload file stream success", uri=uri) LOG.debug("Upload file stream success", uri=uri)
return uri return uri
@@ -137,6 +156,7 @@ class AsyncAWSClient:
storage_class: S3StorageClass = S3StorageClass.STANDARD, storage_class: S3StorageClass = S3StorageClass.STANDARD,
metadata: dict | None = None, metadata: dict | None = None,
raise_exception: bool = False, raise_exception: bool = False,
tags: dict[str, str] | None = None,
) -> None: ) -> None:
try: try:
async with self._s3_client() as client: async with self._s3_client() as client:
@@ -144,6 +164,8 @@ class AsyncAWSClient:
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)} extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
if metadata: if metadata:
extra_args["Metadata"] = metadata extra_args["Metadata"] = metadata
if tags:
extra_args["Tagging"] = self._create_tag_string(tags)
await client.upload_file( await client.upload_file(
Filename=file_path, Filename=file_path,
Bucket=parsed_uri.bucket, Bucket=parsed_uri.bucket,

View File

@@ -83,18 +83,23 @@ class S3Storage(BaseStorage):
async def store_artifact(self, artifact: Artifact, data: bytes) -> None: async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
sc = await self._get_storage_class_for_org(artifact.organization_id) sc = await self._get_storage_class_for_org(artifact.organization_id)
tags = await self._get_tags_for_org(artifact.organization_id)
LOG.debug( LOG.debug(
"Storing artifact", "Storing artifact",
artifact_id=artifact.artifact_id, artifact_id=artifact.artifact_id,
organization_id=artifact.organization_id, organization_id=artifact.organization_id,
uri=artifact.uri, uri=artifact.uri,
storage_class=sc, storage_class=sc,
tags=tags,
) )
await self.async_client.upload_file(artifact.uri, data, storage_class=sc) await self.async_client.upload_file(artifact.uri, data, storage_class=sc, tags=tags)
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass: async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
return S3StorageClass.STANDARD return S3StorageClass.STANDARD
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
return {}
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None: async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await self.async_client.download_file(artifact.uri) return await self.async_client.download_file(artifact.uri)
@@ -107,6 +112,7 @@ class S3Storage(BaseStorage):
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None: async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
sc = await self._get_storage_class_for_org(artifact.organization_id) sc = await self._get_storage_class_for_org(artifact.organization_id)
tags = await self._get_tags_for_org(artifact.organization_id)
LOG.debug( LOG.debug(
"Storing artifact from path", "Storing artifact from path",
artifact_id=artifact.artifact_id, artifact_id=artifact.artifact_id,
@@ -114,13 +120,15 @@ class S3Storage(BaseStorage):
uri=artifact.uri, uri=artifact.uri,
storage_class=sc, storage_class=sc,
path=path, path=path,
tags=tags,
) )
await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc) await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc, tags=tags)
async def save_streaming_file(self, organization_id: str, file_name: str) -> None: async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}" from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}"
to_path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}" to_path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
sc = await self._get_storage_class_for_org(organization_id) sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
LOG.debug( LOG.debug(
"Saving streaming file", "Saving streaming file",
organization_id=organization_id, organization_id=organization_id,
@@ -128,8 +136,9 @@ class S3Storage(BaseStorage):
from_path=from_path, from_path=from_path,
to_path=to_path, to_path=to_path,
storage_class=sc, storage_class=sc,
tags=tags,
) )
await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc) await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc, tags=tags)
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None: async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}" path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
@@ -141,6 +150,7 @@ class S3Storage(BaseStorage):
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory) zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip" browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
sc = await self._get_storage_class_for_org(organization_id) sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
LOG.debug( LOG.debug(
"Storing browser session", "Storing browser session",
organization_id=organization_id, organization_id=organization_id,
@@ -148,8 +158,9 @@ class S3Storage(BaseStorage):
zip_file_path=zip_file_path, zip_file_path=zip_file_path,
browser_session_uri=browser_session_uri, browser_session_uri=browser_session_uri,
storage_class=sc, storage_class=sc,
tags=tags,
) )
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc) await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc, tags=tags)
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None: async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip" browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
@@ -171,27 +182,30 @@ class S3Storage(BaseStorage):
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id) download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
files = os.listdir(download_dir) files = os.listdir(download_dir)
sc = await self._get_storage_class_for_org(organization_id) sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
for file in files: for file in files:
fpath = os.path.join(download_dir, file) fpath = os.path.join(download_dir, file)
if os.path.isfile(fpath): if not os.path.isfile(fpath):
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}" continue
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
# Calculate SHA-256 checksum # Calculate SHA-256 checksum
checksum = calculate_sha256_for_file(fpath) checksum = calculate_sha256_for_file(fpath)
LOG.info( LOG.info(
"Calculated checksum for file", "Calculated checksum for file",
file=file, file=file,
checksum=checksum, checksum=checksum,
organization_id=organization_id, organization_id=organization_id,
storage_class=sc, storage_class=sc,
) )
# Upload file with checksum metadata # Upload file with checksum metadata
await self.async_client.upload_file_from_path( await self.async_client.upload_file_from_path(
uri=uri, uri=uri,
file_path=fpath, file_path=fpath,
metadata={"sha256_checksum": checksum, "original_filename": file}, metadata={"sha256_checksum": checksum, "original_filename": file},
storage_class=sc, storage_class=sc,
) tags=tags,
)
async def get_downloaded_files( async def get_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None self, organization_id: str, task_id: str | None, workflow_run_id: str | None
@@ -232,11 +246,12 @@ class S3Storage(BaseStorage):
todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
bucket = settings.AWS_S3_BUCKET_UPLOADS bucket = settings.AWS_S3_BUCKET_UPLOADS
sc = await self._get_storage_class_for_org(organization_id) sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
# First try uploading with original filename # First try uploading with original filename
try: try:
sanitized_filename = os.path.basename(filename) # Remove any path components sanitized_filename = os.path.basename(filename) # Remove any path components
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}" s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}"
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc) uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
except Exception: except Exception:
LOG.error("Failed to upload file to S3", exc_info=True) LOG.error("Failed to upload file to S3", exc_info=True)
uploaded_s3_uri = None uploaded_s3_uri = None
@@ -246,7 +261,7 @@ class S3Storage(BaseStorage):
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}" uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}"
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}" s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}"
fileObj.seek(0) # Reset file pointer fileObj.seek(0) # Reset file pointer
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc) uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
if not uploaded_s3_uri: if not uploaded_s3_uri:
LOG.error( LOG.error(

View File

@@ -6,9 +6,10 @@ import boto3
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from moto.server import ThreadedMotoServer from moto.server import ThreadedMotoServer
from types_boto3_s3.client import S3Client
from skyvern.config import settings from skyvern.config import settings
from skyvern.forge.sdk.api.aws import S3Uri from skyvern.forge.sdk.api.aws import S3StorageClass, S3Uri
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
from skyvern.forge.sdk.artifact.storage.test_helpers import ( from skyvern.forge.sdk.artifact.storage.test_helpers import (
@@ -30,9 +31,17 @@ TEST_BLOCK_ID = "block_123456789"
TEST_AI_SUGGESTION_ID = "ai_sugg_test_123" TEST_AI_SUGGESTION_ID = "ai_sugg_test_123"
class S3StorageForTests(S3Storage):
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
return {"dummy": f"org-{organization_id}", "test": "jerry"}
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
return S3StorageClass.ONEZONE_IA
@pytest.fixture @pytest.fixture
def s3_storage(moto_server: str) -> S3Storage: def s3_storage(moto_server: str) -> S3Storage:
return S3Storage(bucket=TEST_BUCKET, endpoint_url=moto_server) return S3StorageForTests(bucket=TEST_BUCKET, endpoint_url=moto_server)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -53,7 +62,7 @@ def moto_server() -> Generator[str, None, None]:
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def boto3_test_client(moto_server: str) -> boto3.client: def boto3_test_client(moto_server: str) -> Generator[S3Client, None, None]:
client = boto3.client( client = boto3.client(
"s3", "s3",
aws_access_key_id="testing", aws_access_key_id="testing",
@@ -145,6 +154,23 @@ class TestS3StorageBuildURIs:
) )
def _assert_object_meta(boto3_test_client: S3Client, uri: str) -> None:
s3uri = S3Uri(uri)
assert s3uri.bucket == TEST_BUCKET
obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key)
assert obj_meta["StorageClass"] == "ONEZONE_IA"
s3_tags_resp = boto3_test_client.get_object_tagging(Bucket=TEST_BUCKET, Key=s3uri.key)
tags_dict = {tag["Key"]: tag["Value"] for tag in s3_tags_resp["TagSet"]}
assert tags_dict == {"dummy": f"org-{TEST_ORGANIZATION_ID}", "test": "jerry"}
def _assert_object_content(boto3_test_client: S3Client, uri: str, expected_content: bytes) -> None:
s3uri = S3Uri(uri)
assert s3uri.bucket == TEST_BUCKET
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
assert obj_response["Body"].read() == expected_content
@pytest.mark.asyncio @pytest.mark.asyncio
class TestS3StorageStore: class TestS3StorageStore:
"""Test S3Storage store methods.""" """Test S3Storage store methods."""
@@ -174,17 +200,24 @@ class TestS3StorageStore:
modified_at=datetime.utcnow(), modified_at=datetime.utcnow(),
) )
async def test_store_artifact_screenshot( async def test_store_artifact_from_path(
self, s3_storage: S3Storage, boto3_test_client: boto3.client, tmp_path: Path self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
) -> None: ) -> None:
test_data = b"fake screenshot data" test_data = b"fake screenshot data"
artifact = self._create_artifact_for_ai_suggestion( artifact = self._create_artifact_for_ai_suggestion(
s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID
) )
s3uri = S3Uri(artifact.uri)
assert s3uri.bucket == TEST_BUCKET
test_file = tmp_path / "test_screenshot.png" test_file = tmp_path / "test_screenshot.png"
test_file.write_bytes(test_data) test_file.write_bytes(test_data)
await s3_storage.store_artifact_from_path(artifact, str(test_file)) await s3_storage.store_artifact_from_path(artifact, str(test_file))
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key) _assert_object_content(boto3_test_client, artifact.uri, test_data)
assert obj_response["Body"].read() == test_data _assert_object_meta(boto3_test_client, artifact.uri)
async def test_store_artifact(self, s3_storage: S3Storage, boto3_test_client: S3Client) -> None:
test_data = b"fake artifact data"
artifact = self._create_artifact_for_ai_suggestion(s3_storage, ArtifactType.LLM_PROMPT, TEST_AI_SUGGESTION_ID)
await s3_storage.store_artifact(artifact, test_data)
_assert_object_content(boto3_test_client, artifact.uri, test_data)
_assert_object_meta(boto3_test_client, artifact.uri)