Add the ability to add tags to s3 objects we upload + more tests for artifact upload (#2684)
This commit is contained in:
@@ -57,6 +57,9 @@ class AsyncAWSClient:
|
||||
def _s3_client(self) -> S3Client:
|
||||
return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url)
|
||||
|
||||
def _create_tag_string(self, tags: dict[str, str]) -> str:
|
||||
return "&".join([f"{k}={v}" for k, v in tags.items()])
|
||||
|
||||
async def get_secret(self, secret_name: str) -> str | None:
|
||||
try:
|
||||
async with self._secrets_manager_client() as client:
|
||||
@@ -95,15 +98,24 @@ class AsyncAWSClient:
|
||||
raise e
|
||||
|
||||
async def upload_file(
|
||||
self, uri: str, data: bytes, storage_class: S3StorageClass = S3StorageClass.STANDARD
|
||||
self,
|
||||
uri: str,
|
||||
data: bytes,
|
||||
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> str | None:
|
||||
if storage_class not in S3StorageClass:
|
||||
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
||||
try:
|
||||
async with self._s3_client() as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
extra_args = {"Tagging": self._create_tag_string(tags)} if tags else {}
|
||||
await client.put_object(
|
||||
Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key, StorageClass=str(storage_class)
|
||||
Body=data,
|
||||
Bucket=parsed_uri.bucket,
|
||||
Key=parsed_uri.key,
|
||||
StorageClass=str(storage_class),
|
||||
**extra_args,
|
||||
)
|
||||
return uri
|
||||
except Exception:
|
||||
@@ -111,18 +123,25 @@ class AsyncAWSClient:
|
||||
return None
|
||||
|
||||
async def upload_file_stream(
|
||||
self, uri: str, file_obj: IO[bytes], storage_class: S3StorageClass = S3StorageClass.STANDARD
|
||||
self,
|
||||
uri: str,
|
||||
file_obj: IO[bytes],
|
||||
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> str | None:
|
||||
if storage_class not in S3StorageClass:
|
||||
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
||||
try:
|
||||
async with self._s3_client() as client:
|
||||
parsed_uri = S3Uri(uri)
|
||||
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
|
||||
if tags:
|
||||
extra_args["Tagging"] = self._create_tag_string(tags)
|
||||
await client.upload_fileobj(
|
||||
file_obj,
|
||||
parsed_uri.bucket,
|
||||
parsed_uri.key,
|
||||
ExtraArgs={"StorageClass": str(storage_class)},
|
||||
ExtraArgs=extra_args,
|
||||
)
|
||||
LOG.debug("Upload file stream success", uri=uri)
|
||||
return uri
|
||||
@@ -137,6 +156,7 @@ class AsyncAWSClient:
|
||||
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
||||
metadata: dict | None = None,
|
||||
raise_exception: bool = False,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
async with self._s3_client() as client:
|
||||
@@ -144,6 +164,8 @@ class AsyncAWSClient:
|
||||
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
|
||||
if metadata:
|
||||
extra_args["Metadata"] = metadata
|
||||
if tags:
|
||||
extra_args["Tagging"] = self._create_tag_string(tags)
|
||||
await client.upload_file(
|
||||
Filename=file_path,
|
||||
Bucket=parsed_uri.bucket,
|
||||
|
||||
@@ -83,18 +83,23 @@ class S3Storage(BaseStorage):
|
||||
|
||||
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
||||
sc = await self._get_storage_class_for_org(artifact.organization_id)
|
||||
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||
LOG.debug(
|
||||
"Storing artifact",
|
||||
artifact_id=artifact.artifact_id,
|
||||
organization_id=artifact.organization_id,
|
||||
uri=artifact.uri,
|
||||
storage_class=sc,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file(artifact.uri, data, storage_class=sc)
|
||||
await self.async_client.upload_file(artifact.uri, data, storage_class=sc, tags=tags)
|
||||
|
||||
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
|
||||
return S3StorageClass.STANDARD
|
||||
|
||||
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
return await self.async_client.download_file(artifact.uri)
|
||||
|
||||
@@ -107,6 +112,7 @@ class S3Storage(BaseStorage):
|
||||
|
||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||
sc = await self._get_storage_class_for_org(artifact.organization_id)
|
||||
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||
LOG.debug(
|
||||
"Storing artifact from path",
|
||||
artifact_id=artifact.artifact_id,
|
||||
@@ -114,13 +120,15 @@ class S3Storage(BaseStorage):
|
||||
uri=artifact.uri,
|
||||
storage_class=sc,
|
||||
path=path,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc)
|
||||
await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc, tags=tags)
|
||||
|
||||
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
||||
from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}"
|
||||
to_path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Saving streaming file",
|
||||
organization_id=organization_id,
|
||||
@@ -128,8 +136,9 @@ class S3Storage(BaseStorage):
|
||||
from_path=from_path,
|
||||
to_path=to_path,
|
||||
storage_class=sc,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc)
|
||||
await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc, tags=tags)
|
||||
|
||||
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
|
||||
path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
||||
@@ -141,6 +150,7 @@ class S3Storage(BaseStorage):
|
||||
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
|
||||
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Storing browser session",
|
||||
organization_id=organization_id,
|
||||
@@ -148,8 +158,9 @@ class S3Storage(BaseStorage):
|
||||
zip_file_path=zip_file_path,
|
||||
browser_session_uri=browser_session_uri,
|
||||
storage_class=sc,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc)
|
||||
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc, tags=tags)
|
||||
|
||||
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
|
||||
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
||||
@@ -171,27 +182,30 @@ class S3Storage(BaseStorage):
|
||||
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
|
||||
files = os.listdir(download_dir)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
for file in files:
|
||||
fpath = os.path.join(download_dir, file)
|
||||
if os.path.isfile(fpath):
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
|
||||
if not os.path.isfile(fpath):
|
||||
continue
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
|
||||
|
||||
# Calculate SHA-256 checksum
|
||||
checksum = calculate_sha256_for_file(fpath)
|
||||
LOG.info(
|
||||
"Calculated checksum for file",
|
||||
file=file,
|
||||
checksum=checksum,
|
||||
organization_id=organization_id,
|
||||
storage_class=sc,
|
||||
)
|
||||
# Upload file with checksum metadata
|
||||
await self.async_client.upload_file_from_path(
|
||||
uri=uri,
|
||||
file_path=fpath,
|
||||
metadata={"sha256_checksum": checksum, "original_filename": file},
|
||||
storage_class=sc,
|
||||
)
|
||||
# Calculate SHA-256 checksum
|
||||
checksum = calculate_sha256_for_file(fpath)
|
||||
LOG.info(
|
||||
"Calculated checksum for file",
|
||||
file=file,
|
||||
checksum=checksum,
|
||||
organization_id=organization_id,
|
||||
storage_class=sc,
|
||||
)
|
||||
# Upload file with checksum metadata
|
||||
await self.async_client.upload_file_from_path(
|
||||
uri=uri,
|
||||
file_path=fpath,
|
||||
metadata={"sha256_checksum": checksum, "original_filename": file},
|
||||
storage_class=sc,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
async def get_downloaded_files(
|
||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||
@@ -232,11 +246,12 @@ class S3Storage(BaseStorage):
|
||||
todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
bucket = settings.AWS_S3_BUCKET_UPLOADS
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
# First try uploading with original filename
|
||||
try:
|
||||
sanitized_filename = os.path.basename(filename) # Remove any path components
|
||||
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}"
|
||||
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc)
|
||||
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
|
||||
except Exception:
|
||||
LOG.error("Failed to upload file to S3", exc_info=True)
|
||||
uploaded_s3_uri = None
|
||||
@@ -246,7 +261,7 @@ class S3Storage(BaseStorage):
|
||||
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}"
|
||||
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}"
|
||||
fileObj.seek(0) # Reset file pointer
|
||||
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc)
|
||||
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
|
||||
|
||||
if not uploaded_s3_uri:
|
||||
LOG.error(
|
||||
|
||||
@@ -6,9 +6,10 @@ import boto3
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from moto.server import ThreadedMotoServer
|
||||
from types_boto3_s3.client import S3Client
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.api.aws import S3Uri
|
||||
from skyvern.forge.sdk.api.aws import S3StorageClass, S3Uri
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
|
||||
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
|
||||
from skyvern.forge.sdk.artifact.storage.test_helpers import (
|
||||
@@ -30,9 +31,17 @@ TEST_BLOCK_ID = "block_123456789"
|
||||
TEST_AI_SUGGESTION_ID = "ai_sugg_test_123"
|
||||
|
||||
|
||||
class S3StorageForTests(S3Storage):
|
||||
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||
return {"dummy": f"org-{organization_id}", "test": "jerry"}
|
||||
|
||||
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
|
||||
return S3StorageClass.ONEZONE_IA
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_storage(moto_server: str) -> S3Storage:
|
||||
return S3Storage(bucket=TEST_BUCKET, endpoint_url=moto_server)
|
||||
return S3StorageForTests(bucket=TEST_BUCKET, endpoint_url=moto_server)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -53,7 +62,7 @@ def moto_server() -> Generator[str, None, None]:
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def boto3_test_client(moto_server: str) -> boto3.client:
|
||||
def boto3_test_client(moto_server: str) -> Generator[S3Client, None, None]:
|
||||
client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id="testing",
|
||||
@@ -145,6 +154,23 @@ class TestS3StorageBuildURIs:
|
||||
)
|
||||
|
||||
|
||||
def _assert_object_meta(boto3_test_client: S3Client, uri: str) -> None:
|
||||
s3uri = S3Uri(uri)
|
||||
assert s3uri.bucket == TEST_BUCKET
|
||||
obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
assert obj_meta["StorageClass"] == "ONEZONE_IA"
|
||||
s3_tags_resp = boto3_test_client.get_object_tagging(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
tags_dict = {tag["Key"]: tag["Value"] for tag in s3_tags_resp["TagSet"]}
|
||||
assert tags_dict == {"dummy": f"org-{TEST_ORGANIZATION_ID}", "test": "jerry"}
|
||||
|
||||
|
||||
def _assert_object_content(boto3_test_client: S3Client, uri: str, expected_content: bytes) -> None:
|
||||
s3uri = S3Uri(uri)
|
||||
assert s3uri.bucket == TEST_BUCKET
|
||||
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
assert obj_response["Body"].read() == expected_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestS3StorageStore:
|
||||
"""Test S3Storage store methods."""
|
||||
@@ -174,17 +200,24 @@ class TestS3StorageStore:
|
||||
modified_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
async def test_store_artifact_screenshot(
|
||||
self, s3_storage: S3Storage, boto3_test_client: boto3.client, tmp_path: Path
|
||||
async def test_store_artifact_from_path(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
test_data = b"fake screenshot data"
|
||||
artifact = self._create_artifact_for_ai_suggestion(
|
||||
s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID
|
||||
)
|
||||
s3uri = S3Uri(artifact.uri)
|
||||
assert s3uri.bucket == TEST_BUCKET
|
||||
|
||||
test_file = tmp_path / "test_screenshot.png"
|
||||
test_file.write_bytes(test_data)
|
||||
await s3_storage.store_artifact_from_path(artifact, str(test_file))
|
||||
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
assert obj_response["Body"].read() == test_data
|
||||
_assert_object_content(boto3_test_client, artifact.uri, test_data)
|
||||
_assert_object_meta(boto3_test_client, artifact.uri)
|
||||
|
||||
async def test_store_artifact(self, s3_storage: S3Storage, boto3_test_client: S3Client) -> None:
|
||||
test_data = b"fake artifact data"
|
||||
artifact = self._create_artifact_for_ai_suggestion(s3_storage, ArtifactType.LLM_PROMPT, TEST_AI_SUGGESTION_ID)
|
||||
|
||||
await s3_storage.store_artifact(artifact, test_data)
|
||||
_assert_object_content(boto3_test_client, artifact.uri, test_data)
|
||||
_assert_object_meta(boto3_test_client, artifact.uri)
|
||||
|
||||
Reference in New Issue
Block a user