basic test for s3 artifact upload logic (#2681)

This commit is contained in:
Asher Foa
2025-06-11 14:23:58 -04:00
committed by GitHub
parent ef19c0265e
commit c1e19d27d3
5 changed files with 628 additions and 63 deletions

View File

@@ -37,21 +37,25 @@ class AsyncAWSClient:
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
region_name: str | None = None,
endpoint_url: str | None = None,
) -> None:
self.region_name = region_name or settings.AWS_REGION
self._endpoint_url = endpoint_url
self.session = aioboto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
def _ecs_client(self) -> ECSClient:
return self.session.client(AWSClientType.ECS, region_name=self.region_name)
return self.session.client(AWSClientType.ECS, region_name=self.region_name, endpoint_url=self._endpoint_url)
def _secrets_manager_client(self) -> SecretsManagerClient:
return self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name)
return self.session.client(
AWSClientType.SECRETS_MANAGER, region_name=self.region_name, endpoint_url=self._endpoint_url
)
def _s3_client(self) -> S3Client:
return self.session.client(AWSClientType.S3, region_name=self.region_name)
return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url)
async def get_secret(self, secret_name: str) -> str | None:
try:

View File

@@ -31,8 +31,8 @@ LOG = structlog.get_logger()
class S3Storage(BaseStorage):
_PATH_VERSION = "v1"
def __init__(self, bucket: str | None = None) -> None:
self.async_client = AsyncAWSClient()
def __init__(self, bucket: str | None = None, endpoint_url: str | None = None) -> None:
self.async_client = AsyncAWSClient(endpoint_url=endpoint_url)
self.bucket = bucket or settings.AWS_S3_BUCKET_ARTIFACTS
def build_uri(self, *, organization_id: str, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:

View File

@@ -1,8 +1,15 @@
from datetime import datetime
from pathlib import Path
from typing import Generator
import boto3
import pytest
from freezegun import freeze_time
from moto.server import ThreadedMotoServer
from skyvern.config import settings
from skyvern.forge.sdk.artifact.models import ArtifactType, LogEntityType
from skyvern.forge.sdk.api.aws import S3Uri
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
from skyvern.forge.sdk.artifact.storage.test_helpers import (
create_fake_for_ai_suggestion,
@@ -11,6 +18,7 @@ from skyvern.forge.sdk.artifact.storage.test_helpers import (
create_fake_thought,
create_fake_workflow_run_block,
)
from skyvern.forge.sdk.db.id import generate_artifact_id
# Test constants
TEST_BUCKET = "test-skyvern-bucket"
@@ -23,8 +31,38 @@ TEST_AI_SUGGESTION_ID = "ai_sugg_test_123"
@pytest.fixture
def s3_storage() -> S3Storage:
return S3Storage(bucket=TEST_BUCKET)
def s3_storage(moto_server: str) -> S3Storage:
return S3Storage(bucket=TEST_BUCKET, endpoint_url=moto_server)
@pytest.fixture(autouse=True)
def aws_credentials(monkeypatch: pytest.MonkeyPatch) -> None:
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
@pytest.fixture(scope="module")
def moto_server() -> Generator[str, None, None]:
# Note: pass `port=0` to get a random free port.
server = ThreadedMotoServer(port=0)
server.start()
host, port = server.get_host_and_port()
yield f"http://{host}:{port}"
server.stop()
@pytest.fixture(scope="module", autouse=True)
def boto3_test_client(moto_server: str) -> boto3.client:
client = boto3.client(
"s3",
aws_access_key_id="testing",
aws_secret_access_key="testing",
region_name=settings.AWS_REGION,
endpoint_url=moto_server,
)
client.create_bucket(Bucket=TEST_BUCKET) # Ensure the bucket exists for the test
yield client
@freeze_time("2025-06-09T12:00:00")
@@ -105,3 +143,48 @@ class TestS3StorageBuildURIs:
uri
== f"s3://{TEST_BUCKET}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/ai_suggestions/{TEST_AI_SUGGESTION_ID}/2025-06-09T12:00:00_artifact123_screenshot_llm.png"
)
@pytest.mark.asyncio
class TestS3StorageStore:
"""Test S3Storage store methods."""
def _create_artifact_for_ai_suggestion(
self,
s3_storage: S3Storage,
artifact_type: ArtifactType,
ai_suggestion_id: str,
) -> Artifact:
"""Helper method to create an Artifact for an AI suggestion."""
artifact_id_val = generate_artifact_id()
ai_suggestion = create_fake_for_ai_suggestion(ai_suggestion_id)
uri = s3_storage.build_ai_suggestion_uri(
organization_id=TEST_ORGANIZATION_ID,
artifact_id=artifact_id_val,
ai_suggestion=ai_suggestion,
artifact_type=artifact_type,
)
return Artifact(
artifact_id=artifact_id_val,
artifact_type=artifact_type,
uri=uri,
organization_id=TEST_ORGANIZATION_ID,
ai_suggestion_id=ai_suggestion.ai_suggestion_id,
created_at=datetime.utcnow(),
modified_at=datetime.utcnow(),
)
async def test_store_artifact_screenshot(
self, s3_storage: S3Storage, boto3_test_client: boto3.client, tmp_path: Path
) -> None:
test_data = b"fake screenshot data"
artifact = self._create_artifact_for_ai_suggestion(
s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID
)
s3uri = S3Uri(artifact.uri)
assert s3uri.bucket == TEST_BUCKET
test_file = tmp_path / "test_screenshot.png"
test_file.write_bytes(test_data)
await s3_storage.store_artifact_from_path(artifact, str(test_file))
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
assert obj_response["Body"].read() == test_data