From 55f366ba930e26271cb976611487841a4f05ae61 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sat, 20 Dec 2025 00:16:16 +0800 Subject: [PATCH] add azure blob storage (#4338) Signed-off-by: Benji Visser Co-authored-by: Benji Visser Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- .../src/routes/tasks/detail/Artifact.tsx | 6 +- .../src/routes/tasks/detail/artifactUtils.ts | 5 +- .../nodes/HttpRequestNode/HttpRequestNode.tsx | 2 +- skyvern/config.py | 21 +- skyvern/forge/forge_app.py | 3 + skyvern/forge/sdk/api/aws.py | 9 + skyvern/forge/sdk/api/azure.py | 50 +- skyvern/forge/sdk/api/files.py | 57 +- skyvern/forge/sdk/api/real_azure.py | 289 +++++++++- skyvern/forge/sdk/artifact/storage/azure.py | 539 ++++++++++++++++++ skyvern/forge/sdk/artifact/storage/base.py | 47 ++ skyvern/forge/sdk/artifact/storage/local.py | 103 +++- skyvern/forge/sdk/artifact/storage/s3.py | 101 +++- .../artifact/storage/test_azure_storage.py | 251 ++++++++ .../sdk/artifact/storage/test_s3_storage.py | 226 ++++++++ skyvern/forge/sdk/routes/agent_protocol.py | 63 +- skyvern/forge/sdk/workflow/models/block.py | 5 +- 17 files changed, 1641 insertions(+), 136 deletions(-) create mode 100644 skyvern/forge/sdk/artifact/storage/azure.py create mode 100644 skyvern/forge/sdk/artifact/storage/test_azure_storage.py diff --git a/skyvern-frontend/src/routes/tasks/detail/Artifact.tsx b/skyvern-frontend/src/routes/tasks/detail/Artifact.tsx index f90985f4..6367d3f1 100644 --- a/skyvern-frontend/src/routes/tasks/detail/Artifact.tsx +++ b/skyvern-frontend/src/routes/tasks/detail/Artifact.tsx @@ -54,6 +54,9 @@ type Props = { function Artifact({ type, artifacts }: Props) { function fetchArtifact(artifact: ArtifactApiResponse) { + if (artifact.signed_url) { + return axios.get(artifact.signed_url).then((response) => response.data); + } if (artifact.uri.startsWith("file://")) { const endpoint = getEndpoint(type); return artifactApiClient @@ -64,9 +67,6 @@ function Artifact({ type, artifacts }: Props) { }) .then((response) => response.data); } - if (artifact.uri.startsWith("s3://") && artifact.signed_url) { - return axios.get(artifact.signed_url).then((response) => response.data); - } } const results = useQueries({ diff --git a/skyvern-frontend/src/routes/tasks/detail/artifactUtils.ts b/skyvern-frontend/src/routes/tasks/detail/artifactUtils.ts index f3760f29..9ce24677 100644 --- a/skyvern-frontend/src/routes/tasks/detail/artifactUtils.ts +++ b/skyvern-frontend/src/routes/tasks/detail/artifactUtils.ts @@ -2,10 +2,11 @@ import { ArtifactApiResponse, TaskApiResponse } from "@/api/types"; import { artifactApiBaseUrl } from "@/util/env"; export function getImageURL(artifact: ArtifactApiResponse): string { + if (artifact.signed_url) { + return artifact.signed_url; + } if (artifact.uri.startsWith("file://")) { return `${artifactApiBaseUrl}/artifact/image?path=${artifact.uri.slice(7)}`; - } else if (artifact.uri.startsWith("s3://") && artifact.signed_url) { - return artifact.signed_url; } return artifact.uri; } diff --git a/skyvern-frontend/src/routes/workflows/editor/nodes/HttpRequestNode/HttpRequestNode.tsx b/skyvern-frontend/src/routes/workflows/editor/nodes/HttpRequestNode/HttpRequestNode.tsx index f3324db0..2a322189 100644 --- a/skyvern-frontend/src/routes/workflows/editor/nodes/HttpRequestNode/HttpRequestNode.tsx +++ b/skyvern-frontend/src/routes/workflows/editor/nodes/HttpRequestNode/HttpRequestNode.tsx @@ -63,7 +63,7 @@ const headersTooltip = const bodyTooltip = "Request body as JSON object. Only used for POST, PUT, PATCH methods."; const filesTooltip = - 'Files to upload as multipart/form-data. Dictionary mapping field names to file paths/URLs. Supports HTTP/HTTPS URLs, S3 URIs (s3://), or limited local file access. Example: {"file": "https://example.com/file.pdf"} or {"document": "s3://bucket/path/file.pdf"}'; + 'Files to upload as multipart/form-data. Dictionary mapping field names to file paths/URLs. Supports HTTP/HTTPS URLs, S3 URIs (s3://), Azure blob URIs (azure://), or limited local file access. Example: {"file": "https://example.com/file.pdf"} or {"document": "s3://bucket/path/file.pdf"}'; const timeoutTooltip = "Request timeout in seconds."; const followRedirectsTooltip = "Whether to automatically follow HTTP redirects."; diff --git a/skyvern/config.py b/skyvern/config.py index b9ec0760..c3de4b02 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -87,23 +87,26 @@ class Settings(BaseSettings): # Artifact storage settings ARTIFACT_STORAGE_PATH: str = f"{SKYVERN_DIR}/artifacts" - GENERATE_PRESIGNED_URLS: bool = False + + # Supported storage types: local, s3cloud, azureblob + SKYVERN_STORAGE_TYPE: str = "local" + + # S3/AWS settings + AWS_REGION: str = "us-east-1" + MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB + PRESIGNED_URL_EXPIRATION: int = 60 * 60 * 24 # 24 hours AWS_S3_BUCKET_ARTIFACTS: str = "skyvern-artifacts" AWS_S3_BUCKET_SCREENSHOTS: str = "skyvern-screenshots" AWS_S3_BUCKET_BROWSER_SESSIONS: str = "skyvern-browser-sessions" - - # Supported storage types: local, s3 - SKYVERN_STORAGE_TYPE: str = "local" - - # S3 bucket settings - AWS_REGION: str = "us-east-1" AWS_S3_BUCKET_UPLOADS: str = "skyvern-uploads" - MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB - PRESIGNED_URL_EXPIRATION: int = 60 * 60 * 24 # 24 hours # Azure Blob Storage settings AZURE_STORAGE_ACCOUNT_NAME: str | None = None AZURE_STORAGE_ACCOUNT_KEY: str | None = None + AZURE_STORAGE_CONTAINER_ARTIFACTS: str = "skyvern-artifacts" + AZURE_STORAGE_CONTAINER_SCREENSHOTS: str = "skyvern-screenshots" + AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS: str = "skyvern-browser-sessions" + AZURE_STORAGE_CONTAINER_UPLOADS: str = "skyvern-uploads" SKYVERN_TELEMETRY: bool = True ANALYTICS_ID: str = "anonymous" diff --git a/skyvern/forge/forge_app.py b/skyvern/forge/forge_app.py index 71463266..c428bfb1 100644 --- a/skyvern/forge/forge_app.py +++ b/skyvern/forge/forge_app.py @@ -16,6 +16,7 @@ from skyvern.forge.sdk.api.llm.api_handler import LLMAPIHandler from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.api.real_azure import RealAzureClientFactory from skyvern.forge.sdk.artifact.manager import ArtifactManager +from skyvern.forge.sdk.artifact.storage.azure import AzureStorage from skyvern.forge.sdk.artifact.storage.base import BaseStorage from skyvern.forge.sdk.artifact.storage.factory import StorageFactory from skyvern.forge.sdk.artifact.storage.s3 import S3Storage @@ -104,6 +105,8 @@ def create_forge_app() -> ForgeApp: if settings.SKYVERN_STORAGE_TYPE == "s3": StorageFactory.set_storage(S3Storage()) + elif settings.SKYVERN_STORAGE_TYPE == "azureblob": + StorageFactory.set_storage(AzureStorage()) app.STORAGE = StorageFactory.get_storage() app.CACHE = CacheFactory.get_cache() app.ARTIFACT_MANAGER = ArtifactManager() diff --git a/skyvern/forge/sdk/api/aws.py b/skyvern/forge/sdk/api/aws.py index e777a48c..7f83b547 100644 --- a/skyvern/forge/sdk/api/aws.py +++ b/skyvern/forge/sdk/api/aws.py @@ -1,4 +1,5 @@ from enum import StrEnum +from mimetypes import add_type, guess_type from typing import IO, Any from urllib.parse import urlparse @@ -12,6 +13,10 @@ from types_boto3_secretsmanager.client import SecretsManagerClient from skyvern.config import settings +# Register custom mime types for mimetypes guessing +add_type("application/json", ".har") +add_type("text/plain", ".log") + LOG = structlog.get_logger() @@ -188,6 +193,10 @@ class AsyncAWSClient: extra_args["Tagging"] = self._create_tag_string(tags) if content_type: extra_args["ContentType"] = content_type + else: + guessed_type, _ = guess_type(file_path) + if guessed_type: + extra_args["ContentType"] = guessed_type await client.upload_file( Filename=file_path, Bucket=parsed_uri.bucket, diff --git a/skyvern/forge/sdk/api/azure.py b/skyvern/forge/sdk/api/azure.py index af07cfb4..996e482c 100644 --- a/skyvern/forge/sdk/api/azure.py +++ b/skyvern/forge/sdk/api/azure.py @@ -1,8 +1,35 @@ from typing import Protocol, Self +from urllib.parse import urlparse + +from azure.storage.blob import StandardBlobTier from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential +class AzureUri: + """Parse azure://{container}/{blob_path} URIs.""" + + def __init__(self, uri: str) -> None: + self._parsed = urlparse(uri, allow_fragments=False) + + @property + def container(self) -> str: + return self._parsed.netloc + + @property + def blob_path(self) -> str: + if self._parsed.query: + return self._parsed.path.lstrip("/") + "?" + self._parsed.query + return self._parsed.path.lstrip("/") + + @property + def uri(self) -> str: + return self._parsed.geturl() + + def __str__(self) -> str: + return self.uri + + class AsyncAzureVaultClient(Protocol): """Protocol defining the interface for Azure Vault clients. @@ -68,21 +95,24 @@ class AsyncAzureVaultClient(Protocol): class AsyncAzureStorageClient(Protocol): - """Protocol defining the interface for Azure Storage clients. + """Protocol defining the interface for Azure Storage clients.""" - This client provides methods to interact with Azure Blob Storage for file operations. - """ - - async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None: + async def upload_file_from_path( + self, + uri: str, + file_path: str, + tier: StandardBlobTier = StandardBlobTier.HOT, + tags: dict[str, str] | None = None, + metadata: dict[str, str] | None = None, + ) -> None: """Upload a file from the local filesystem to Azure Blob Storage. Args: - container_name: The name of the Azure Blob container - blob_name: The name to give the blob in storage + uri: The azure:// URI for the blob (azure://container/blob_path) file_path: The local path to the file to upload - - Raises: - Exception: If the upload fails + tier: The storage tier for the blob + tags: Optional tags to attach to the blob + metadata: Optional metadata to attach to the blob """ ... diff --git a/skyvern/forge/sdk/api/files.py b/skyvern/forge/sdk/api/files.py index 75866c51..4264998f 100644 --- a/skyvern/forge/sdk/api/files.py +++ b/skyvern/forge/sdk/api/files.py @@ -17,7 +17,8 @@ from yarl import URL from skyvern.config import settings from skyvern.constants import BROWSER_DOWNLOAD_TIMEOUT, BROWSER_DOWNLOADING_SUFFIX, REPO_ROOT_DIR from skyvern.exceptions import DownloadFileMaxSizeExceeded, DownloadFileMaxWaitingTime -from skyvern.forge.sdk.api.aws import AsyncAWSClient, aws_client +from skyvern.forge import app +from skyvern.forge.sdk.api.aws import AsyncAWSClient from skyvern.utils.url_validators import encode_url LOG = structlog.get_logger() @@ -97,6 +98,12 @@ def validate_download_url(url: str) -> bool: return True return False + # Allow Azure URIs for Skyvern uploads container + if scheme == "azure": + if url.startswith(f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{settings.ENV}/o_"): + return True + return False + # Allow file:// URLs only in local environment if scheme == "file": if settings.ENV != "local": @@ -129,20 +136,41 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str: url = f"https://drive.google.com/uc?export=download&id={file_id}" LOG.info("Converting Google Drive link to direct download", url=url) - # Check if URL is an S3 URI - if url.startswith(f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"): - LOG.info("Downloading Skyvern file from S3", url=url) - client = AsyncAWSClient() - return await download_from_s3(client, url) + # Check if URL is a cloud storage URI (S3 or Azure) + parsed = urlparse(url) + if parsed.scheme == "s3": + uploads_prefix = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_" + if url.startswith(uploads_prefix): + LOG.info("Downloading Skyvern file from S3", url=url) + data = await app.STORAGE.download_uploaded_file(url) + if data is None: + raise Exception(f"Failed to download file from S3: {url}") + filename = url.split("/")[-1] + temp_file = create_named_temporary_file(delete=False, file_name=filename) + LOG.info(f"Downloaded file to {temp_file.name}") + temp_file.write(data) + return temp_file.name + elif parsed.scheme == "azure": + uploads_prefix = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{settings.ENV}/o_" + if url.startswith(uploads_prefix): + LOG.info("Downloading Skyvern file from Azure Blob Storage", url=url) + data = await app.STORAGE.download_uploaded_file(url) + if data is None: + raise Exception(f"Failed to download file from Azure Blob Storage: {url}") + filename = url.split("/")[-1] + temp_file = create_named_temporary_file(delete=False, file_name=filename) + LOG.info(f"Downloaded file to {temp_file.name}") + temp_file.write(data) + return temp_file.name # Check if URL is a file:// URI # we only support to download local files when the environment is local # and the file is in the skyvern downloads directory if url.startswith("file://") and settings.ENV == "local": - file_path = parse_uri_to_path(url) - if file_path.startswith(f"{REPO_ROOT_DIR}/downloads"): + local_path = parse_uri_to_path(url) + if local_path.startswith(f"{REPO_ROOT_DIR}/downloads"): LOG.info("Downloading file from local file system", url=url) - return file_path + return local_path async with aiohttp.ClientSession(raise_for_status=True) as session: LOG.info("Starting to download file", url=url) @@ -262,12 +290,13 @@ async def wait_for_download_finished(downloading_files: list[str], timeout: floa while len(cur_downloading_files) > 0: new_downloading_files: list[str] = [] for path in cur_downloading_files: - if path.startswith("s3://"): - try: - await aws_client.get_object_info(path) - except Exception: + # Check for cloud storage URIs (S3 or Azure) + parsed = urlparse(path) + if parsed.scheme in ("s3", "azure"): + if not await app.STORAGE.file_exists(path): LOG.debug( - "downloading file is not found in s3, means the file finished downloading", path=path + "downloading file is not found in cloud storage, means the file finished downloading", + path=path, ) continue else: diff --git a/skyvern/forge/sdk/api/real_azure.py b/skyvern/forge/sdk/api/real_azure.py index ad7b8aab..41109a89 100644 --- a/skyvern/forge/sdk/api/real_azure.py +++ b/skyvern/forge/sdk/api/real_azure.py @@ -1,15 +1,29 @@ """Real implementations of Azure clients (Vault and Storage) and their factories.""" -from typing import Self +from datetime import datetime, timedelta, timezone +from mimetypes import add_type, guess_type +from typing import IO, Self import structlog +from azure.core.exceptions import ResourceNotFoundError from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential from azure.keyvault.secrets.aio import SecretClient +from azure.storage.blob import BlobSasPermissions, ContentSettings, StandardBlobTier, generate_blob_sas from azure.storage.blob.aio import BlobServiceClient -from skyvern.forge.sdk.api.azure import AsyncAzureStorageClient, AsyncAzureVaultClient, AzureClientFactory +from skyvern.config import settings +from skyvern.forge.sdk.api.azure import ( + AsyncAzureStorageClient, + AsyncAzureVaultClient, + AzureClientFactory, + AzureUri, +) from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential +# Register custom mime types for mimetypes guessing +add_type("application/json", ".har") +add_type("text/plain", ".log") + LOG = structlog.get_logger() @@ -73,37 +87,256 @@ class RealAsyncAzureVaultClient(AsyncAzureVaultClient): class RealAsyncAzureStorageClient(AsyncAzureStorageClient): - """Real implementation of Azure Storage client using Azure SDK.""" + """Async client for Azure Blob Storage operations. Implements AsyncAzureStorageClient protocol.""" - def __init__(self, storage_account_name: str, storage_account_key: str): - self.blob_service_client = BlobServiceClient( - account_url=f"https://{storage_account_name}.blob.core.windows.net", - credential=storage_account_key, + def __init__( + self, + account_name: str | None = None, + account_key: str | None = None, + ) -> None: + self.account_name = account_name or settings.AZURE_STORAGE_ACCOUNT_NAME + self.account_key = account_key or settings.AZURE_STORAGE_ACCOUNT_KEY + + if not self.account_name or not self.account_key: + raise ValueError("Azure Storage account name and key are required") + + self._blob_service_client: BlobServiceClient | None = None + self._verified_containers: set[str] = set() + + def _get_blob_service_client(self) -> BlobServiceClient: + if self._blob_service_client is None: + self._blob_service_client = BlobServiceClient( + account_url=f"https://{self.account_name}.blob.core.windows.net", + credential=self.account_key, + ) + return self._blob_service_client + + async def _ensure_container_exists(self, container: str) -> None: + if container in self._verified_containers: + return + client = self._get_blob_service_client() + container_client = client.get_container_client(container) + try: + if not await container_client.exists(): + await container_client.create_container() + LOG.info("Created Azure container", container=container) + except Exception: + LOG.debug("Container may already exist", container=container) + self._verified_containers.add(container) + + async def upload_file( + self, + uri: str, + data: bytes, + tier: StandardBlobTier = StandardBlobTier.HOT, + tags: dict[str, str] | None = None, + ) -> None: + parsed = AzureUri(uri) + await self._ensure_container_exists(parsed.container) + client = self._get_blob_service_client() + container_client = client.get_container_client(parsed.container) + await container_client.upload_blob( + name=parsed.blob_path, + data=data, + overwrite=True, + standard_blob_tier=tier, + tags=tags, ) - async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None: - try: - container_client = self.blob_service_client.get_container_client(container_name) - # Create the container if it doesn't exist - try: - await container_client.create_container() - except Exception as e: - LOG.info("Azure container already exists or failed to create", container_name=container_name, error=e) - - with open(file_path, "rb") as data: - await container_client.upload_blob(name=blob_name, data=data, overwrite=True) - LOG.info("File uploaded to Azure Blob Storage", container_name=container_name, blob_name=blob_name) - except Exception as e: - LOG.error( - "Failed to upload file to Azure Blob Storage", - container_name=container_name, - blob_name=blob_name, - error=e, + async def upload_file_from_path( + self, + uri: str, + file_path: str, + tier: StandardBlobTier = StandardBlobTier.HOT, + tags: dict[str, str] | None = None, + metadata: dict[str, str] | None = None, + ) -> None: + parsed = AzureUri(uri) + await self._ensure_container_exists(parsed.container) + client = self._get_blob_service_client() + container_client = client.get_container_client(parsed.container) + content_type, _ = guess_type(file_path) + content_settings = ContentSettings(content_type=content_type) if content_type else None + with open(file_path, "rb") as f: + await container_client.upload_blob( + name=parsed.blob_path, + data=f, + overwrite=True, + standard_blob_tier=tier, + tags=tags, + metadata=metadata, + content_settings=content_settings, ) - raise e + + async def upload_file_stream( + self, + uri: str, + file_obj: IO[bytes], + tier: StandardBlobTier = StandardBlobTier.HOT, + tags: dict[str, str] | None = None, + metadata: dict[str, str] | None = None, + ) -> str: + parsed = AzureUri(uri) + await self._ensure_container_exists(parsed.container) + client = self._get_blob_service_client() + container_client = client.get_container_client(parsed.container) + await container_client.upload_blob( + name=parsed.blob_path, + data=file_obj, + overwrite=True, + standard_blob_tier=tier, + tags=tags, + metadata=metadata, + ) + return uri + + async def download_file(self, uri: str, log_exception: bool = True) -> bytes | None: + parsed = AzureUri(uri) + try: + client = self._get_blob_service_client() + container_client = client.get_container_client(parsed.container) + blob_client = container_client.get_blob_client(parsed.blob_path) + download = await blob_client.download_blob() + return await download.readall() + except ResourceNotFoundError: + if log_exception: + LOG.warning("Azure blob not found", uri=uri) + return None + except Exception: + if log_exception: + LOG.exception("Failed to download from Azure", uri=uri) + return None + + async def get_blob_properties(self, uri: str) -> dict | None: + parsed = AzureUri(uri) + try: + client = self._get_blob_service_client() + container_client = client.get_container_client(parsed.container) + blob_client = container_client.get_blob_client(parsed.blob_path) + props = await blob_client.get_blob_properties() + return { + "size": props.size, + "content_type": props.content_settings.content_type if props.content_settings else None, + "last_modified": props.last_modified, + "etag": props.etag, + "metadata": props.metadata, + } + except ResourceNotFoundError: + return None + except Exception: + LOG.exception("Failed to get blob properties", uri=uri) + return None + + async def blob_exists(self, uri: str) -> bool: + parsed = AzureUri(uri) + try: + client = self._get_blob_service_client() + container_client = client.get_container_client(parsed.container) + blob_client = container_client.get_blob_client(parsed.blob_path) + return await blob_client.exists() + except Exception: + return False + + async def delete_blob(self, uri: str) -> None: + parsed = AzureUri(uri) + try: + client = self._get_blob_service_client() + container_client = client.get_container_client(parsed.container) + blob_client = container_client.get_blob_client(parsed.blob_path) + await blob_client.delete_blob() + except ResourceNotFoundError: + LOG.debug("Azure blob not found for deletion", uri=uri) + except Exception: + LOG.exception("Failed to delete Azure blob", uri=uri) + raise + + async def list_blobs(self, container: str, prefix: str | None = None) -> list[str]: + try: + client = self._get_blob_service_client() + container_client = client.get_container_client(container) + blobs = [] + async for blob in container_client.list_blobs(name_starts_with=prefix): + blobs.append(blob.name) + return blobs + except ResourceNotFoundError: + return [] + except Exception: + LOG.exception("Failed to list Azure blobs", container=container, prefix=prefix) + return [] + + def create_sas_url(self, uri: str, expiry_hours: int = 24) -> str | None: + parsed = AzureUri(uri) + try: + sas_token = generate_blob_sas( + account_name=self.account_name, + container_name=parsed.container, + blob_name=parsed.blob_path, + account_key=self.account_key, + permission=BlobSasPermissions(read=True), + expiry=datetime.now(timezone.utc) + timedelta(hours=expiry_hours), + ) + return ( + f"https://{self.account_name}.blob.core.windows.net/{parsed.container}/{parsed.blob_path}?{sas_token}" + ) + except Exception: + LOG.exception("Failed to create SAS URL", uri=uri) + return None + + async def create_sas_urls(self, uris: list[str], expiry_hours: int = 24) -> list[str] | None: + try: + sas_urls: list[str] = [] + for uri in uris: + url = self.create_sas_url(uri, expiry_hours) + if url is None: + LOG.warning("SAS URL generation failed, aborting batch", failed_uri=uri, uris=uris) + return None + sas_urls.append(url) + return sas_urls + except Exception: + LOG.exception("Failed to create SAS URLs") + return None async def close(self) -> None: - await self.blob_service_client.close() + if self._blob_service_client: + await self._blob_service_client.close() + self._blob_service_client = None + + async def list_files(self, uri: str) -> list[str]: + """List files under a URI prefix. Returns blob names relative to container.""" + parsed = AzureUri(uri) + return await self.list_blobs(parsed.container, parsed.blob_path) + + async def get_object_info(self, uri: str) -> dict | None: + """Get object info including metadata. Returns dict with Metadata and LastModified keys.""" + props = await self.get_blob_properties(uri) + if props is None: + return None + return { + "Metadata": props.get("metadata", {}), + "LastModified": props.get("last_modified"), + } + + async def delete_file(self, uri: str) -> None: + """Delete a file at the given URI.""" + await self.delete_blob(uri) + + async def get_file_metadata(self, uri: str, log_exception: bool = True) -> dict[str, str] | None: + """Get only the metadata for a file.""" + parsed = AzureUri(uri) + try: + client = self._get_blob_service_client() + container_client = client.get_container_client(parsed.container) + blob_client = container_client.get_blob_client(parsed.blob_path) + props = await blob_client.get_blob_properties() + return props.metadata or {} + except ResourceNotFoundError: + if log_exception: + LOG.warning("Azure blob not found for metadata", uri=uri) + return None + except Exception: + if log_exception: + LOG.exception("Failed to get blob metadata", uri=uri) + return None class RealAzureClientFactory(AzureClientFactory): @@ -124,4 +357,4 @@ class RealAzureClientFactory(AzureClientFactory): def create_storage_client(self, storage_account_name: str, storage_account_key: str) -> AsyncAzureStorageClient: """Create an Azure Storage client with the provided credentials.""" - return RealAsyncAzureStorageClient(storage_account_name, storage_account_key) + return RealAsyncAzureStorageClient(account_name=storage_account_name, account_key=storage_account_key) diff --git a/skyvern/forge/sdk/artifact/storage/azure.py b/skyvern/forge/sdk/artifact/storage/azure.py new file mode 100644 index 00000000..3b35e3eb --- /dev/null +++ b/skyvern/forge/sdk/artifact/storage/azure.py @@ -0,0 +1,539 @@ +import os +import shutil +import uuid +from datetime import datetime, timezone +from typing import BinaryIO + +import structlog + +from skyvern.config import settings +from skyvern.constants import BROWSER_DOWNLOADING_SUFFIX, DOWNLOAD_FILE_PREFIX +from skyvern.forge.sdk.api.azure import StandardBlobTier +from skyvern.forge.sdk.api.files import ( + calculate_sha256_for_file, + create_named_temporary_file, + get_download_dir, + get_skyvern_temp_dir, + make_temp_directory, + unzip_files, +) +from skyvern.forge.sdk.api.real_azure import RealAsyncAzureStorageClient +from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType +from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage +from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion +from skyvern.forge.sdk.schemas.files import FileInfo +from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought +from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock + +LOG = structlog.get_logger() + + +class AzureStorage(BaseStorage): + _PATH_VERSION = "v1" + + def __init__( + self, + container: str | None = None, + account_name: str | None = None, + account_key: str | None = None, + ) -> None: + self.async_client = RealAsyncAzureStorageClient(account_name=account_name, account_key=account_key) + self.container = container or settings.AZURE_STORAGE_CONTAINER_ARTIFACTS + + def build_uri(self, *, organization_id: str, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str: + file_ext = FILE_EXTENTSION_MAP[artifact_type] + return f"{self._build_base_uri(organization_id)}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + + async def retrieve_global_workflows(self) -> list[str]: + uri = f"azure://{self.container}/{settings.ENV}/global_workflows.txt" + data = await self.async_client.download_file(uri, log_exception=False) + if not data: + return [] + return [line.strip() for line in data.decode("utf-8").split("\n") if line.strip()] + + def _build_base_uri(self, organization_id: str) -> str: + return f"azure://{self.container}/{self._PATH_VERSION}/{settings.ENV}/{organization_id}" + + def build_log_uri( + self, *, organization_id: str, log_entity_type: LogEntityType, log_entity_id: str, artifact_type: ArtifactType + ) -> str: + file_ext = FILE_EXTENTSION_MAP[artifact_type] + return f"{self._build_base_uri(organization_id)}/logs/{log_entity_type}/{log_entity_id}/{datetime.utcnow().isoformat()}_{artifact_type}.{file_ext}" + + def build_thought_uri( + self, *, organization_id: str, artifact_id: str, thought: Thought, artifact_type: ArtifactType + ) -> str: + file_ext = FILE_EXTENTSION_MAP[artifact_type] + return f"{self._build_base_uri(organization_id)}/observers/{thought.observer_cruise_id}/{thought.observer_thought_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + + def build_task_v2_uri( + self, *, organization_id: str, artifact_id: str, task_v2: TaskV2, artifact_type: ArtifactType + ) -> str: + file_ext = FILE_EXTENTSION_MAP[artifact_type] + return f"{self._build_base_uri(organization_id)}/observers/{task_v2.observer_cruise_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + + def build_workflow_run_block_uri( + self, + *, + organization_id: str, + artifact_id: str, + workflow_run_block: WorkflowRunBlock, + artifact_type: ArtifactType, + ) -> str: + file_ext = FILE_EXTENTSION_MAP[artifact_type] + return f"{self._build_base_uri(organization_id)}/workflow_runs/{workflow_run_block.workflow_run_id}/{workflow_run_block.workflow_run_block_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + + def build_ai_suggestion_uri( + self, *, organization_id: str, artifact_id: str, ai_suggestion: AISuggestion, artifact_type: ArtifactType + ) -> str: + file_ext = FILE_EXTENTSION_MAP[artifact_type] + return f"{self._build_base_uri(organization_id)}/ai_suggestions/{ai_suggestion.ai_suggestion_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}" + + def build_script_file_uri( + self, *, organization_id: str, script_id: str, script_version: int, file_path: str + ) -> str: + """Build the Azure URI for a script file.""" + return f"{self._build_base_uri(organization_id)}/scripts/{script_id}/{script_version}/{file_path}" + + async def store_artifact(self, artifact: Artifact, data: bytes) -> None: + tier = await self._get_storage_tier_for_org(artifact.organization_id) + tags = await self._get_tags_for_org(artifact.organization_id) + LOG.debug( + "Storing artifact", + artifact_id=artifact.artifact_id, + organization_id=artifact.organization_id, + uri=artifact.uri, + storage_tier=tier, + tags=tags, + ) + await self.async_client.upload_file(artifact.uri, data, tier=tier, tags=tags) + + async def _get_storage_tier_for_org(self, organization_id: str) -> StandardBlobTier: + return StandardBlobTier.HOT + + async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]: + return {} + + async def retrieve_artifact(self, artifact: Artifact) -> bytes | None: + return await self.async_client.download_file(artifact.uri) + + async def get_share_link(self, artifact: Artifact) -> str | None: + share_urls = await self.async_client.create_sas_urls([artifact.uri]) + return share_urls[0] if share_urls else None + + async def get_share_links(self, artifacts: list[Artifact]) -> list[str] | None: + return await self.async_client.create_sas_urls([artifact.uri for artifact in artifacts]) + + async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None: + tier = await self._get_storage_tier_for_org(artifact.organization_id) + tags = await self._get_tags_for_org(artifact.organization_id) + LOG.debug( + "Storing artifact from path", + artifact_id=artifact.artifact_id, + organization_id=artifact.organization_id, + uri=artifact.uri, + storage_tier=tier, + path=path, + tags=tags, + ) + await self.async_client.upload_file_from_path(artifact.uri, path, tier=tier, tags=tags) + + async def save_streaming_file(self, organization_id: str, file_name: str) -> None: + from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}" + to_path = f"azure://{settings.AZURE_STORAGE_CONTAINER_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}" + tier = await self._get_storage_tier_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) + LOG.debug( + "Saving streaming file", + organization_id=organization_id, + file_name=file_name, + from_path=from_path, + to_path=to_path, + storage_tier=tier, + tags=tags, + ) + await self.async_client.upload_file_from_path(to_path, from_path, tier=tier, tags=tags) + + async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None: + path = f"azure://{settings.AZURE_STORAGE_CONTAINER_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}" + return await self.async_client.download_file(path, log_exception=False) + + async def store_browser_session(self, organization_id: str, workflow_permanent_id: str, directory: str) -> None: + # Zip the directory to a temp file + temp_zip_file = create_named_temporary_file() + zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory) + browser_session_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip" + tier = await self._get_storage_tier_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) + LOG.debug( + "Storing browser session", + organization_id=organization_id, + workflow_permanent_id=workflow_permanent_id, + zip_file_path=zip_file_path, + browser_session_uri=browser_session_uri, + storage_tier=tier, + tags=tags, + ) + await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, tier=tier, tags=tags) + + async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None: + browser_session_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip" + downloaded_zip_bytes = await self.async_client.download_file(browser_session_uri, log_exception=True) + if not downloaded_zip_bytes: + return None + temp_zip_file = create_named_temporary_file(delete=False) + temp_zip_file.write(downloaded_zip_bytes) + temp_zip_file_path = temp_zip_file.name + + temp_dir = make_temp_directory(prefix="skyvern_browser_session_") + unzip_files(temp_zip_file_path, temp_dir) + temp_zip_file.close() + return temp_dir + + async def store_browser_profile(self, organization_id: str, profile_id: str, directory: str) -> None: + """Store browser profile to Azure.""" + temp_zip_file = create_named_temporary_file() + zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory) + profile_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/profiles/{profile_id}.zip" + tier = await self._get_storage_tier_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) + LOG.debug( + "Storing browser profile", + organization_id=organization_id, + profile_id=profile_id, + zip_file_path=zip_file_path, + profile_uri=profile_uri, + storage_tier=tier, + tags=tags, + ) + await self.async_client.upload_file_from_path(profile_uri, zip_file_path, tier=tier, tags=tags) + + async def retrieve_browser_profile(self, organization_id: str, profile_id: str) -> str | None: + """Retrieve browser profile from Azure.""" + profile_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/profiles/{profile_id}.zip" + downloaded_zip_bytes = await self.async_client.download_file(profile_uri, log_exception=True) + if not downloaded_zip_bytes: + return None + temp_zip_file = create_named_temporary_file(delete=False) + temp_zip_file.write(downloaded_zip_bytes) + temp_zip_file_path = temp_zip_file.name + + temp_dir = make_temp_directory(prefix="skyvern_browser_profile_") + unzip_files(temp_zip_file_path, temp_dir) + temp_zip_file.close() + return temp_dir + + async def list_downloaded_files_in_browser_session( + self, organization_id: str, browser_session_id: str + ) -> list[str]: + uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads" + return [ + f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/{file}" + for file in await self.async_client.list_files(uri=uri) + ] + + async def get_shared_downloaded_files_in_browser_session( + self, organization_id: str, browser_session_id: str + ) -> list[FileInfo]: + object_keys = await self.list_downloaded_files_in_browser_session(organization_id, browser_session_id) + if len(object_keys) == 0: + return [] + + file_infos: list[FileInfo] = [] + for key in object_keys: + metadata = {} + modified_at: datetime | None = None + # Get metadata (including checksum) + try: + object_info = await self.async_client.get_object_info(key) + if object_info: + metadata = object_info.get("Metadata", {}) + modified_at = object_info.get("LastModified") + except Exception: + LOG.exception("Object info retrieval failed", uri=key) + + # Create FileInfo object + filename = os.path.basename(key) + checksum = metadata.get("sha256_checksum") if metadata else None + + # Get SAS URL + sas_urls = await self.async_client.create_sas_urls([key]) + if not sas_urls: + continue + + file_info = FileInfo( + url=sas_urls[0], + checksum=checksum, + filename=metadata.get("original_filename", filename) if metadata else filename, + modified_at=modified_at, + ) + file_infos.append(file_info) + + return file_infos + + async def list_downloading_files_in_browser_session( + self, organization_id: str, browser_session_id: str + ) -> list[str]: + uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads" + files = [ + f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/{file}" + for file in await self.async_client.list_files(uri=uri) + ] + return [file for file in files if file.endswith(BROWSER_DOWNLOADING_SUFFIX)] + + async def list_recordings_in_browser_session(self, organization_id: str, browser_session_id: str) -> list[str]: + """List all recording files for a browser session from Azure.""" + uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/videos" + return [ + f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/{file}" + for file in await self.async_client.list_files(uri=uri) + ] + + async def get_shared_recordings_in_browser_session( + self, organization_id: str, browser_session_id: str + ) -> list[FileInfo]: + """Get recording files with SAS URLs for a browser session.""" + object_keys = await self.list_recordings_in_browser_session(organization_id, browser_session_id) + if len(object_keys) == 0: + return [] + + file_infos: list[FileInfo] = [] + for key in object_keys: + # Playwright's record_video_dir should only contain .webm files. + # Filter defensively in case of unexpected files. + key_lower = key.lower() + if not (key_lower.endswith(".webm") or key_lower.endswith(".mp4")): + LOG.warning( + "Skipping recording file with unsupported extension", + uri=key, + organization_id=organization_id, + browser_session_id=browser_session_id, + ) + continue + + metadata = {} + modified_at: datetime | None = None + content_length: int | None = None + # Get metadata (including checksum) + try: + object_info = await self.async_client.get_object_info(key) + if object_info: + metadata = object_info.get("Metadata", {}) + modified_at = object_info.get("LastModified") + content_length = object_info.get("ContentLength") or object_info.get("Size") + except Exception: + LOG.exception("Recording object info retrieval failed", uri=key) + + # Skip zero-byte objects (if any incomplete uploads) + if content_length == 0: + continue + + # Create FileInfo object + filename = os.path.basename(key) + checksum = metadata.get("sha256_checksum") if metadata else None + + # Get SAS URL + sas_urls = await self.async_client.create_sas_urls([key]) + if not sas_urls: + continue + + file_info = FileInfo( + url=sas_urls[0], + checksum=checksum, + filename=metadata.get("original_filename", filename) if metadata else filename, + modified_at=modified_at, + ) + file_infos.append(file_info) + + # Prefer the newest recording first (Azure list order is not guaranteed). + # Treat None as "oldest". + file_infos.sort(key=lambda f: (f.modified_at is not None, f.modified_at), reverse=True) + return file_infos + + async def save_downloaded_files(self, organization_id: str, run_id: str | None) -> None: + download_dir = get_download_dir(run_id=run_id) + files = os.listdir(download_dir) + tier = await self._get_storage_tier_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) + base_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}" + for file in files: + fpath = os.path.join(download_dir, file) + if not os.path.isfile(fpath): + continue + uri = f"{base_uri}/{file}" + checksum = calculate_sha256_for_file(fpath) + LOG.info( + "Calculated checksum for file", + file=file, + checksum=checksum, + organization_id=organization_id, + storage_tier=tier, + ) + # Upload file with checksum metadata + await self.async_client.upload_file_from_path( + uri=uri, + file_path=fpath, + metadata={"sha256_checksum": checksum, "original_filename": file}, + tier=tier, + tags=tags, + ) + + async def get_downloaded_files(self, organization_id: str, run_id: str | None) -> list[FileInfo]: + uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}" + object_keys = await self.async_client.list_files(uri=uri) + if len(object_keys) == 0: + return [] + + file_infos: list[FileInfo] = [] + for key in object_keys: + object_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{key}" + + # Get metadata (including checksum) + metadata = await self.async_client.get_file_metadata(object_uri, log_exception=False) + + # Create FileInfo object + filename = os.path.basename(key) + checksum = metadata.get("sha256_checksum") if metadata else None + + # Get SAS URL + sas_urls = await self.async_client.create_sas_urls([object_uri]) + if not sas_urls: + continue + + file_info = FileInfo( + url=sas_urls[0], + checksum=checksum, + filename=metadata.get("original_filename", filename) if metadata else filename, + ) + file_infos.append(file_info) + + return file_infos + + async def save_legacy_file( + self, *, organization_id: str, filename: str, fileObj: BinaryIO + ) -> tuple[str, str] | None: + todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + container = settings.AZURE_STORAGE_CONTAINER_UPLOADS + tier = await self._get_storage_tier_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) + # First try uploading with original filename + try: + sanitized_filename = os.path.basename(filename) # Remove any path components + azure_uri = f"azure://{container}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}" + uploaded_uri = await self.async_client.upload_file_stream(azure_uri, fileObj, tier=tier, tags=tags) + except Exception: + LOG.error("Failed to upload file to Azure", exc_info=True) + uploaded_uri = None + + # If upload fails, try again with UUID prefix + if not uploaded_uri: + uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}" + azure_uri = f"azure://{container}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}" + fileObj.seek(0) # Reset file pointer + uploaded_uri = await self.async_client.upload_file_stream(azure_uri, fileObj, tier=tier, tags=tags) + + if not uploaded_uri: + LOG.error( + "Failed to upload file to Azure after retrying with UUID prefix", + organization_id=organization_id, + storage_tier=tier, + filename=filename, + exc_info=True, + ) + return None + LOG.debug( + "Legacy file upload", + organization_id=organization_id, + storage_tier=tier, + filename=filename, + uploaded_uri=uploaded_uri, + ) + # Generate a SAS URL for the uploaded file + sas_urls = await self.async_client.create_sas_urls([uploaded_uri]) + if not sas_urls: + LOG.error( + "Failed to create SAS URL for uploaded file", + organization_id=organization_id, + storage_tier=tier, + uploaded_uri=uploaded_uri, + filename=filename, + exc_info=True, + ) + return None + return sas_urls[0], uploaded_uri + + def _build_browser_session_uri( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> str: + """Build the Azure URI for a browser session file.""" + base = f"azure://{self.container}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/{artifact_type}" + if date: + return f"{base}/{date}/{remote_path}" + return f"{base}/{remote_path}" + + async def sync_browser_session_file( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + local_file_path: str, + remote_path: str, + date: str | None = None, + ) -> str: + """Sync a file from local browser session to Azure.""" + uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date) + tier = await self._get_storage_tier_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) + await self.async_client.upload_file_from_path(uri, local_file_path, tier=tier, tags=tags) + return uri + + async def delete_browser_session_file( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> None: + """Delete a file from browser session storage in Azure.""" + uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date) + await self.async_client.delete_file(uri) + + async def browser_session_file_exists( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> bool: + """Check if a file exists in browser session storage in Azure.""" + uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date) + try: + info = await self.async_client.get_object_info(uri) + return info is not None + except Exception: + return False + + async def download_uploaded_file(self, uri: str) -> bytes | None: + """Download a user-uploaded file from Azure.""" + return await self.async_client.download_file(uri, log_exception=False) + + async def file_exists(self, uri: str) -> bool: + """Check if a file exists at the given Azure URI.""" + try: + info = await self.async_client.get_object_info(uri) + return info is not None + except Exception: + return False + + @property + def storage_type(self) -> str: + """Returns 'azure' as the storage type.""" + return "azure" diff --git a/skyvern/forge/sdk/artifact/storage/base.py b/skyvern/forge/sdk/artifact/storage/base.py index de44be46..cf899320 100644 --- a/skyvern/forge/sdk/artifact/storage/base.py +++ b/skyvern/forge/sdk/artifact/storage/base.py @@ -172,3 +172,50 @@ class BaseStorage(ABC): self, *, organization_id: str, filename: str, fileObj: BinaryIO ) -> tuple[str, str] | None: pass + + @abstractmethod + async def sync_browser_session_file( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + local_file_path: str, + remote_path: str, + date: str | None = None, + ) -> str: + pass + + @abstractmethod + async def delete_browser_session_file( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> None: + pass + + @abstractmethod + async def browser_session_file_exists( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> bool: + pass + + @abstractmethod + async def download_uploaded_file(self, uri: str) -> bytes | None: + pass + + @abstractmethod + async def file_exists(self, uri: str) -> bool: + pass + + @property + @abstractmethod + def storage_type(self) -> str: + pass diff --git a/skyvern/forge/sdk/artifact/storage/local.py b/skyvern/forge/sdk/artifact/storage/local.py index d39e5011..e9a85a9d 100644 --- a/skyvern/forge/sdk/artifact/storage/local.py +++ b/skyvern/forge/sdk/artifact/storage/local.py @@ -160,11 +160,11 @@ class LocalStorage(BaseStorage): ) return None - async def get_share_link(self, artifact: Artifact) -> str: - return artifact.uri + async def get_share_link(self, artifact: Artifact) -> str | None: + return None - async def get_share_links(self, artifacts: list[Artifact]) -> list[str]: - return [artifact.uri for artifact in artifacts] + async def get_share_links(self, artifacts: list[Artifact]) -> list[str] | None: + return None async def save_streaming_file(self, organization_id: str, file_name: str) -> None: return @@ -346,3 +346,98 @@ class LocalStorage(BaseStorage): raise NotImplementedError( "Legacy file storage is not implemented for LocalStorage. Please use a different storage backend." ) + + def _build_browser_session_path( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> Path: + """Build the local path for a browser session file.""" + base = ( + Path(self.artifact_path) + / settings.ENV + / organization_id + / "browser_sessions" + / browser_session_id + / artifact_type + ) + if date: + return base / date / remote_path + return base / remote_path + + async def sync_browser_session_file( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + local_file_path: str, + remote_path: str, + date: str | None = None, + ) -> str: + """Sync a file from local browser session to local storage.""" + target_path = self._build_browser_session_path( + organization_id, browser_session_id, artifact_type, remote_path, date + ) + if WINDOWS: + target_path = target_path.with_name(_windows_safe_filename(target_path.name)) + self._create_directories_if_not_exists(target_path) + shutil.copy2(local_file_path, target_path) + return f"file://{target_path}" + + async def delete_browser_session_file( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> None: + """Delete a file from browser session storage in local filesystem.""" + target_path = self._build_browser_session_path( + organization_id, browser_session_id, artifact_type, remote_path, date + ) + try: + if target_path.exists(): + target_path.unlink() + except Exception: + LOG.exception("Failed to delete local browser session file", path=str(target_path)) + + async def browser_session_file_exists( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> bool: + """Check if a file exists in browser session storage in local filesystem.""" + target_path = self._build_browser_session_path( + organization_id, browser_session_id, artifact_type, remote_path, date + ) + return target_path.exists() + + async def download_uploaded_file(self, uri: str) -> bytes | None: + """Download a user-uploaded file from local filesystem.""" + try: + file_path = parse_uri_to_path(uri) + with open(file_path, "rb") as f: + return f.read() + except Exception: + LOG.exception("Failed to read local file", uri=uri) + return None + + async def file_exists(self, uri: str) -> bool: + """Check if a file exists at the given local URI.""" + try: + file_path = parse_uri_to_path(uri) + return os.path.exists(file_path) + except Exception: + return False + + @property + def storage_type(self) -> str: + """Returns 'file' as the storage type.""" + return "file" diff --git a/skyvern/forge/sdk/artifact/storage/s3.py b/skyvern/forge/sdk/artifact/storage/s3.py index 6bb359a0..c95a286e 100644 --- a/skyvern/forge/sdk/artifact/storage/s3.py +++ b/skyvern/forge/sdk/artifact/storage/s3.py @@ -235,10 +235,9 @@ class S3Storage(BaseStorage): async def list_downloaded_files_in_browser_session( self, organization_id: str, browser_session_id: str ) -> list[str]: - uri = f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads" - return [ - f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri) - ] + bucket = settings.AWS_S3_BUCKET_ARTIFACTS + uri = f"s3://{bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads" + return [f"s3://{bucket}/{file}" for file in await self.async_client.list_files(uri=uri)] async def get_shared_downloaded_files_in_browser_session( self, organization_id: str, browser_session_id: str @@ -281,18 +280,16 @@ class S3Storage(BaseStorage): async def list_downloading_files_in_browser_session( self, organization_id: str, browser_session_id: str ) -> list[str]: - uri = f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads" - files = [ - f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri) - ] + bucket = settings.AWS_S3_BUCKET_ARTIFACTS + uri = f"s3://{bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads" + files = [f"s3://{bucket}/{file}" for file in await self.async_client.list_files(uri=uri)] return [file for file in files if file.endswith(BROWSER_DOWNLOADING_SUFFIX)] async def list_recordings_in_browser_session(self, organization_id: str, browser_session_id: str) -> list[str]: """List all recording files for a browser session from S3.""" - uri = f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/videos" - return [ - f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri) - ] + bucket = settings.AWS_S3_BUCKET_ARTIFACTS + uri = f"s3://{bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/videos" + return [f"s3://{bucket}/{file}" for file in await self.async_client.list_files(uri=uri)] async def get_shared_recordings_in_browser_session( self, organization_id: str, browser_session_id: str @@ -385,14 +382,15 @@ class S3Storage(BaseStorage): ) async def get_downloaded_files(self, organization_id: str, run_id: str | None) -> list[FileInfo]: - uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}" + bucket = settings.AWS_S3_BUCKET_UPLOADS + uri = f"s3://{bucket}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}" object_keys = await self.async_client.list_files(uri=uri) if len(object_keys) == 0: return [] file_infos: list[FileInfo] = [] for key in object_keys: - object_uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{key}" + object_uri = f"s3://{bucket}/{key}" # Get metadata (including checksum) metadata = await self.async_client.get_file_metadata(object_uri, log_exception=False) @@ -467,3 +465,78 @@ class S3Storage(BaseStorage): ) return None return presigned_urls[0], uploaded_s3_uri + + def _build_browser_session_uri( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> str: + """Build the S3 URI for a browser session file.""" + base = f"s3://{self.bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/{artifact_type}" + if date: + return f"{base}/{date}/{remote_path}" + return f"{base}/{remote_path}" + + async def sync_browser_session_file( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + local_file_path: str, + remote_path: str, + date: str | None = None, + ) -> str: + """Sync a file from local browser session to S3.""" + uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date) + sc = await self._get_storage_class_for_org(organization_id) + tags = await self._get_tags_for_org(organization_id) + await self.async_client.upload_file_from_path(uri, local_file_path, storage_class=sc, tags=tags) + return uri + + async def delete_browser_session_file( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> None: + """Delete a file from browser session storage in S3.""" + uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date) + await self.async_client.delete_file(uri, log_exception=True) + + async def browser_session_file_exists( + self, + organization_id: str, + browser_session_id: str, + artifact_type: str, + remote_path: str, + date: str | None = None, + ) -> bool: + """Check if a file exists in browser session storage in S3.""" + uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date) + try: + info = await self.async_client.get_object_info(uri) + return info is not None + except Exception: + return False + + async def download_uploaded_file(self, uri: str) -> bytes | None: + """Download a user-uploaded file from S3.""" + return await self.async_client.download_file(uri, log_exception=False) + + async def file_exists(self, uri: str) -> bool: + """Check if a file exists at the given S3 URI.""" + try: + info = await self.async_client.get_object_info(uri) + return info is not None + except Exception: + return False + + @property + def storage_type(self) -> str: + """Returns 's3' as the storage type.""" + return "s3" diff --git a/skyvern/forge/sdk/artifact/storage/test_azure_storage.py b/skyvern/forge/sdk/artifact/storage/test_azure_storage.py new file mode 100644 index 00000000..6187f49a --- /dev/null +++ b/skyvern/forge/sdk/artifact/storage/test_azure_storage.py @@ -0,0 +1,251 @@ +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from skyvern.config import settings +from skyvern.forge.sdk.api.azure import StandardBlobTier +from skyvern.forge.sdk.api.real_azure import RealAsyncAzureStorageClient +from skyvern.forge.sdk.artifact.storage.azure import AzureStorage + +# Test constants +TEST_CONTAINER = "test-azure-container" +TEST_ORGANIZATION_ID = "test-org-123" +TEST_BROWSER_SESSION_ID = "bs_test_123" + + +class AzureStorageForTests(AzureStorage): + """Test subclass that overrides org-specific methods and bypasses client init.""" + + async_client: Any # Allow mock attribute access + + def __init__(self, container: str) -> None: + # Don't call super().__init__ to avoid creating real RealAsyncAzureStorageClient + self.container = container + self.async_client = AsyncMock() + + async def _get_storage_tier_for_org(self, organization_id: str) -> StandardBlobTier: + return StandardBlobTier.HOT + + async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]: + return {"test": "tag"} + + +@pytest.fixture +def azure_storage() -> AzureStorageForTests: + """Create AzureStorage with mocked async_client.""" + return AzureStorageForTests(container=TEST_CONTAINER) + + +@pytest.mark.asyncio +class TestAzureStorageBrowserSessionFiles: + """Test AzureStorage browser session file methods.""" + + async def test_sync_browser_session_file_with_date( + self, azure_storage: AzureStorageForTests, tmp_path: Path + ) -> None: + """Test syncing a file with date in path (videos/har).""" + test_file = tmp_path / "recording.webm" + test_file.write_bytes(b"fake video data") + + uri = await azure_storage.sync_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + local_file_path=str(test_file), + remote_path="recording.webm", + date="2025-01-15", + ) + + expected_uri = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/recording.webm" + assert uri == expected_uri + azure_storage.async_client.upload_file_from_path.assert_called_once() + + async def test_sync_browser_session_file_without_date( + self, azure_storage: AzureStorageForTests, tmp_path: Path + ) -> None: + """Test syncing a file without date (downloads category).""" + test_file = tmp_path / "document.pdf" + test_file.write_bytes(b"fake download data") + + uri = await azure_storage.sync_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="downloads", + local_file_path=str(test_file), + remote_path="document.pdf", + date=None, + ) + + expected_uri = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/downloads/document.pdf" + assert uri == expected_uri + + async def test_browser_session_file_exists_returns_true(self, azure_storage: AzureStorageForTests) -> None: + """Test browser_session_file_exists returns True when file exists.""" + azure_storage.async_client.get_object_info.return_value = {"LastModified": "2025-01-15"} + + exists = await azure_storage.browser_session_file_exists( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + remote_path="exists.webm", + date="2025-01-15", + ) + + assert exists is True + + async def test_browser_session_file_exists_returns_false_on_exception( + self, azure_storage: AzureStorageForTests + ) -> None: + """Test browser_session_file_exists returns False when exception is raised.""" + azure_storage.async_client.get_object_info.side_effect = Exception("Not found") + + exists = await azure_storage.browser_session_file_exists( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + remote_path="nonexistent.webm", + date="2025-01-15", + ) + + assert exists is False + + async def test_delete_browser_session_file(self, azure_storage: AzureStorageForTests) -> None: + """Test deleting a browser session file.""" + await azure_storage.delete_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + remote_path="to_delete.webm", + date="2025-01-15", + ) + + expected_uri = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/to_delete.webm" + azure_storage.async_client.delete_file.assert_called_once_with(expected_uri) + + async def test_file_exists_returns_true(self, azure_storage: AzureStorageForTests) -> None: + """Test file_exists returns True when file exists.""" + azure_storage.async_client.get_object_info.return_value = {"LastModified": "2025-01-15"} + uri = f"azure://{TEST_CONTAINER}/test/file.txt" + + exists = await azure_storage.file_exists(uri) + + assert exists is True + + async def test_file_exists_returns_false_on_exception(self, azure_storage: AzureStorageForTests) -> None: + """Test file_exists returns False when exception is raised (404).""" + azure_storage.async_client.get_object_info.side_effect = Exception("Not found") + uri = f"azure://{TEST_CONTAINER}/nonexistent/file.txt" + + exists = await azure_storage.file_exists(uri) + + assert exists is False + + async def test_download_uploaded_file(self, azure_storage: AzureStorageForTests) -> None: + """Test downloading an uploaded file.""" + test_data = b"uploaded file content" + azure_storage.async_client.download_file.return_value = test_data + uri = f"azure://{TEST_CONTAINER}/uploads/file.pdf" + + downloaded = await azure_storage.download_uploaded_file(uri) + + assert downloaded == test_data + azure_storage.async_client.download_file.assert_called_once_with(uri, log_exception=False) + + async def test_download_uploaded_file_returns_none(self, azure_storage: AzureStorageForTests) -> None: + """Test downloading a non-existent file returns None.""" + azure_storage.async_client.download_file.return_value = None + uri = f"azure://{TEST_CONTAINER}/nonexistent/file.txt" + + downloaded = await azure_storage.download_uploaded_file(uri) + + assert downloaded is None + + def test_storage_type_property(self, azure_storage: AzureStorageForTests) -> None: + """Test storage_type returns 'azure'.""" + assert azure_storage.storage_type == "azure" + + +class TestAzureStorageBuildUri: + """Test Azure URI building methods.""" + + def test_build_browser_session_uri_with_date(self, azure_storage: AzureStorageForTests) -> None: + """Test building URI with date.""" + uri = azure_storage._build_browser_session_uri( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + remote_path="file.webm", + date="2025-01-15", + ) + + expected = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/file.webm" + assert uri == expected + + def test_build_browser_session_uri_without_date(self, azure_storage: AzureStorageForTests) -> None: + """Test building URI without date.""" + uri = azure_storage._build_browser_session_uri( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="downloads", + remote_path="file.pdf", + date=None, + ) + + expected = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/downloads/file.pdf" + assert uri == expected + + +AZURE_CONTENT_TYPE_TEST_CASES = [ + # (filename, expected_content_type, artifact_type, date) + ("video.webm", "video/webm", "videos", "2025-01-15"), + ("data.json", "application/json", "har", "2025-01-15"), + ("network.har", "application/json", "har", "2025-01-15"), + ("screenshot.png", "image/png", "downloads", None), + ("output.txt", "text/plain", "downloads", None), + ("debug.log", "text/plain", "downloads", None), +] + + +@pytest.mark.asyncio +class TestAzureStorageContentType: + """Test Azure Storage content type guessing. + + Tests at two levels: + 1. High-level: sync_browser_session_file interface with artifact_type/date + 2. Low-level: RealAsyncAzureStorageClient to verify ContentSettings is passed + """ + + @pytest.mark.parametrize("filename,expected_content_type,artifact_type,date", AZURE_CONTENT_TYPE_TEST_CASES) + async def test_content_type_guessing( + self, + tmp_path: Path, + filename: str, + expected_content_type: str, + artifact_type: str, + date: str | None, + ) -> None: + """Test that RealAsyncAzureStorageClient sets correct content type based on extension.""" + test_file = tmp_path / filename + test_file.write_bytes(b"test content") + + with patch.object(RealAsyncAzureStorageClient, "_get_blob_service_client") as mock_get_client: + mock_container_client = MagicMock() + mock_container_client.upload_blob = AsyncMock() + mock_container_client.exists = AsyncMock(return_value=True) + mock_blob_service = MagicMock() + mock_blob_service.get_container_client.return_value = mock_container_client + mock_get_client.return_value = mock_blob_service + + client = RealAsyncAzureStorageClient(account_name="test", account_key="testkey") + client._verified_containers.add("test-container") + + await client.upload_file_from_path( + uri=f"azure://test-container/path/{filename}", + file_path=str(test_file), + ) + + call_kwargs = mock_container_client.upload_blob.call_args.kwargs + assert call_kwargs["content_settings"] is not None + assert call_kwargs["content_settings"].content_type == expected_content_type diff --git a/skyvern/forge/sdk/artifact/storage/test_s3_storage.py b/skyvern/forge/sdk/artifact/storage/test_s3_storage.py index 810bed0a..8ee7f9bb 100644 --- a/skyvern/forge/sdk/artifact/storage/test_s3_storage.py +++ b/skyvern/forge/sdk/artifact/storage/test_s3_storage.py @@ -221,3 +221,229 @@ class TestS3StorageStore: await s3_storage.store_artifact(artifact, test_data) _assert_object_content(boto3_test_client, artifact.uri, test_data) _assert_object_meta(boto3_test_client, artifact.uri) + + +TEST_BROWSER_SESSION_ID = "bs_test_123" + + +@pytest.mark.asyncio +class TestS3StorageBrowserSessionFiles: + """Test S3Storage browser session file methods.""" + + async def test_sync_browser_session_file_with_date( + self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path + ) -> None: + """Test syncing a file with date in path (videos/har).""" + test_data = b"fake video data" + test_file = tmp_path / "recording.webm" + test_file.write_bytes(test_data) + + uri = await s3_storage.sync_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + local_file_path=str(test_file), + remote_path="recording.webm", + date="2025-01-15", + ) + + expected_uri = f"s3://{TEST_BUCKET}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/recording.webm" + assert uri == expected_uri + _assert_object_content(boto3_test_client, uri, test_data) + _assert_object_meta(boto3_test_client, uri) + + async def test_sync_browser_session_file_without_date( + self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path + ) -> None: + """Test syncing a file without date (downloads category).""" + test_data = b"fake download data" + test_file = tmp_path / "document.pdf" + test_file.write_bytes(test_data) + + uri = await s3_storage.sync_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="downloads", + local_file_path=str(test_file), + remote_path="document.pdf", + date=None, + ) + + expected_uri = f"s3://{TEST_BUCKET}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/downloads/document.pdf" + assert uri == expected_uri + _assert_object_content(boto3_test_client, uri, test_data) + + async def test_browser_session_file_exists_returns_true( + self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path + ) -> None: + """Test browser_session_file_exists returns True for existing file.""" + test_file = tmp_path / "exists.webm" + test_file.write_bytes(b"test data") + + await s3_storage.sync_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + local_file_path=str(test_file), + remote_path="exists.webm", + date="2025-01-15", + ) + + exists = await s3_storage.browser_session_file_exists( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + remote_path="exists.webm", + date="2025-01-15", + ) + assert exists is True + + async def test_browser_session_file_exists_returns_false(self, s3_storage: S3Storage) -> None: + """Test browser_session_file_exists returns False for non-existent file.""" + exists = await s3_storage.browser_session_file_exists( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + remote_path="nonexistent.webm", + date="2025-01-15", + ) + assert exists is False + + async def test_delete_browser_session_file( + self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path + ) -> None: + """Test deleting a browser session file.""" + test_file = tmp_path / "to_delete.webm" + test_file.write_bytes(b"test data") + + await s3_storage.sync_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + local_file_path=str(test_file), + remote_path="to_delete.webm", + date="2025-01-15", + ) + + exists_before = await s3_storage.browser_session_file_exists( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + remote_path="to_delete.webm", + date="2025-01-15", + ) + assert exists_before is True + + await s3_storage.delete_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + remote_path="to_delete.webm", + date="2025-01-15", + ) + + exists_after = await s3_storage.browser_session_file_exists( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="videos", + remote_path="to_delete.webm", + date="2025-01-15", + ) + assert exists_after is False + + async def test_file_exists_returns_true( + self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path + ) -> None: + """Test file_exists returns True for existing file.""" + test_file = tmp_path / "test.txt" + test_file.write_bytes(b"test data") + + uri = await s3_storage.sync_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="downloads", + local_file_path=str(test_file), + remote_path="test.txt", + ) + + exists = await s3_storage.file_exists(uri) + assert exists is True + + async def test_file_exists_returns_false(self, s3_storage: S3Storage) -> None: + """Test file_exists returns False for non-existent file.""" + uri = f"s3://{TEST_BUCKET}/nonexistent/path/file.txt" + exists = await s3_storage.file_exists(uri) + assert exists is False + + async def test_download_uploaded_file( + self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path + ) -> None: + """Test downloading an uploaded file.""" + test_data = b"uploaded file content" + test_file = tmp_path / "uploaded.pdf" + test_file.write_bytes(test_data) + + uri = await s3_storage.sync_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type="downloads", + local_file_path=str(test_file), + remote_path="uploaded.pdf", + ) + + downloaded = await s3_storage.download_uploaded_file(uri) + assert downloaded == test_data + + async def test_download_uploaded_file_nonexistent(self, s3_storage: S3Storage) -> None: + """Test downloading a non-existent file returns None.""" + uri = f"s3://{TEST_BUCKET}/nonexistent/path/file.txt" + downloaded = await s3_storage.download_uploaded_file(uri) + assert downloaded is None + + def test_storage_type_property(self, s3_storage: S3Storage) -> None: + """Test storage_type returns 's3'.""" + assert s3_storage.storage_type == "s3" + + +CONTENT_TYPE_TEST_CASES = [ + # (filename, expected_content_type, artifact_type, date) + ("video.webm", "video/webm", "videos", "2025-01-15"), + ("data.json", "application/json", "har", "2025-01-15"), + ("network.har", "application/json", "har", "2025-01-15"), + ("screenshot.png", "image/png", "downloads", None), + ("output.txt", "text/plain", "downloads", None), + ("debug.log", "text/plain", "downloads", None), +] + + +@pytest.mark.asyncio +class TestS3StorageContentType: + """Test S3Storage content type guessing.""" + + @pytest.mark.parametrize("filename,expected_content_type,artifact_type,date", CONTENT_TYPE_TEST_CASES) + async def test_content_type_guessing( + self, + s3_storage: S3Storage, + boto3_test_client: S3Client, + tmp_path: Path, + filename: str, + expected_content_type: str, + artifact_type: str, + date: str | None, + ) -> None: + """Test that files get correct content type based on extension.""" + test_file = tmp_path / filename + test_file.write_bytes(b"test content") + + uri = await s3_storage.sync_browser_session_file( + organization_id=TEST_ORGANIZATION_ID, + browser_session_id=TEST_BROWSER_SESSION_ID, + artifact_type=artifact_type, + local_file_path=str(test_file), + remote_path=filename, + date=date, + ) + + s3uri = S3Uri(uri) + obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key) + assert obj_meta["ContentType"] == expected_content_type diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index d6184d3a..aeb1dd39 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -1294,15 +1294,9 @@ async def get_artifact( status_code=http_status.HTTP_404_NOT_FOUND, detail=f"Artifact not found {artifact_id}", ) - if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS: - signed_urls = await app.ARTIFACT_MANAGER.get_share_links([artifact]) - if signed_urls: - artifact.signed_url = signed_urls[0] - else: - LOG.warning( - "Failed to get signed url for artifact", - artifact_id=artifact_id, - ) + signed_urls = await app.ARTIFACT_MANAGER.get_share_links([artifact]) + if signed_urls and len(signed_urls) == 1: + artifact.signed_url = signed_urls[0] return artifact @@ -1334,23 +1328,10 @@ async def get_run_artifacts( # Ensure we have a list of artifacts (since group_by_type=False, this will always be a list) artifacts_list = artifacts if isinstance(artifacts, list) else [] - if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS: - # Get signed URLs for all artifacts - signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts_list) - - if signed_urls and len(signed_urls) == len(artifacts_list): - for i, artifact in enumerate(artifacts_list): - if hasattr(artifact, "signed_url"): - artifact.signed_url = signed_urls[i] - elif signed_urls: - LOG.warning( - "Mismatch between artifacts and signed URLs count", - artifacts_count=len(artifacts_list), - urls_count=len(signed_urls), - run_id=run_id, - ) - else: - LOG.warning("Failed to get signed urls for artifacts", run_id=run_id) + signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts_list) + if signed_urls and len(signed_urls) == len(artifacts_list): + for i, artifact in enumerate(artifacts_list): + artifact.signed_url = signed_urls[i] return ORJSONResponse([artifact.model_dump() for artifact in artifacts_list]) @@ -1976,17 +1957,10 @@ async def get_artifacts( } artifacts = await app.DATABASE.get_artifacts_by_entity_id(organization_id=current_org.organization_id, **params) # type: ignore - if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS: - signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts) - if signed_urls: - for i, artifact in enumerate(artifacts): - artifact.signed_url = signed_urls[i] - else: - LOG.warning( - "Failed to get signed urls for artifacts", - entity_type=entity_type, - entity_id=entity_id, - ) + signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts) + if signed_urls and len(signed_urls) == len(artifacts): + for i, artifact in enumerate(artifacts): + artifact.signed_url = signed_urls[i] return ORJSONResponse([artifact.model_dump() for artifact in artifacts]) @@ -2021,17 +1995,10 @@ async def get_step_artifacts( step_id, organization_id=current_org.organization_id, ) - if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS: - signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts) - if signed_urls: - for i, artifact in enumerate(artifacts): - artifact.signed_url = signed_urls[i] - else: - LOG.warning( - "Failed to get signed urls for artifacts", - task_id=task_id, - step_id=step_id, - ) + signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts) + if signed_urls and len(signed_urls) == len(artifacts): + for i, artifact in enumerate(artifacts): + artifact.signed_url = signed_urls[i] return ORJSONResponse([artifact.model_dump() for artifact in artifacts]) diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 1b296ab8..2b22d82f 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -2603,9 +2603,8 @@ class FileUploadBlock(Block): blob_name = self._get_azure_blob_name(workflow_run_id, file_path) azure_uri = self._get_azure_blob_uri(workflow_run_id, blob_name) uploaded_uris.append(azure_uri) - await azure_client.upload_file_from_path( - container_name=self.azure_blob_container_name or "", blob_name=blob_name, file_path=file_path - ) + uri = f"azure://{self.azure_blob_container_name or ''}/{blob_name}" + await azure_client.upload_file_from_path(uri, file_path) LOG.info("FileUploadBlock: File(s) uploaded to Azure Blob Storage", file_path=self.path) else: # This case should ideally be caught by the initial validation