Add the ability to add tags to s3 objects we upload + more tests for artifact upload (#2684)
This commit is contained in:
@@ -57,6 +57,9 @@ class AsyncAWSClient:
|
|||||||
def _s3_client(self) -> S3Client:
|
def _s3_client(self) -> S3Client:
|
||||||
return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url)
|
return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url)
|
||||||
|
|
||||||
|
def _create_tag_string(self, tags: dict[str, str]) -> str:
|
||||||
|
return "&".join([f"{k}={v}" for k, v in tags.items()])
|
||||||
|
|
||||||
async def get_secret(self, secret_name: str) -> str | None:
|
async def get_secret(self, secret_name: str) -> str | None:
|
||||||
try:
|
try:
|
||||||
async with self._secrets_manager_client() as client:
|
async with self._secrets_manager_client() as client:
|
||||||
@@ -95,15 +98,24 @@ class AsyncAWSClient:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
self, uri: str, data: bytes, storage_class: S3StorageClass = S3StorageClass.STANDARD
|
self,
|
||||||
|
uri: str,
|
||||||
|
data: bytes,
|
||||||
|
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
||||||
|
tags: dict[str, str] | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
if storage_class not in S3StorageClass:
|
if storage_class not in S3StorageClass:
|
||||||
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
||||||
try:
|
try:
|
||||||
async with self._s3_client() as client:
|
async with self._s3_client() as client:
|
||||||
parsed_uri = S3Uri(uri)
|
parsed_uri = S3Uri(uri)
|
||||||
|
extra_args = {"Tagging": self._create_tag_string(tags)} if tags else {}
|
||||||
await client.put_object(
|
await client.put_object(
|
||||||
Body=data, Bucket=parsed_uri.bucket, Key=parsed_uri.key, StorageClass=str(storage_class)
|
Body=data,
|
||||||
|
Bucket=parsed_uri.bucket,
|
||||||
|
Key=parsed_uri.key,
|
||||||
|
StorageClass=str(storage_class),
|
||||||
|
**extra_args,
|
||||||
)
|
)
|
||||||
return uri
|
return uri
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -111,18 +123,25 @@ class AsyncAWSClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def upload_file_stream(
|
async def upload_file_stream(
|
||||||
self, uri: str, file_obj: IO[bytes], storage_class: S3StorageClass = S3StorageClass.STANDARD
|
self,
|
||||||
|
uri: str,
|
||||||
|
file_obj: IO[bytes],
|
||||||
|
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
||||||
|
tags: dict[str, str] | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
if storage_class not in S3StorageClass:
|
if storage_class not in S3StorageClass:
|
||||||
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
|
||||||
try:
|
try:
|
||||||
async with self._s3_client() as client:
|
async with self._s3_client() as client:
|
||||||
parsed_uri = S3Uri(uri)
|
parsed_uri = S3Uri(uri)
|
||||||
|
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
|
||||||
|
if tags:
|
||||||
|
extra_args["Tagging"] = self._create_tag_string(tags)
|
||||||
await client.upload_fileobj(
|
await client.upload_fileobj(
|
||||||
file_obj,
|
file_obj,
|
||||||
parsed_uri.bucket,
|
parsed_uri.bucket,
|
||||||
parsed_uri.key,
|
parsed_uri.key,
|
||||||
ExtraArgs={"StorageClass": str(storage_class)},
|
ExtraArgs=extra_args,
|
||||||
)
|
)
|
||||||
LOG.debug("Upload file stream success", uri=uri)
|
LOG.debug("Upload file stream success", uri=uri)
|
||||||
return uri
|
return uri
|
||||||
@@ -137,6 +156,7 @@ class AsyncAWSClient:
|
|||||||
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
storage_class: S3StorageClass = S3StorageClass.STANDARD,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
raise_exception: bool = False,
|
raise_exception: bool = False,
|
||||||
|
tags: dict[str, str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
async with self._s3_client() as client:
|
async with self._s3_client() as client:
|
||||||
@@ -144,6 +164,8 @@ class AsyncAWSClient:
|
|||||||
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
|
extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
|
||||||
if metadata:
|
if metadata:
|
||||||
extra_args["Metadata"] = metadata
|
extra_args["Metadata"] = metadata
|
||||||
|
if tags:
|
||||||
|
extra_args["Tagging"] = self._create_tag_string(tags)
|
||||||
await client.upload_file(
|
await client.upload_file(
|
||||||
Filename=file_path,
|
Filename=file_path,
|
||||||
Bucket=parsed_uri.bucket,
|
Bucket=parsed_uri.bucket,
|
||||||
|
|||||||
@@ -83,18 +83,23 @@ class S3Storage(BaseStorage):
|
|||||||
|
|
||||||
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
||||||
sc = await self._get_storage_class_for_org(artifact.organization_id)
|
sc = await self._get_storage_class_for_org(artifact.organization_id)
|
||||||
|
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
"Storing artifact",
|
"Storing artifact",
|
||||||
artifact_id=artifact.artifact_id,
|
artifact_id=artifact.artifact_id,
|
||||||
organization_id=artifact.organization_id,
|
organization_id=artifact.organization_id,
|
||||||
uri=artifact.uri,
|
uri=artifact.uri,
|
||||||
storage_class=sc,
|
storage_class=sc,
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
await self.async_client.upload_file(artifact.uri, data, storage_class=sc)
|
await self.async_client.upload_file(artifact.uri, data, storage_class=sc, tags=tags)
|
||||||
|
|
||||||
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
|
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
|
||||||
return S3StorageClass.STANDARD
|
return S3StorageClass.STANDARD
|
||||||
|
|
||||||
|
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||||
|
return {}
|
||||||
|
|
||||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||||
return await self.async_client.download_file(artifact.uri)
|
return await self.async_client.download_file(artifact.uri)
|
||||||
|
|
||||||
@@ -107,6 +112,7 @@ class S3Storage(BaseStorage):
|
|||||||
|
|
||||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||||
sc = await self._get_storage_class_for_org(artifact.organization_id)
|
sc = await self._get_storage_class_for_org(artifact.organization_id)
|
||||||
|
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
"Storing artifact from path",
|
"Storing artifact from path",
|
||||||
artifact_id=artifact.artifact_id,
|
artifact_id=artifact.artifact_id,
|
||||||
@@ -114,13 +120,15 @@ class S3Storage(BaseStorage):
|
|||||||
uri=artifact.uri,
|
uri=artifact.uri,
|
||||||
storage_class=sc,
|
storage_class=sc,
|
||||||
path=path,
|
path=path,
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc)
|
await self.async_client.upload_file_from_path(artifact.uri, path, storage_class=sc, tags=tags)
|
||||||
|
|
||||||
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
||||||
from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}"
|
from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}"
|
||||||
to_path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
to_path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
||||||
sc = await self._get_storage_class_for_org(organization_id)
|
sc = await self._get_storage_class_for_org(organization_id)
|
||||||
|
tags = await self._get_tags_for_org(organization_id)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
"Saving streaming file",
|
"Saving streaming file",
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
@@ -128,8 +136,9 @@ class S3Storage(BaseStorage):
|
|||||||
from_path=from_path,
|
from_path=from_path,
|
||||||
to_path=to_path,
|
to_path=to_path,
|
||||||
storage_class=sc,
|
storage_class=sc,
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc)
|
await self.async_client.upload_file_from_path(to_path, from_path, storage_class=sc, tags=tags)
|
||||||
|
|
||||||
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
|
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
|
||||||
path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
||||||
@@ -141,6 +150,7 @@ class S3Storage(BaseStorage):
|
|||||||
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
|
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
|
||||||
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
||||||
sc = await self._get_storage_class_for_org(organization_id)
|
sc = await self._get_storage_class_for_org(organization_id)
|
||||||
|
tags = await self._get_tags_for_org(organization_id)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
"Storing browser session",
|
"Storing browser session",
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
@@ -148,8 +158,9 @@ class S3Storage(BaseStorage):
|
|||||||
zip_file_path=zip_file_path,
|
zip_file_path=zip_file_path,
|
||||||
browser_session_uri=browser_session_uri,
|
browser_session_uri=browser_session_uri,
|
||||||
storage_class=sc,
|
storage_class=sc,
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc)
|
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, storage_class=sc, tags=tags)
|
||||||
|
|
||||||
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
|
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
|
||||||
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
||||||
@@ -171,27 +182,30 @@ class S3Storage(BaseStorage):
|
|||||||
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
|
download_dir = get_download_dir(workflow_run_id=workflow_run_id, task_id=task_id)
|
||||||
files = os.listdir(download_dir)
|
files = os.listdir(download_dir)
|
||||||
sc = await self._get_storage_class_for_org(organization_id)
|
sc = await self._get_storage_class_for_org(organization_id)
|
||||||
|
tags = await self._get_tags_for_org(organization_id)
|
||||||
for file in files:
|
for file in files:
|
||||||
fpath = os.path.join(download_dir, file)
|
fpath = os.path.join(download_dir, file)
|
||||||
if os.path.isfile(fpath):
|
if not os.path.isfile(fpath):
|
||||||
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
|
continue
|
||||||
|
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{workflow_run_id or task_id}/{file}"
|
||||||
|
|
||||||
# Calculate SHA-256 checksum
|
# Calculate SHA-256 checksum
|
||||||
checksum = calculate_sha256_for_file(fpath)
|
checksum = calculate_sha256_for_file(fpath)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Calculated checksum for file",
|
"Calculated checksum for file",
|
||||||
file=file,
|
file=file,
|
||||||
checksum=checksum,
|
checksum=checksum,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
storage_class=sc,
|
storage_class=sc,
|
||||||
)
|
)
|
||||||
# Upload file with checksum metadata
|
# Upload file with checksum metadata
|
||||||
await self.async_client.upload_file_from_path(
|
await self.async_client.upload_file_from_path(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
file_path=fpath,
|
file_path=fpath,
|
||||||
metadata={"sha256_checksum": checksum, "original_filename": file},
|
metadata={"sha256_checksum": checksum, "original_filename": file},
|
||||||
storage_class=sc,
|
storage_class=sc,
|
||||||
)
|
tags=tags,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_downloaded_files(
|
async def get_downloaded_files(
|
||||||
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
self, organization_id: str, task_id: str | None, workflow_run_id: str | None
|
||||||
@@ -232,11 +246,12 @@ class S3Storage(BaseStorage):
|
|||||||
todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||||
bucket = settings.AWS_S3_BUCKET_UPLOADS
|
bucket = settings.AWS_S3_BUCKET_UPLOADS
|
||||||
sc = await self._get_storage_class_for_org(organization_id)
|
sc = await self._get_storage_class_for_org(organization_id)
|
||||||
|
tags = await self._get_tags_for_org(organization_id)
|
||||||
# First try uploading with original filename
|
# First try uploading with original filename
|
||||||
try:
|
try:
|
||||||
sanitized_filename = os.path.basename(filename) # Remove any path components
|
sanitized_filename = os.path.basename(filename) # Remove any path components
|
||||||
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}"
|
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}"
|
||||||
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc)
|
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
|
||||||
except Exception:
|
except Exception:
|
||||||
LOG.error("Failed to upload file to S3", exc_info=True)
|
LOG.error("Failed to upload file to S3", exc_info=True)
|
||||||
uploaded_s3_uri = None
|
uploaded_s3_uri = None
|
||||||
@@ -246,7 +261,7 @@ class S3Storage(BaseStorage):
|
|||||||
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}"
|
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}"
|
||||||
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}"
|
s3_uri = f"s3://{bucket}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}"
|
||||||
fileObj.seek(0) # Reset file pointer
|
fileObj.seek(0) # Reset file pointer
|
||||||
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc)
|
uploaded_s3_uri = await self.async_client.upload_file_stream(s3_uri, fileObj, storage_class=sc, tags=tags)
|
||||||
|
|
||||||
if not uploaded_s3_uri:
|
if not uploaded_s3_uri:
|
||||||
LOG.error(
|
LOG.error(
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ import boto3
|
|||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
from moto.server import ThreadedMotoServer
|
from moto.server import ThreadedMotoServer
|
||||||
|
from types_boto3_s3.client import S3Client
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.forge.sdk.api.aws import S3Uri
|
from skyvern.forge.sdk.api.aws import S3StorageClass, S3Uri
|
||||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
|
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
|
||||||
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
|
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
|
||||||
from skyvern.forge.sdk.artifact.storage.test_helpers import (
|
from skyvern.forge.sdk.artifact.storage.test_helpers import (
|
||||||
@@ -30,9 +31,17 @@ TEST_BLOCK_ID = "block_123456789"
|
|||||||
TEST_AI_SUGGESTION_ID = "ai_sugg_test_123"
|
TEST_AI_SUGGESTION_ID = "ai_sugg_test_123"
|
||||||
|
|
||||||
|
|
||||||
|
class S3StorageForTests(S3Storage):
|
||||||
|
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||||
|
return {"dummy": f"org-{organization_id}", "test": "jerry"}
|
||||||
|
|
||||||
|
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
|
||||||
|
return S3StorageClass.ONEZONE_IA
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def s3_storage(moto_server: str) -> S3Storage:
|
def s3_storage(moto_server: str) -> S3Storage:
|
||||||
return S3Storage(bucket=TEST_BUCKET, endpoint_url=moto_server)
|
return S3StorageForTests(bucket=TEST_BUCKET, endpoint_url=moto_server)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -53,7 +62,7 @@ def moto_server() -> Generator[str, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
def boto3_test_client(moto_server: str) -> boto3.client:
|
def boto3_test_client(moto_server: str) -> Generator[S3Client, None, None]:
|
||||||
client = boto3.client(
|
client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
aws_access_key_id="testing",
|
aws_access_key_id="testing",
|
||||||
@@ -145,6 +154,23 @@ class TestS3StorageBuildURIs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_object_meta(boto3_test_client: S3Client, uri: str) -> None:
|
||||||
|
s3uri = S3Uri(uri)
|
||||||
|
assert s3uri.bucket == TEST_BUCKET
|
||||||
|
obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||||
|
assert obj_meta["StorageClass"] == "ONEZONE_IA"
|
||||||
|
s3_tags_resp = boto3_test_client.get_object_tagging(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||||
|
tags_dict = {tag["Key"]: tag["Value"] for tag in s3_tags_resp["TagSet"]}
|
||||||
|
assert tags_dict == {"dummy": f"org-{TEST_ORGANIZATION_ID}", "test": "jerry"}
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_object_content(boto3_test_client: S3Client, uri: str, expected_content: bytes) -> None:
|
||||||
|
s3uri = S3Uri(uri)
|
||||||
|
assert s3uri.bucket == TEST_BUCKET
|
||||||
|
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||||
|
assert obj_response["Body"].read() == expected_content
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestS3StorageStore:
|
class TestS3StorageStore:
|
||||||
"""Test S3Storage store methods."""
|
"""Test S3Storage store methods."""
|
||||||
@@ -174,17 +200,24 @@ class TestS3StorageStore:
|
|||||||
modified_at=datetime.utcnow(),
|
modified_at=datetime.utcnow(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_store_artifact_screenshot(
|
async def test_store_artifact_from_path(
|
||||||
self, s3_storage: S3Storage, boto3_test_client: boto3.client, tmp_path: Path
|
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||||
) -> None:
|
) -> None:
|
||||||
test_data = b"fake screenshot data"
|
test_data = b"fake screenshot data"
|
||||||
artifact = self._create_artifact_for_ai_suggestion(
|
artifact = self._create_artifact_for_ai_suggestion(
|
||||||
s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID
|
s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID
|
||||||
)
|
)
|
||||||
s3uri = S3Uri(artifact.uri)
|
|
||||||
assert s3uri.bucket == TEST_BUCKET
|
|
||||||
test_file = tmp_path / "test_screenshot.png"
|
test_file = tmp_path / "test_screenshot.png"
|
||||||
test_file.write_bytes(test_data)
|
test_file.write_bytes(test_data)
|
||||||
await s3_storage.store_artifact_from_path(artifact, str(test_file))
|
await s3_storage.store_artifact_from_path(artifact, str(test_file))
|
||||||
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
_assert_object_content(boto3_test_client, artifact.uri, test_data)
|
||||||
assert obj_response["Body"].read() == test_data
|
_assert_object_meta(boto3_test_client, artifact.uri)
|
||||||
|
|
||||||
|
async def test_store_artifact(self, s3_storage: S3Storage, boto3_test_client: S3Client) -> None:
|
||||||
|
test_data = b"fake artifact data"
|
||||||
|
artifact = self._create_artifact_for_ai_suggestion(s3_storage, ArtifactType.LLM_PROMPT, TEST_AI_SUGGESTION_ID)
|
||||||
|
|
||||||
|
await s3_storage.store_artifact(artifact, test_data)
|
||||||
|
_assert_object_content(boto3_test_client, artifact.uri, test_data)
|
||||||
|
_assert_object_meta(boto3_test_client, artifact.uri)
|
||||||
|
|||||||
Reference in New Issue
Block a user