add zstd compression for .har (#4420)
This commit is contained in:
@@ -16,6 +16,7 @@ from skyvern.config import settings
|
||||
# Register custom mime types for mimetypes guessing
|
||||
add_type("application/json", ".har")
|
||||
add_type("text/plain", ".log")
|
||||
add_type("application/zstd", ".zst")
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -27,7 +28,6 @@ class S3StorageClass(StrEnum):
|
||||
# INTELLIGENT_TIERING = "INTELLIGENT_TIERING"
|
||||
ONEZONE_IA = "ONEZONE_IA"
|
||||
GLACIER = "GLACIER"
|
||||
GLACIER_IR = "GLACIER_IR" # Glacier Instant Retrieval
|
||||
# DEEP_ARCHIVE = "DEEP_ARCHIVE"
|
||||
# OUTPOSTS = "OUTPOSTS"
|
||||
# STANDARD_IA = "STANDARD_IA"
|
||||
|
||||
@@ -23,6 +23,7 @@ from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential
|
||||
# Register custom mime types for mimetypes guessing
|
||||
add_type("application/json", ".har")
|
||||
add_type("text/plain", ".log")
|
||||
add_type("application/zstd", ".zst")
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime, timezone
|
||||
from typing import BinaryIO
|
||||
|
||||
import structlog
|
||||
import zstandard as zstd
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import BROWSER_DOWNLOADING_SUFFIX, DOWNLOAD_FILE_PREFIX
|
||||
@@ -101,26 +102,42 @@ class S3Storage(BaseStorage):
|
||||
return f"{self._build_base_uri(organization_id)}/scripts/{script_id}/{script_version}/{file_path}"
|
||||
|
||||
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
||||
sc = await self._get_storage_class_for_org(artifact.organization_id, self.bucket)
|
||||
# We compress HAR files with zstd level 3 to reduce storage size.
|
||||
# HARs are easily compressible because they are mostly JSON.
|
||||
# Other artifacts are not compressed because they are not easily compressible.
|
||||
uri = artifact.uri
|
||||
if artifact.artifact_type == ArtifactType.HAR:
|
||||
cctx = zstd.ZstdCompressor(level=3)
|
||||
data = cctx.compress(data)
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact.artifact_type]
|
||||
uri = uri.replace(f".{file_ext}", f".{file_ext}.zst")
|
||||
artifact.uri = uri
|
||||
|
||||
sc = await self._get_storage_class_for_org(artifact.organization_id)
|
||||
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||
LOG.debug(
|
||||
"Storing artifact",
|
||||
artifact_id=artifact.artifact_id,
|
||||
organization_id=artifact.organization_id,
|
||||
uri=artifact.uri,
|
||||
uri=uri,
|
||||
storage_class=sc,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file(artifact.uri, data, storage_class=sc, tags=tags)
|
||||
await self.async_client.upload_file(uri, data, storage_class=sc, tags=tags)
|
||||
|
||||
async def _get_storage_class_for_org(self, organization_id: str, bucket: str) -> S3StorageClass:
|
||||
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
|
||||
return S3StorageClass.STANDARD
|
||||
|
||||
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
return await self.async_client.download_file(artifact.uri)
|
||||
data = await self.async_client.download_file(artifact.uri)
|
||||
# Decompress zstd-compressed files
|
||||
if data and artifact.uri.endswith(".zst"):
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
data = dctx.decompress(data)
|
||||
return data
|
||||
|
||||
async def get_share_link(self, artifact: Artifact) -> str | None:
|
||||
share_urls = await self.async_client.create_presigned_urls([artifact.uri])
|
||||
@@ -130,7 +147,7 @@ class S3Storage(BaseStorage):
|
||||
return await self.async_client.create_presigned_urls([artifact.uri for artifact in artifacts])
|
||||
|
||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||
sc = await self._get_storage_class_for_org(artifact.organization_id, self.bucket)
|
||||
sc = await self._get_storage_class_for_org(artifact.organization_id)
|
||||
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||
LOG.debug(
|
||||
"Storing artifact from path",
|
||||
@@ -146,7 +163,7 @@ class S3Storage(BaseStorage):
|
||||
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
||||
from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}"
|
||||
to_path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
||||
sc = await self._get_storage_class_for_org(organization_id, settings.AWS_S3_BUCKET_SCREENSHOTS)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Saving streaming file",
|
||||
@@ -168,7 +185,7 @@ class S3Storage(BaseStorage):
|
||||
temp_zip_file = create_named_temporary_file()
|
||||
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
|
||||
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
||||
sc = await self._get_storage_class_for_org(organization_id, settings.AWS_S3_BUCKET_BROWSER_SESSIONS)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Storing browser session",
|
||||
@@ -202,7 +219,7 @@ class S3Storage(BaseStorage):
|
||||
profile_uri = (
|
||||
f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/profiles/{profile_id}.zip"
|
||||
)
|
||||
sc = await self._get_storage_class_for_org(organization_id, settings.AWS_S3_BUCKET_BROWSER_SESSIONS)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Storing browser profile",
|
||||
@@ -354,7 +371,7 @@ class S3Storage(BaseStorage):
|
||||
async def save_downloaded_files(self, organization_id: str, run_id: str | None) -> None:
|
||||
download_dir = get_download_dir(run_id=run_id)
|
||||
files = os.listdir(download_dir)
|
||||
sc = await self._get_storage_class_for_org(organization_id, settings.AWS_S3_BUCKET_UPLOADS)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
base_uri = (
|
||||
f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
@@ -418,7 +435,7 @@ class S3Storage(BaseStorage):
|
||||
) -> tuple[str, str] | None:
|
||||
todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
bucket = settings.AWS_S3_BUCKET_UPLOADS
|
||||
sc = await self._get_storage_class_for_org(organization_id, bucket)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
# First try uploading with original filename
|
||||
try:
|
||||
@@ -491,7 +508,7 @@ class S3Storage(BaseStorage):
|
||||
) -> str:
|
||||
"""Sync a file from local browser session to S3."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
sc = await self._get_storage_class_for_org(organization_id, self.bucket)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
await self.async_client.upload_file_from_path(uri, local_file_path, storage_class=sc, tags=tags)
|
||||
return uri
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Generator
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
import zstandard as zstd
|
||||
from freezegun import freeze_time
|
||||
from moto.server import ThreadedMotoServer
|
||||
from types_boto3_s3.client import S3Client
|
||||
@@ -35,7 +36,7 @@ class S3StorageForTests(S3Storage):
|
||||
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||
return {"dummy": f"org-{organization_id}", "test": "jerry"}
|
||||
|
||||
async def _get_storage_class_for_org(self, organization_id: str, bucket: str) -> S3StorageClass:
|
||||
async def _get_storage_class_for_org(self, organization_id: str) -> S3StorageClass:
|
||||
return S3StorageClass.ONEZONE_IA
|
||||
|
||||
|
||||
@@ -447,3 +448,106 @@ class TestS3StorageContentType:
|
||||
s3uri = S3Uri(uri)
|
||||
obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
assert obj_meta["ContentType"] == expected_content_type
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestS3StorageHARCompression:
|
||||
"""Test S3Storage HAR file compression with zstd."""
|
||||
|
||||
def _create_har_artifact(self, s3_storage: S3Storage, step_id: str) -> Artifact:
|
||||
"""Helper method to create a HAR Artifact."""
|
||||
artifact_id_val = generate_artifact_id()
|
||||
step = create_fake_step(step_id)
|
||||
uri = s3_storage.build_uri(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
artifact_id=artifact_id_val,
|
||||
step=step,
|
||||
artifact_type=ArtifactType.HAR,
|
||||
)
|
||||
return Artifact(
|
||||
artifact_id=artifact_id_val,
|
||||
artifact_type=ArtifactType.HAR,
|
||||
uri=uri,
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
step_id=step.step_id,
|
||||
task_id=step.task_id,
|
||||
created_at=datetime.utcnow(),
|
||||
modified_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
async def test_store_har_artifact_compresses_with_zstd(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client
|
||||
) -> None:
|
||||
"""Test that HAR artifacts are compressed with zstd and URI is updated."""
|
||||
|
||||
# Create sample HAR JSON data (easily compressible)
|
||||
har_data = b'{"log": {"version": "1.2", "entries": [{"request": {}, "response": {}}]}}'
|
||||
artifact = self._create_har_artifact(s3_storage, TEST_STEP_ID)
|
||||
original_uri = artifact.uri
|
||||
|
||||
# Store the artifact
|
||||
await s3_storage.store_artifact(artifact, har_data)
|
||||
|
||||
# Verify URI was updated to .har.zst
|
||||
assert artifact.uri.endswith(".har.zst")
|
||||
assert artifact.uri == original_uri.replace(".har", ".har.zst")
|
||||
|
||||
# Verify the stored data is compressed
|
||||
s3uri = S3Uri(artifact.uri)
|
||||
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
stored_data = obj_response["Body"].read()
|
||||
|
||||
# Stored data should be different from original (compressed)
|
||||
assert stored_data != har_data
|
||||
|
||||
# Verify we can decompress it back to original
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
decompressed = dctx.decompress(stored_data)
|
||||
assert decompressed == har_data
|
||||
|
||||
async def test_retrieve_har_artifact_decompresses_zstd(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client
|
||||
) -> None:
|
||||
"""Test that retrieving a .zst HAR artifact auto-decompresses it."""
|
||||
# Create and store HAR artifact
|
||||
har_data = b'{"log": {"version": "1.2", "creator": {"name": "test"}}}'
|
||||
artifact = self._create_har_artifact(s3_storage, TEST_STEP_ID)
|
||||
|
||||
await s3_storage.store_artifact(artifact, har_data)
|
||||
|
||||
# Retrieve should auto-decompress
|
||||
retrieved_data = await s3_storage.retrieve_artifact(artifact)
|
||||
assert retrieved_data == har_data
|
||||
|
||||
async def test_non_har_artifact_not_compressed(self, s3_storage: S3Storage, boto3_test_client: S3Client) -> None:
|
||||
"""Test that non-HAR artifacts are NOT compressed."""
|
||||
test_data = b"fake screenshot data"
|
||||
artifact_id_val = generate_artifact_id()
|
||||
step = create_fake_step(TEST_STEP_ID)
|
||||
uri = s3_storage.build_uri(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
artifact_id=artifact_id_val,
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
)
|
||||
artifact = Artifact(
|
||||
artifact_id=artifact_id_val,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
uri=uri,
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
step_id=step.step_id,
|
||||
task_id=step.task_id,
|
||||
created_at=datetime.utcnow(),
|
||||
modified_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
await s3_storage.store_artifact(artifact, test_data)
|
||||
|
||||
# URI should NOT have .zst extension
|
||||
assert not artifact.uri.endswith(".zst")
|
||||
|
||||
# Stored data should be identical to original
|
||||
s3uri = S3Uri(artifact.uri)
|
||||
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
stored_data = obj_response["Body"].read()
|
||||
assert stored_data == test_data
|
||||
|
||||
Reference in New Issue
Block a user