Add the ability to add tags to s3 objects we upload + more tests for artifact upload (#2684)

This commit is contained in:
Asher Foa
2025-06-11 15:52:25 -04:00
committed by GitHub
parent 2d2146948a
commit 3100ff9543
3 changed files with 107 additions and 37 deletions

View File

@@ -83,18 +83,23 @@ class S3Storage(BaseStorage):
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
sc = await self._get_storage_class_for_org(artifact.organization_id)
tags = await self._get_tags_for_org(artifact.organization_id)
LOG.debug(
"Storing artifact",
artifact_id=artifact.artifact_id,
organization_id=artifact.organization_id,
uri=artifact.uri,
storage_class=sc,
tags=tags,
)
await self.async_client.upload_file(artifact.uri, data, storage_class=sc)
await self.async_client.upload_file(artifact.uri, data, storage_class=sc, tags=tags)
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
return S3StorageClass.STANDARD
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
return {}
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await self.async_client.download_file(artifact.uri)
@@ -107,6 +112,7 @@ class S3Storage(BaseStorage):
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
sc = await self._get_storage_class_for_org(artifact.organization_id)
tags = await self._get_tags_for_org(artifact.organization_id)
LOG.debug(
"Storing artifact from path",
artifact_id=artifact.artifact_id,
@@ -114,13 +120,15 @@ class S3Storage(BaseStorage):
uri=artifact.uri,
storage_class=sc,
path=path,
tags=tags,
)
await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc)
await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc, tags=tags)
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}"
to_path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
LOG.debug(
"Saving streaming file",
organization_id=organization_id,
@@ -128,8 +136,9 @@ class S3Storage(BaseStorage):
from_path=from_path,
to_path=to_path,
storage_class=sc,
tags=tags,
)
await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc)
await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc, tags=tags)
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
@@ -141,6 +150,7 @@ class S3Storage(BaseStorage):
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
LOG.debug(
"Storing browser session",
organization_id=organization_id,
@@ -148,8 +158,9 @@ class S3Storage(BaseStorage):
zip_file_path=zip_file_path,
browser_session_uri=browser_session_uri,
storage_class=sc,
tags=tags,
)
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc)
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc, tags=tags)
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
@@ -171,27 +182,30 @@ class S3Storage(BaseStorage):
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
files = os.listdir(download_dir)
sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
for file in files:
fpath = os.path.join(download_dir, file)
if os.path.isfile(fpath):
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
if not os.path.isfile(fpath):
continue
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
# Calculate SHA-256 checksum
checksum = calculate_sha256_for_file(fpath)
LOG.info(
"Calculated checksum for file",
file=file,
checksum=checksum,
organization_id=organization_id,
storage_class=sc,
)
# Upload file with checksum metadata
await self.async_client.upload_file_from_path(
uri=uri,
file_path=fpath,
metadata={"sha256_checksum": checksum, "original_filename": file},
storage_class=sc,
)
# Calculate SHA-256 checksum
checksum = calculate_sha256_for_file(fpath)
LOG.info(
"Calculated checksum for file",
file=file,
checksum=checksum,
organization_id=organization_id,
storage_class=sc,
)
# Upload file with checksum metadata
await self.async_client.upload_file_from_path(
uri=uri,
file_path=fpath,
metadata={"sha256_checksum": checksum, "original_filename": file},
storage_class=sc,
tags=tags,
)
async def get_downloaded_files(
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
@@ -232,11 +246,12 @@ class S3Storage(BaseStorage):
todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
bucket = settings.AWS_S3_BUCKET_UPLOADS
sc = await self._get_storage_class_for_org(organization_id)
tags = await self._get_tags_for_org(organization_id)
# First try uploading with original filename
try:
sanitized_filename = os.path.basename(filename) # Remove any path components
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}"
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc)
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
except Exception:
LOG.error("Failed to upload file to S3", exc_info=True)
uploaded_s3_uri = None
@@ -246,7 +261,7 @@ class S3Storage(BaseStorage):
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}"
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}"
fileObj.seek(0) # Reset file pointer
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc)
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
if not uploaded_s3_uri:
LOG.error(

View File

@@ -6,9 +6,10 @@ import boto3
import pytest
from freezegun import freeze_time
from moto.server import ThreadedMotoServer
from types_boto3_s3.client import S3Client
from skyvern.config import settings
from skyvern.forge.sdk.api.aws import S3Uri
from skyvern.forge.sdk.api.aws import S3StorageClass, S3Uri
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
from skyvern.forge.sdk.artifact.storage.test_helpers import (
@@ -30,9 +31,17 @@ TEST_BLOCK_ID = "block_123456789"
TEST_AI_SUGGESTION_ID = "ai_sugg_test_123"
class S3StorageForTests(S3Storage):
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
return {"dummy": f"org-{organization_id}", "test": "jerry"}
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
return S3StorageClass.ONEZONE_IA
@pytest.fixture
def s3_storage(moto_server: str) -> S3Storage:
return S3Storage(bucket=TEST_BUCKET, endpoint_url=moto_server)
return S3StorageForTests(bucket=TEST_BUCKET, endpoint_url=moto_server)
@pytest.fixture(autouse=True)
@@ -53,7 +62,7 @@ def moto_server() -> Generator[str, None, None]:
@pytest.fixture(scope="module", autouse=True)
def boto3_test_client(moto_server: str) -> boto3.client:
def boto3_test_client(moto_server: str) -> Generator[S3Client, None, None]:
client = boto3.client(
"s3",
aws_access_key_id="testing",
@@ -145,6 +154,23 @@ class TestS3StorageBuildURIs:
)
def _assert_object_meta(boto3_test_client: S3Client, uri: str) -> None:
s3uri = S3Uri(uri)
assert s3uri.bucket == TEST_BUCKET
obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key)
assert obj_meta["StorageClass"] == "ONEZONE_IA"
s3_tags_resp = boto3_test_client.get_object_tagging(Bucket=TEST_BUCKET, Key=s3uri.key)
tags_dict = {tag["Key"]: tag["Value"] for tag in s3_tags_resp["TagSet"]}
assert tags_dict == {"dummy": f"org-{TEST_ORGANIZATION_ID}", "test": "jerry"}
def _assert_object_content(boto3_test_client: S3Client, uri: str, expected_content: bytes) -> None:
s3uri = S3Uri(uri)
assert s3uri.bucket == TEST_BUCKET
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
assert obj_response["Body"].read() == expected_content
@pytest.mark.asyncio
class TestS3StorageStore:
"""Test S3Storage store methods."""
@@ -174,17 +200,24 @@ class TestS3StorageStore:
modified_at=datetime.utcnow(),
)
async def test_store_artifact_screenshot(
self, s3_storage: S3Storage, boto3_test_client: boto3.client, tmp_path: Path
async def test_store_artifact_from_path(
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
) -> None:
test_data = b"fake screenshot data"
artifact = self._create_artifact_for_ai_suggestion(
s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID
)
s3uri = S3Uri(artifact.uri)
assert s3uri.bucket == TEST_BUCKET
test_file = tmp_path / "test_screenshot.png"
test_file.write_bytes(test_data)
await s3_storage.store_artifact_from_path(artifact, str(test_file))
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
assert obj_response["Body"].read() == test_data
_assert_object_content(boto3_test_client, artifact.uri, test_data)
_assert_object_meta(boto3_test_client, artifact.uri)
async def test_store_artifact(self, s3_storage: S3Storage, boto3_test_client: S3Client) -> None:
test_data = b"fake artifact data"
artifact = self._create_artifact_for_ai_suggestion(s3_storage, ArtifactType.LLM_PROMPT, TEST_AI_SUGGESTION_ID)
await s3_storage.store_artifact(artifact, test_data)
_assert_object_content(boto3_test_client, artifact.uri, test_data)
_assert_object_meta(boto3_test_client, artifact.uri)