diff --git a/skyvern/forge/sdk/api/aws.py b/skyvern/forge/sdk/api/aws.py index da4d8006..7b33fdfe 100644 --- a/skyvern/forge/sdk/api/aws.py +++ b/skyvern/forge/sdk/api/aws.py @@ -57,6 +57,9 @@ class AsyncAWSClient: def _s3_client(self) -> S3Client: return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url) + def _create_tag_string(self, tags: dict[str, str]) -> str: + return "&".join([f"{k}={v}" for k, v in tags.items()]) + async def get_secret(self, secret_name: str) -> str | None: try: async with self._secrets_manager_client() as client: @@ -95,15 +98,24 @@ class AsyncAWSClient: raise e async def upload_file( - self, uri: str, data: bytes, storage_class: S3StorageClass = S3StorageClass.STANDARD + self, + uri: str, + data: bytes, + storage_class: S3StorageClass = S3StorageClass.STANDARD, + tags: dict[str, str] | None = None, ) -> str | None: if storage_class not in S3StorageClass: raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}") try: async with self._s3_client() as client: parsed_uri = S3Uri(uri) + extra_args = {"Tagging": self._create_tag_string(tags)} if tags else {} await client.put_object( - Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key, StorageClass=str(storage_class) + Body=data, + Bucket=parsed_uri.bucket, + Key=parsed_uri.key, + StorageClass=str(storage_class), + **extra_args, ) return uri except Exception: @@ -111,18 +123,25 @@ class AsyncAWSClient: return None async def upload_file_stream( - self, uri: str, file_obj: IO[bytes], storage_class: S3StorageClass = S3StorageClass.STANDARD + self, + uri: str, + file_obj: IO[bytes], + storage_class: S3StorageClass = S3StorageClass.STANDARD, + tags: dict[str, str] | None = None, ) -> str | None: if storage_class not in S3StorageClass: raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}") try: async with self._s3_client() as client: parsed_uri = S3Uri(uri) + extra_args: dict[str, Any] = {"StorageClass": str(storage_class)} + if tags: + extra_args["Tagging"] = self._create_tag_string(tags) await client.upload_fileobj( file_obj, parsed_uri.bucket, parsed_uri.key, - ExtraArgs={"StorageClass": str(storage_class)}, + ExtraArgs=extra_args, ) LOG.debug("Upload file stream success", uri=uri) return uri @@ -137,6 +156,7 @@ class AsyncAWSClient: storage_class: S3StorageClass = S3StorageClass.STANDARD, metadata: dict | None = None, raise_exception: bool = False, + tags: dict[str, str] | None = None, ) -> None: try: async with self._s3_client() as client: @@ -144,6 +164,8 @@ class AsyncAWSClient: extra_args: dict[str, Any] = {"StorageClass": str(storage_class)} if metadata: extra_args["Metadata"] = metadata + if tags: + extra_args["Tagging"] = self._create_tag_string(tags) await client.upload_file( Filename=file_path, Bucket=parsed_uri.bucket, diff --git a/skyvern/forge/sdk/artifact/storage/s3.py b/skyvern/forge/sdk/artifact/storage/s3.py index ac5192c2..e97b59a9 100644 --- a/skyvern/forge/sdk/artifact/storage/s3.py +++ b/skyvern/forge/sdk/artifact/storage/s3.py @@ -83,18 +83,23 @@ class S3Storage(BaseStorage): async def store_artifact(self, artifact: Artifact, data: bytes) -> None: sc = await self._get_storage_class_for_org(artifact.organization_id) + tags = await self._get_tags_for_org(artifact.organization_id) LOG.debug( "Storing artifact", artifact_id=artifact.artifact_id, organization_id=artifact.organization_id, uri=artifact.uri, storage_class=sc, + tags=tags, ) - await self.async_client.upload_file(artifact.uri, data, storage_class=sc) + await self.async_client.upload_file(artifact.uri, data, storage_class=sc, tags=tags) async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass: return S3StorageClass.STANDARD + async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]: + return {} + async def retrieve_artifact(self, artifact: Artifact) -> bytes | None: return await self.async_client.download_file(artifact.uri) @@ -107,6 +112,7 @@ class S3Storage(BaseStorage): async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None: sc = await self._get_storage_class_for_org(artifact.organization_id) + tags = await self._get_tags_for_org(artifact.organization_id) LOG.debug( "Storing artifact from path", artifact_id=artifact.artifact_id, @@ -114,13 +120,15 @@ class S3Storage(BaseStorage): uri=artifact.uri, storage_class=sc, path=path, + tags=tags, ) - await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc) + await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc, tags=tags) async def save_streaming_file(self, organization_id: str, file_name: str) -> None: from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}" to_path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}" sc = await self._get_storage_class_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) LOG.debug( "Saving streaming file", organization_id=organization_id, @@ -128,8 +136,9 @@ class S3Storage(BaseStorage): from_path=from_path, to_path=to_path, storage_class=sc, + tags=tags, ) - await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc) + await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc, tags=tags) async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None: path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}" @@ -141,6 +150,7 @@ class S3Storage(BaseStorage): zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory) browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip" sc = await self._get_storage_class_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) LOG.debug( "Storing browser session", organization_id=organization_id, @@ -148,8 +158,9 @@ class S3Storage(BaseStorage): zip_file_path=zip_file_path, browser_session_uri=browser_session_uri, storage_class=sc, + tags=tags, ) - await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc) + await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc, tags=tags) async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None: browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip" @@ -171,27 +182,30 @@ class S3Storage(BaseStorage): download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id) files = os.listdir(download_dir) sc = await self._get_storage_class_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) for file in files: fpath = os.path.join(download_dir, file) - if os.path.isfile(fpath): - uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}" + if not os.path.isfile(fpath): + continue + uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}" - # Calculate SHA-256 checksum - checksum = calculate_sha256_for_file(fpath) - LOG.info( - "Calculated checksum for file", - file=file, - checksum=checksum, - organization_id=organization_id, - storage_class=sc, - ) - # Upload file with checksum metadata - await self.async_client.upload_file_from_path( - uri=uri, - file_path=fpath, - metadata={"sha256_checksum": checksum, "original_filename": file}, - storage_class=sc, - ) + # Calculate SHA-256 checksum + checksum = calculate_sha256_for_file(fpath) + LOG.info( + "Calculated checksum for file", + file=file, + checksum=checksum, + organization_id=organization_id, + storage_class=sc, + ) + # Upload file with checksum metadata + await self.async_client.upload_file_from_path( + uri=uri, + file_path=fpath, + metadata={"sha256_checksum": checksum, "original_filename": file}, + storage_class=sc, + tags=tags, + ) async def get_downloaded_files( self, organization_id: str, task_id: str | None, workflow_run_id: str | None @@ -232,11 +246,12 @@ class S3Storage(BaseStorage): todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") bucket = settings.AWS_S3_BUCKET_UPLOADS sc = await self._get_storage_class_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) # First try uploading with original filename try: sanitized_filename = os.path.basename(filename) # Remove any path components s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}" - uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc) + uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags) except Exception: LOG.error("Failed to upload file to S3", exc_info=True) uploaded_s3_uri = None @@ -246,7 +261,7 @@ class S3Storage(BaseStorage): uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}" s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}" fileObj.seek(0) # Reset file pointer - uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc) + uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags) if not uploaded_s3_uri: LOG.error( diff --git a/skyvern/forge/sdk/artifact/storage/test_s3_storage.py b/skyvern/forge/sdk/artifact/storage/test_s3_storage.py index 0403abdc..82c0c7e6 100644 --- a/skyvern/forge/sdk/artifact/storage/test_s3_storage.py +++ b/skyvern/forge/sdk/artifact/storage/test_s3_storage.py @@ -6,9 +6,10 @@ import boto3 import pytest from freezegun import freeze_time from moto.server import ThreadedMotoServer +from types_boto3_s3.client import S3Client from skyvern.config import settings -from skyvern.forge.sdk.api.aws import S3Uri +from skyvern.forge.sdk.api.aws import S3StorageClass, S3Uri from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType from skyvern.forge.sdk.artifact.storage.s3 import S3Storage from skyvern.forge.sdk.artifact.storage.test_helpers import ( @@ -30,9 +31,17 @@ TEST_BLOCK_ID = "block_123456789" TEST_AI_SUGGESTION_ID = "ai_sugg_test_123" +class S3StorageForTests(S3Storage): + async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]: + return {"dummy": f"org-{organization_id}", "test": "jerry"} + + async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass: + return S3StorageClass.ONEZONE_IA + + @pytest.fixture def s3_storage(moto_server: str) -> S3Storage: - return S3Storage(bucket=TEST_BUCKET, endpoint_url=moto_server) + return S3StorageForTests(bucket=TEST_BUCKET, endpoint_url=moto_server) @pytest.fixture(autouse=True) @@ -53,7 +62,7 @@ def moto_server() -> Generator[str, None, None]: @pytest.fixture(scope="module", autouse=True) -def boto3_test_client(moto_server: str) -> boto3.client: +def boto3_test_client(moto_server: str) -> Generator[S3Client, None, None]: client = boto3.client( "s3", aws_access_key_id="testing", @@ -145,6 +154,23 @@ class TestS3StorageBuildURIs: ) +def _assert_object_meta(boto3_test_client: S3Client, uri: str) -> None: + s3uri = S3Uri(uri) + assert s3uri.bucket == TEST_BUCKET + obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key) + assert obj_meta["StorageClass"] == "ONEZONE_IA" + s3_tags_resp = boto3_test_client.get_object_tagging(Bucket=TEST_BUCKET, Key=s3uri.key) + tags_dict = {tag["Key"]: tag["Value"] for tag in s3_tags_resp["TagSet"]} + assert tags_dict == {"dummy": f"org-{TEST_ORGANIZATION_ID}", "test": "jerry"} + + +def _assert_object_content(boto3_test_client: S3Client, uri: str, expected_content: bytes) -> None: + s3uri = S3Uri(uri) + assert s3uri.bucket == TEST_BUCKET + obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key) + assert obj_response["Body"].read() == expected_content + + @pytest.mark.asyncio class TestS3StorageStore: """Test S3Storage store methods.""" @@ -174,17 +200,24 @@ class TestS3StorageStore: modified_at=datetime.utcnow(), ) - async def test_store_artifact_screenshot( - self, s3_storage: S3Storage, boto3_test_client: boto3.client, tmp_path: Path + async def test_store_artifact_from_path( + self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path ) -> None: test_data = b"fake screenshot data" artifact = self._create_artifact_for_ai_suggestion( s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID ) - s3uri = S3Uri(artifact.uri) - assert s3uri.bucket == TEST_BUCKET + test_file = tmp_path / "test_screenshot.png" test_file.write_bytes(test_data) await s3_storage.store_artifact_from_path(artifact, str(test_file)) - obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key) - assert obj_response["Body"].read() == test_data + _assert_object_content(boto3_test_client, artifact.uri, test_data) + _assert_object_meta(boto3_test_client, artifact.uri) + + async def test_store_artifact(self, s3_storage: S3Storage, boto3_test_client: S3Client) -> None: + test_data = b"fake artifact data" + artifact = self._create_artifact_for_ai_suggestion(s3_storage, ArtifactType.LLM_PROMPT, TEST_AI_SUGGESTION_ID) + + await s3_storage.store_artifact(artifact, test_data) + _assert_object_content(boto3_test_client, artifact.uri, test_data) + _assert_object_meta(boto3_test_client, artifact.uri)