basic test for s3 artifact upload logic (#2681)
This commit is contained in:
@@ -37,21 +37,25 @@ class AsyncAWSClient:
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
region_name: str | None = None,
|
||||
endpoint_url: str | None = None,
|
||||
) -> None:
|
||||
self.region_name = region_name or settings.AWS_REGION
|
||||
self._endpoint_url = endpoint_url
|
||||
self.session = aioboto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
|
||||
def _ecs_client(self) -> ECSClient:
|
||||
return self.session.client(AWSClientType.ECS, region_name=self.region_name)
|
||||
return self.session.client(AWSClientType.ECS, region_name=self.region_name, endpoint_url=self._endpoint_url)
|
||||
|
||||
def _secrets_manager_client(self) -> SecretsManagerClient:
|
||||
return self.session.client(AWSClientType.SECRETS_MANAGER, region_name=self.region_name)
|
||||
return self.session.client(
|
||||
AWSClientType.SECRETS_MANAGER, region_name=self.region_name, endpoint_url=self._endpoint_url
|
||||
)
|
||||
|
||||
def _s3_client(self) -> S3Client:
|
||||
return self.session.client(AWSClientType.S3, region_name=self.region_name)
|
||||
return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url)
|
||||
|
||||
async def get_secret(self, secret_name: str) -> str | None:
|
||||
try:
|
||||
|
||||
@@ -31,8 +31,8 @@ LOG = structlog.get_logger()
|
||||
class S3Storage(BaseStorage):
|
||||
_PATH_VERSION = "v1"
|
||||
|
||||
def __init__(self, bucket: str | None = None) -> None:
|
||||
self.async_client = AsyncAWSClient()
|
||||
def __init__(self, bucket: str | None = None, endpoint_url: str | None = None) -> None:
|
||||
self.async_client = AsyncAWSClient(endpoint_url=endpoint_url)
|
||||
self.bucket = bucket or settings.AWS_S3_BUCKET_ARTIFACTS
|
||||
|
||||
def build_uri(self, *, organization_id: str, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from moto.server import ThreadedMotoServer
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType, LogEntityType
|
||||
from skyvern.forge.sdk.api.aws import S3Uri
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
|
||||
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
|
||||
from skyvern.forge.sdk.artifact.storage.test_helpers import (
|
||||
create_fake_for_ai_suggestion,
|
||||
@@ -11,6 +18,7 @@ from skyvern.forge.sdk.artifact.storage.test_helpers import (
|
||||
create_fake_thought,
|
||||
create_fake_workflow_run_block,
|
||||
)
|
||||
from skyvern.forge.sdk.db.id import generate_artifact_id
|
||||
|
||||
# Test constants
|
||||
TEST_BUCKET = "test-skyvern-bucket"
|
||||
@@ -23,8 +31,38 @@ TEST_AI_SUGGESTION_ID = "ai_sugg_test_123"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_storage() -> S3Storage:
|
||||
return S3Storage(bucket=TEST_BUCKET)
|
||||
def s3_storage(moto_server: str) -> S3Storage:
|
||||
return S3Storage(bucket=TEST_BUCKET, endpoint_url=moto_server)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def aws_credentials(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Mocked AWS Credentials for moto."""
|
||||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
|
||||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def moto_server() -> Generator[str, None, None]:
|
||||
# Note: pass `port=0` to get a random free port.
|
||||
server = ThreadedMotoServer(port=0)
|
||||
server.start()
|
||||
host, port = server.get_host_and_port()
|
||||
yield f"http://{host}:{port}"
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def boto3_test_client(moto_server: str) -> boto3.client:
|
||||
client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id="testing",
|
||||
aws_secret_access_key="testing",
|
||||
region_name=settings.AWS_REGION,
|
||||
endpoint_url=moto_server,
|
||||
)
|
||||
client.create_bucket(Bucket=TEST_BUCKET) # Ensure the bucket exists for the test
|
||||
yield client
|
||||
|
||||
|
||||
@freeze_time("2025-06-09T12:00:00")
|
||||
@@ -105,3 +143,48 @@ class TestS3StorageBuildURIs:
|
||||
uri
|
||||
== f"s3://{TEST_BUCKET}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/ai_suggestions/{TEST_AI_SUGGESTION_ID}/2025-06-09T12:00:00_artifact123_screenshot_llm.png"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestS3StorageStore:
|
||||
"""Test S3Storage store methods."""
|
||||
|
||||
def _create_artifact_for_ai_suggestion(
|
||||
self,
|
||||
s3_storage: S3Storage,
|
||||
artifact_type: ArtifactType,
|
||||
ai_suggestion_id: str,
|
||||
) -> Artifact:
|
||||
"""Helper method to create an Artifact for an AI suggestion."""
|
||||
artifact_id_val = generate_artifact_id()
|
||||
ai_suggestion = create_fake_for_ai_suggestion(ai_suggestion_id)
|
||||
uri = s3_storage.build_ai_suggestion_uri(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
artifact_id=artifact_id_val,
|
||||
ai_suggestion=ai_suggestion,
|
||||
artifact_type=artifact_type,
|
||||
)
|
||||
return Artifact(
|
||||
artifact_id=artifact_id_val,
|
||||
artifact_type=artifact_type,
|
||||
uri=uri,
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
ai_suggestion_id=ai_suggestion.ai_suggestion_id,
|
||||
created_at=datetime.utcnow(),
|
||||
modified_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
async def test_store_artifact_screenshot(
|
||||
self, s3_storage: S3Storage, boto3_test_client: boto3.client, tmp_path: Path
|
||||
) -> None:
|
||||
test_data = b"fake screenshot data"
|
||||
artifact = self._create_artifact_for_ai_suggestion(
|
||||
s3_storage, ArtifactType.SCREENSHOT_LLM, TEST_AI_SUGGESTION_ID
|
||||
)
|
||||
s3uri = S3Uri(artifact.uri)
|
||||
assert s3uri.bucket == TEST_BUCKET
|
||||
test_file = tmp_path / "test_screenshot.png"
|
||||
test_file.write_bytes(test_data)
|
||||
await s3_storage.store_artifact_from_path(artifact, str(test_file))
|
||||
obj_response = boto3_test_client.get_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
assert obj_response["Body"].read() == test_data
|
||||
|
||||
Reference in New Issue
Block a user