add azure blob storage (#4338)
Signed-off-by: Benji Visser <benji@093b.org> Co-authored-by: Benji Visser <benji@093b.org> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
@@ -87,23 +87,26 @@ class Settings(BaseSettings):
|
||||
|
||||
# Artifact storage settings
|
||||
ARTIFACT_STORAGE_PATH: str = f"{SKYVERN_DIR}/artifacts"
|
||||
GENERATE_PRESIGNED_URLS: bool = False
|
||||
|
||||
# Supported storage types: local, s3cloud, azureblob
|
||||
SKYVERN_STORAGE_TYPE: str = "local"
|
||||
|
||||
# S3/AWS settings
|
||||
AWS_REGION: str = "us-east-1"
|
||||
MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB
|
||||
PRESIGNED_URL_EXPIRATION: int = 60 * 60 * 24 # 24 hours
|
||||
AWS_S3_BUCKET_ARTIFACTS: str = "skyvern-artifacts"
|
||||
AWS_S3_BUCKET_SCREENSHOTS: str = "skyvern-screenshots"
|
||||
AWS_S3_BUCKET_BROWSER_SESSIONS: str = "skyvern-browser-sessions"
|
||||
|
||||
# Supported storage types: local, s3
|
||||
SKYVERN_STORAGE_TYPE: str = "local"
|
||||
|
||||
# S3 bucket settings
|
||||
AWS_REGION: str = "us-east-1"
|
||||
AWS_S3_BUCKET_UPLOADS: str = "skyvern-uploads"
|
||||
MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB
|
||||
PRESIGNED_URL_EXPIRATION: int = 60 * 60 * 24 # 24 hours
|
||||
|
||||
# Azure Blob Storage settings
|
||||
AZURE_STORAGE_ACCOUNT_NAME: str | None = None
|
||||
AZURE_STORAGE_ACCOUNT_KEY: str | None = None
|
||||
AZURE_STORAGE_CONTAINER_ARTIFACTS: str = "skyvern-artifacts"
|
||||
AZURE_STORAGE_CONTAINER_SCREENSHOTS: str = "skyvern-screenshots"
|
||||
AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS: str = "skyvern-browser-sessions"
|
||||
AZURE_STORAGE_CONTAINER_UPLOADS: str = "skyvern-uploads"
|
||||
|
||||
SKYVERN_TELEMETRY: bool = True
|
||||
ANALYTICS_ID: str = "anonymous"
|
||||
|
||||
@@ -16,6 +16,7 @@ from skyvern.forge.sdk.api.llm.api_handler import LLMAPIHandler
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.api.real_azure import RealAzureClientFactory
|
||||
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
||||
from skyvern.forge.sdk.artifact.storage.azure import AzureStorage
|
||||
from skyvern.forge.sdk.artifact.storage.base import BaseStorage
|
||||
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
||||
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
|
||||
@@ -104,6 +105,8 @@ def create_forge_app() -> ForgeApp:
|
||||
|
||||
if settings.SKYVERN_STORAGE_TYPE == "s3":
|
||||
StorageFactory.set_storage(S3Storage())
|
||||
elif settings.SKYVERN_STORAGE_TYPE == "azureblob":
|
||||
StorageFactory.set_storage(AzureStorage())
|
||||
app.STORAGE = StorageFactory.get_storage()
|
||||
app.CACHE = CacheFactory.get_cache()
|
||||
app.ARTIFACT_MANAGER = ArtifactManager()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import StrEnum
|
||||
from mimetypes import add_type, guess_type
|
||||
from typing import IO, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -12,6 +13,10 @@ from types_boto3_secretsmanager.client import SecretsManagerClient
|
||||
|
||||
from skyvern.config import settings
|
||||
|
||||
# Register custom mime types for mimetypes guessing
|
||||
add_type("application/json", ".har")
|
||||
add_type("text/plain", ".log")
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -188,6 +193,10 @@ class AsyncAWSClient:
|
||||
extra_args["Tagging"] = self._create_tag_string(tags)
|
||||
if content_type:
|
||||
extra_args["ContentType"] = content_type
|
||||
else:
|
||||
guessed_type, _ = guess_type(file_path)
|
||||
if guessed_type:
|
||||
extra_args["ContentType"] = guessed_type
|
||||
await client.upload_file(
|
||||
Filename=file_path,
|
||||
Bucket=parsed_uri.bucket,
|
||||
|
||||
@@ -1,8 +1,35 @@
|
||||
from typing import Protocol, Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from azure.storage.blob import StandardBlobTier
|
||||
|
||||
from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential
|
||||
|
||||
|
||||
class AzureUri:
|
||||
"""Parse azure://{container}/{blob_path} URIs."""
|
||||
|
||||
def __init__(self, uri: str) -> None:
|
||||
self._parsed = urlparse(uri, allow_fragments=False)
|
||||
|
||||
@property
|
||||
def container(self) -> str:
|
||||
return self._parsed.netloc
|
||||
|
||||
@property
|
||||
def blob_path(self) -> str:
|
||||
if self._parsed.query:
|
||||
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
|
||||
return self._parsed.path.lstrip("/")
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._parsed.geturl()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.uri
|
||||
|
||||
|
||||
class AsyncAzureVaultClient(Protocol):
|
||||
"""Protocol defining the interface for Azure Vault clients.
|
||||
|
||||
@@ -68,21 +95,24 @@ class AsyncAzureVaultClient(Protocol):
|
||||
|
||||
|
||||
class AsyncAzureStorageClient(Protocol):
|
||||
"""Protocol defining the interface for Azure Storage clients.
|
||||
"""Protocol defining the interface for Azure Storage clients."""
|
||||
|
||||
This client provides methods to interact with Azure Blob Storage for file operations.
|
||||
"""
|
||||
|
||||
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
|
||||
async def upload_file_from_path(
|
||||
self,
|
||||
uri: str,
|
||||
file_path: str,
|
||||
tier: StandardBlobTier = StandardBlobTier.HOT,
|
||||
tags: dict[str, str] | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Upload a file from the local filesystem to Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
container_name: The name of the Azure Blob container
|
||||
blob_name: The name to give the blob in storage
|
||||
uri: The azure:// URI for the blob (azure://container/blob_path)
|
||||
file_path: The local path to the file to upload
|
||||
|
||||
Raises:
|
||||
Exception: If the upload fails
|
||||
tier: The storage tier for the blob
|
||||
tags: Optional tags to attach to the blob
|
||||
metadata: Optional metadata to attach to the blob
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ from yarl import URL
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import BROWSER_DOWNLOAD_TIMEOUT, BROWSER_DOWNLOADING_SUFFIX, REPO_ROOT_DIR
|
||||
from skyvern.exceptions import DownloadFileMaxSizeExceeded, DownloadFileMaxWaitingTime
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient, aws_client
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.utils.url_validators import encode_url
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -97,6 +98,12 @@ def validate_download_url(url: str) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
# Allow Azure URIs for Skyvern uploads container
|
||||
if scheme == "azure":
|
||||
if url.startswith(f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{settings.ENV}/o_"):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Allow file:// URLs only in local environment
|
||||
if scheme == "file":
|
||||
if settings.ENV != "local":
|
||||
@@ -129,20 +136,41 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
||||
url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
||||
LOG.info("Converting Google Drive link to direct download", url=url)
|
||||
|
||||
# Check if URL is an S3 URI
|
||||
if url.startswith(f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"):
|
||||
LOG.info("Downloading Skyvern file from S3", url=url)
|
||||
client = AsyncAWSClient()
|
||||
return await download_from_s3(client, url)
|
||||
# Check if URL is a cloud storage URI (S3 or Azure)
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme == "s3":
|
||||
uploads_prefix = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"
|
||||
if url.startswith(uploads_prefix):
|
||||
LOG.info("Downloading Skyvern file from S3", url=url)
|
||||
data = await app.STORAGE.download_uploaded_file(url)
|
||||
if data is None:
|
||||
raise Exception(f"Failed to download file from S3: {url}")
|
||||
filename = url.split("/")[-1]
|
||||
temp_file = create_named_temporary_file(delete=False, file_name=filename)
|
||||
LOG.info(f"Downloaded file to {temp_file.name}")
|
||||
temp_file.write(data)
|
||||
return temp_file.name
|
||||
elif parsed.scheme == "azure":
|
||||
uploads_prefix = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{settings.ENV}/o_"
|
||||
if url.startswith(uploads_prefix):
|
||||
LOG.info("Downloading Skyvern file from Azure Blob Storage", url=url)
|
||||
data = await app.STORAGE.download_uploaded_file(url)
|
||||
if data is None:
|
||||
raise Exception(f"Failed to download file from Azure Blob Storage: {url}")
|
||||
filename = url.split("/")[-1]
|
||||
temp_file = create_named_temporary_file(delete=False, file_name=filename)
|
||||
LOG.info(f"Downloaded file to {temp_file.name}")
|
||||
temp_file.write(data)
|
||||
return temp_file.name
|
||||
|
||||
# Check if URL is a file:// URI
|
||||
# we only support to download local files when the environment is local
|
||||
# and the file is in the skyvern downloads directory
|
||||
if url.startswith("file://") and settings.ENV == "local":
|
||||
file_path = parse_uri_to_path(url)
|
||||
if file_path.startswith(f"{REPO_ROOT_DIR}/downloads"):
|
||||
local_path = parse_uri_to_path(url)
|
||||
if local_path.startswith(f"{REPO_ROOT_DIR}/downloads"):
|
||||
LOG.info("Downloading file from local file system", url=url)
|
||||
return file_path
|
||||
return local_path
|
||||
|
||||
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
||||
LOG.info("Starting to download file", url=url)
|
||||
@@ -262,12 +290,13 @@ async def wait_for_download_finished(downloading_files: list[str], timeout: floa
|
||||
while len(cur_downloading_files) > 0:
|
||||
new_downloading_files: list[str] = []
|
||||
for path in cur_downloading_files:
|
||||
if path.startswith("s3://"):
|
||||
try:
|
||||
await aws_client.get_object_info(path)
|
||||
except Exception:
|
||||
# Check for cloud storage URIs (S3 or Azure)
|
||||
parsed = urlparse(path)
|
||||
if parsed.scheme in ("s3", "azure"):
|
||||
if not await app.STORAGE.file_exists(path):
|
||||
LOG.debug(
|
||||
"downloading file is not found in s3, means the file finished downloading", path=path
|
||||
"downloading file is not found in cloud storage, means the file finished downloading",
|
||||
path=path,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
|
||||
@@ -1,15 +1,29 @@
|
||||
"""Real implementations of Azure clients (Vault and Storage) and their factories."""
|
||||
|
||||
from typing import Self
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from mimetypes import add_type, guess_type
|
||||
from typing import IO, Self
|
||||
|
||||
import structlog
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
|
||||
from azure.keyvault.secrets.aio import SecretClient
|
||||
from azure.storage.blob import BlobSasPermissions, ContentSettings, StandardBlobTier, generate_blob_sas
|
||||
from azure.storage.blob.aio import BlobServiceClient
|
||||
|
||||
from skyvern.forge.sdk.api.azure import AsyncAzureStorageClient, AsyncAzureVaultClient, AzureClientFactory
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.api.azure import (
|
||||
AsyncAzureStorageClient,
|
||||
AsyncAzureVaultClient,
|
||||
AzureClientFactory,
|
||||
AzureUri,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential
|
||||
|
||||
# Register custom mime types for mimetypes guessing
|
||||
add_type("application/json", ".har")
|
||||
add_type("text/plain", ".log")
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -73,37 +87,256 @@ class RealAsyncAzureVaultClient(AsyncAzureVaultClient):
|
||||
|
||||
|
||||
class RealAsyncAzureStorageClient(AsyncAzureStorageClient):
|
||||
"""Real implementation of Azure Storage client using Azure SDK."""
|
||||
"""Async client for Azure Blob Storage operations. Implements AsyncAzureStorageClient protocol."""
|
||||
|
||||
def __init__(self, storage_account_name: str, storage_account_key: str):
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=f"https://{storage_account_name}.blob.core.windows.net",
|
||||
credential=storage_account_key,
|
||||
def __init__(
|
||||
self,
|
||||
account_name: str | None = None,
|
||||
account_key: str | None = None,
|
||||
) -> None:
|
||||
self.account_name = account_name or settings.AZURE_STORAGE_ACCOUNT_NAME
|
||||
self.account_key = account_key or settings.AZURE_STORAGE_ACCOUNT_KEY
|
||||
|
||||
if not self.account_name or not self.account_key:
|
||||
raise ValueError("Azure Storage account name and key are required")
|
||||
|
||||
self._blob_service_client: BlobServiceClient | None = None
|
||||
self._verified_containers: set[str] = set()
|
||||
|
||||
def _get_blob_service_client(self) -> BlobServiceClient:
|
||||
if self._blob_service_client is None:
|
||||
self._blob_service_client = BlobServiceClient(
|
||||
account_url=f"https://{self.account_name}.blob.core.windows.net",
|
||||
credential=self.account_key,
|
||||
)
|
||||
return self._blob_service_client
|
||||
|
||||
async def _ensure_container_exists(self, container: str) -> None:
|
||||
if container in self._verified_containers:
|
||||
return
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(container)
|
||||
try:
|
||||
if not await container_client.exists():
|
||||
await container_client.create_container()
|
||||
LOG.info("Created Azure container", container=container)
|
||||
except Exception:
|
||||
LOG.debug("Container may already exist", container=container)
|
||||
self._verified_containers.add(container)
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
uri: str,
|
||||
data: bytes,
|
||||
tier: StandardBlobTier = StandardBlobTier.HOT,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
parsed = AzureUri(uri)
|
||||
await self._ensure_container_exists(parsed.container)
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
await container_client.upload_blob(
|
||||
name=parsed.blob_path,
|
||||
data=data,
|
||||
overwrite=True,
|
||||
standard_blob_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
|
||||
try:
|
||||
container_client = self.blob_service_client.get_container_client(container_name)
|
||||
# Create the container if it doesn't exist
|
||||
try:
|
||||
await container_client.create_container()
|
||||
except Exception as e:
|
||||
LOG.info("Azure container already exists or failed to create", container_name=container_name, error=e)
|
||||
|
||||
with open(file_path, "rb") as data:
|
||||
await container_client.upload_blob(name=blob_name, data=data, overwrite=True)
|
||||
LOG.info("File uploaded to Azure Blob Storage", container_name=container_name, blob_name=blob_name)
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"Failed to upload file to Azure Blob Storage",
|
||||
container_name=container_name,
|
||||
blob_name=blob_name,
|
||||
error=e,
|
||||
async def upload_file_from_path(
|
||||
self,
|
||||
uri: str,
|
||||
file_path: str,
|
||||
tier: StandardBlobTier = StandardBlobTier.HOT,
|
||||
tags: dict[str, str] | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
parsed = AzureUri(uri)
|
||||
await self._ensure_container_exists(parsed.container)
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
content_type, _ = guess_type(file_path)
|
||||
content_settings = ContentSettings(content_type=content_type) if content_type else None
|
||||
with open(file_path, "rb") as f:
|
||||
await container_client.upload_blob(
|
||||
name=parsed.blob_path,
|
||||
data=f,
|
||||
overwrite=True,
|
||||
standard_blob_tier=tier,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
content_settings=content_settings,
|
||||
)
|
||||
raise e
|
||||
|
||||
async def upload_file_stream(
|
||||
self,
|
||||
uri: str,
|
||||
file_obj: IO[bytes],
|
||||
tier: StandardBlobTier = StandardBlobTier.HOT,
|
||||
tags: dict[str, str] | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
parsed = AzureUri(uri)
|
||||
await self._ensure_container_exists(parsed.container)
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
await container_client.upload_blob(
|
||||
name=parsed.blob_path,
|
||||
data=file_obj,
|
||||
overwrite=True,
|
||||
standard_blob_tier=tier,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
)
|
||||
return uri
|
||||
|
||||
async def download_file(self, uri: str, log_exception: bool = True) -> bytes | None:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
download = await blob_client.download_blob()
|
||||
return await download.readall()
|
||||
except ResourceNotFoundError:
|
||||
if log_exception:
|
||||
LOG.warning("Azure blob not found", uri=uri)
|
||||
return None
|
||||
except Exception:
|
||||
if log_exception:
|
||||
LOG.exception("Failed to download from Azure", uri=uri)
|
||||
return None
|
||||
|
||||
async def get_blob_properties(self, uri: str) -> dict | None:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
props = await blob_client.get_blob_properties()
|
||||
return {
|
||||
"size": props.size,
|
||||
"content_type": props.content_settings.content_type if props.content_settings else None,
|
||||
"last_modified": props.last_modified,
|
||||
"etag": props.etag,
|
||||
"metadata": props.metadata,
|
||||
}
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
except Exception:
|
||||
LOG.exception("Failed to get blob properties", uri=uri)
|
||||
return None
|
||||
|
||||
async def blob_exists(self, uri: str) -> bool:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
return await blob_client.exists()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_blob(self, uri: str) -> None:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
await blob_client.delete_blob()
|
||||
except ResourceNotFoundError:
|
||||
LOG.debug("Azure blob not found for deletion", uri=uri)
|
||||
except Exception:
|
||||
LOG.exception("Failed to delete Azure blob", uri=uri)
|
||||
raise
|
||||
|
||||
async def list_blobs(self, container: str, prefix: str | None = None) -> list[str]:
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(container)
|
||||
blobs = []
|
||||
async for blob in container_client.list_blobs(name_starts_with=prefix):
|
||||
blobs.append(blob.name)
|
||||
return blobs
|
||||
except ResourceNotFoundError:
|
||||
return []
|
||||
except Exception:
|
||||
LOG.exception("Failed to list Azure blobs", container=container, prefix=prefix)
|
||||
return []
|
||||
|
||||
def create_sas_url(self, uri: str, expiry_hours: int = 24) -> str | None:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
sas_token = generate_blob_sas(
|
||||
account_name=self.account_name,
|
||||
container_name=parsed.container,
|
||||
blob_name=parsed.blob_path,
|
||||
account_key=self.account_key,
|
||||
permission=BlobSasPermissions(read=True),
|
||||
expiry=datetime.now(timezone.utc) + timedelta(hours=expiry_hours),
|
||||
)
|
||||
return (
|
||||
f"https://{self.account_name}.blob.core.windows.net/{parsed.container}/{parsed.blob_path}?{sas_token}"
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception("Failed to create SAS URL", uri=uri)
|
||||
return None
|
||||
|
||||
async def create_sas_urls(self, uris: list[str], expiry_hours: int = 24) -> list[str] | None:
|
||||
try:
|
||||
sas_urls: list[str] = []
|
||||
for uri in uris:
|
||||
url = self.create_sas_url(uri, expiry_hours)
|
||||
if url is None:
|
||||
LOG.warning("SAS URL generation failed, aborting batch", failed_uri=uri, uris=uris)
|
||||
return None
|
||||
sas_urls.append(url)
|
||||
return sas_urls
|
||||
except Exception:
|
||||
LOG.exception("Failed to create SAS URLs")
|
||||
return None
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.blob_service_client.close()
|
||||
if self._blob_service_client:
|
||||
await self._blob_service_client.close()
|
||||
self._blob_service_client = None
|
||||
|
||||
async def list_files(self, uri: str) -> list[str]:
|
||||
"""List files under a URI prefix. Returns blob names relative to container."""
|
||||
parsed = AzureUri(uri)
|
||||
return await self.list_blobs(parsed.container, parsed.blob_path)
|
||||
|
||||
async def get_object_info(self, uri: str) -> dict | None:
|
||||
"""Get object info including metadata. Returns dict with Metadata and LastModified keys."""
|
||||
props = await self.get_blob_properties(uri)
|
||||
if props is None:
|
||||
return None
|
||||
return {
|
||||
"Metadata": props.get("metadata", {}),
|
||||
"LastModified": props.get("last_modified"),
|
||||
}
|
||||
|
||||
async def delete_file(self, uri: str) -> None:
|
||||
"""Delete a file at the given URI."""
|
||||
await self.delete_blob(uri)
|
||||
|
||||
async def get_file_metadata(self, uri: str, log_exception: bool = True) -> dict[str, str] | None:
|
||||
"""Get only the metadata for a file."""
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
props = await blob_client.get_blob_properties()
|
||||
return props.metadata or {}
|
||||
except ResourceNotFoundError:
|
||||
if log_exception:
|
||||
LOG.warning("Azure blob not found for metadata", uri=uri)
|
||||
return None
|
||||
except Exception:
|
||||
if log_exception:
|
||||
LOG.exception("Failed to get blob metadata", uri=uri)
|
||||
return None
|
||||
|
||||
|
||||
class RealAzureClientFactory(AzureClientFactory):
|
||||
@@ -124,4 +357,4 @@ class RealAzureClientFactory(AzureClientFactory):
|
||||
|
||||
def create_storage_client(self, storage_account_name: str, storage_account_key: str) -> AsyncAzureStorageClient:
|
||||
"""Create an Azure Storage client with the provided credentials."""
|
||||
return RealAsyncAzureStorageClient(storage_account_name, storage_account_key)
|
||||
return RealAsyncAzureStorageClient(account_name=storage_account_name, account_key=storage_account_key)
|
||||
|
||||
539
skyvern/forge/sdk/artifact/storage/azure.py
Normal file
539
skyvern/forge/sdk/artifact/storage/azure.py
Normal file
@@ -0,0 +1,539 @@
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import BinaryIO
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import BROWSER_DOWNLOADING_SUFFIX, DOWNLOAD_FILE_PREFIX
|
||||
from skyvern.forge.sdk.api.azure import StandardBlobTier
|
||||
from skyvern.forge.sdk.api.files import (
|
||||
calculate_sha256_for_file,
|
||||
create_named_temporary_file,
|
||||
get_download_dir,
|
||||
get_skyvern_temp_dir,
|
||||
make_temp_directory,
|
||||
unzip_files,
|
||||
)
|
||||
from skyvern.forge.sdk.api.real_azure import RealAsyncAzureStorageClient
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
|
||||
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AzureStorage(BaseStorage):
|
||||
_PATH_VERSION = "v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
container: str | None = None,
|
||||
account_name: str | None = None,
|
||||
account_key: str | None = None,
|
||||
) -> None:
|
||||
self.async_client = RealAsyncAzureStorageClient(account_name=account_name, account_key=account_key)
|
||||
self.container = container or settings.AZURE_STORAGE_CONTAINER_ARTIFACTS
|
||||
|
||||
def build_uri(self, *, organization_id: str, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
async def retrieve_global_workflows(self) -> list[str]:
|
||||
uri = f"azure://{self.container}/{settings.ENV}/global_workflows.txt"
|
||||
data = await self.async_client.download_file(uri, log_exception=False)
|
||||
if not data:
|
||||
return []
|
||||
return [line.strip() for line in data.decode("utf-8").split("\n") if line.strip()]
|
||||
|
||||
def _build_base_uri(self, organization_id: str) -> str:
|
||||
return f"azure://{self.container}/{self._PATH_VERSION}/{settings.ENV}/{organization_id}"
|
||||
|
||||
def build_log_uri(
|
||||
self, *, organization_id: str, log_entity_type: LogEntityType, log_entity_id: str, artifact_type: ArtifactType
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/logs/{log_entity_type}/{log_entity_id}/{datetime.utcnow().isoformat()}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_thought_uri(
|
||||
self, *, organization_id: str, artifact_id: str, thought: Thought, artifact_type: ArtifactType
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/observers/{thought.observer_cruise_id}/{thought.observer_thought_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_task_v2_uri(
|
||||
self, *, organization_id: str, artifact_id: str, task_v2: TaskV2, artifact_type: ArtifactType
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/observers/{task_v2.observer_cruise_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_workflow_run_block_uri(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
artifact_id: str,
|
||||
workflow_run_block: WorkflowRunBlock,
|
||||
artifact_type: ArtifactType,
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/workflow_runs/{workflow_run_block.workflow_run_id}/{workflow_run_block.workflow_run_block_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_ai_suggestion_uri(
|
||||
self, *, organization_id: str, artifact_id: str, ai_suggestion: AISuggestion, artifact_type: ArtifactType
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/ai_suggestions/{ai_suggestion.ai_suggestion_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_script_file_uri(
|
||||
self, *, organization_id: str, script_id: str, script_version: int, file_path: str
|
||||
) -> str:
|
||||
"""Build the Azure URI for a script file."""
|
||||
return f"{self._build_base_uri(organization_id)}/scripts/{script_id}/{script_version}/{file_path}"
|
||||
|
||||
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
||||
tier = await self._get_storage_tier_for_org(artifact.organization_id)
|
||||
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||
LOG.debug(
|
||||
"Storing artifact",
|
||||
artifact_id=artifact.artifact_id,
|
||||
organization_id=artifact.organization_id,
|
||||
uri=artifact.uri,
|
||||
storage_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file(artifact.uri, data, tier=tier, tags=tags)
|
||||
|
||||
async def _get_storage_tier_for_org(self, organization_id: str) -> StandardBlobTier:
|
||||
return StandardBlobTier.HOT
|
||||
|
||||
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
return await self.async_client.download_file(artifact.uri)
|
||||
|
||||
async def get_share_link(self, artifact: Artifact) -> str | None:
|
||||
share_urls = await self.async_client.create_sas_urls([artifact.uri])
|
||||
return share_urls[0] if share_urls else None
|
||||
|
||||
async def get_share_links(self, artifacts: list[Artifact]) -> list[str] | None:
|
||||
return await self.async_client.create_sas_urls([artifact.uri for artifact in artifacts])
|
||||
|
||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||
tier = await self._get_storage_tier_for_org(artifact.organization_id)
|
||||
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||
LOG.debug(
|
||||
"Storing artifact from path",
|
||||
artifact_id=artifact.artifact_id,
|
||||
organization_id=artifact.organization_id,
|
||||
uri=artifact.uri,
|
||||
storage_tier=tier,
|
||||
path=path,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(artifact.uri, path, tier=tier, tags=tags)
|
||||
|
||||
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
||||
from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}"
|
||||
to_path = f"azure://{settings.AZURE_STORAGE_CONTAINER_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Saving streaming file",
|
||||
organization_id=organization_id,
|
||||
file_name=file_name,
|
||||
from_path=from_path,
|
||||
to_path=to_path,
|
||||
storage_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(to_path, from_path, tier=tier, tags=tags)
|
||||
|
||||
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
|
||||
path = f"azure://{settings.AZURE_STORAGE_CONTAINER_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
||||
return await self.async_client.download_file(path, log_exception=False)
|
||||
|
||||
async def store_browser_session(self, organization_id: str, workflow_permanent_id: str, directory: str) -> None:
|
||||
# Zip the directory to a temp file
|
||||
temp_zip_file = create_named_temporary_file()
|
||||
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
|
||||
browser_session_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Storing browser session",
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
zip_file_path=zip_file_path,
|
||||
browser_session_uri=browser_session_uri,
|
||||
storage_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, tier=tier, tags=tags)
|
||||
|
||||
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
|
||||
browser_session_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
||||
downloaded_zip_bytes = await self.async_client.download_file(browser_session_uri, log_exception=True)
|
||||
if not downloaded_zip_bytes:
|
||||
return None
|
||||
temp_zip_file = create_named_temporary_file(delete=False)
|
||||
temp_zip_file.write(downloaded_zip_bytes)
|
||||
temp_zip_file_path = temp_zip_file.name
|
||||
|
||||
temp_dir = make_temp_directory(prefix="skyvern_browser_session_")
|
||||
unzip_files(temp_zip_file_path, temp_dir)
|
||||
temp_zip_file.close()
|
||||
return temp_dir
|
||||
|
||||
async def store_browser_profile(self, organization_id: str, profile_id: str, directory: str) -> None:
|
||||
"""Store browser profile to Azure."""
|
||||
temp_zip_file = create_named_temporary_file()
|
||||
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
|
||||
profile_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/profiles/{profile_id}.zip"
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Storing browser profile",
|
||||
organization_id=organization_id,
|
||||
profile_id=profile_id,
|
||||
zip_file_path=zip_file_path,
|
||||
profile_uri=profile_uri,
|
||||
storage_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(profile_uri, zip_file_path, tier=tier, tags=tags)
|
||||
|
||||
async def retrieve_browser_profile(self, organization_id: str, profile_id: str) -> str | None:
|
||||
"""Retrieve browser profile from Azure."""
|
||||
profile_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/profiles/{profile_id}.zip"
|
||||
downloaded_zip_bytes = await self.async_client.download_file(profile_uri, log_exception=True)
|
||||
if not downloaded_zip_bytes:
|
||||
return None
|
||||
temp_zip_file = create_named_temporary_file(delete=False)
|
||||
temp_zip_file.write(downloaded_zip_bytes)
|
||||
temp_zip_file_path = temp_zip_file.name
|
||||
|
||||
temp_dir = make_temp_directory(prefix="skyvern_browser_profile_")
|
||||
unzip_files(temp_zip_file_path, temp_dir)
|
||||
temp_zip_file.close()
|
||||
return temp_dir
|
||||
|
||||
async def list_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
return [
|
||||
f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/{file}"
|
||||
for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
|
||||
async def get_shared_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[FileInfo]:
|
||||
object_keys = await self.list_downloaded_files_in_browser_session(organization_id, browser_session_id)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
metadata = {}
|
||||
modified_at: datetime | None = None
|
||||
# Get metadata (including checksum)
|
||||
try:
|
||||
object_info = await self.async_client.get_object_info(key)
|
||||
if object_info:
|
||||
metadata = object_info.get("Metadata", {})
|
||||
modified_at = object_info.get("LastModified")
|
||||
except Exception:
|
||||
LOG.exception("Object info retrieval failed", uri=key)
|
||||
|
||||
# Create FileInfo object
|
||||
filename = os.path.basename(key)
|
||||
checksum = metadata.get("sha256_checksum") if metadata else None
|
||||
|
||||
# Get SAS URL
|
||||
sas_urls = await self.async_client.create_sas_urls([key])
|
||||
if not sas_urls:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
url=sas_urls[0],
|
||||
checksum=checksum,
|
||||
filename=metadata.get("original_filename", filename) if metadata else filename,
|
||||
modified_at=modified_at,
|
||||
)
|
||||
file_infos.append(file_info)
|
||||
|
||||
return file_infos
|
||||
|
||||
async def list_downloading_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
files = [
|
||||
f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/{file}"
|
||||
for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
return [file for file in files if file.endswith(BROWSER_DOWNLOADING_SUFFIX)]
|
||||
|
||||
async def list_recordings_in_browser_session(self, organization_id: str, browser_session_id: str) -> list[str]:
|
||||
"""List all recording files for a browser session from Azure."""
|
||||
uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/videos"
|
||||
return [
|
||||
f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/{file}"
|
||||
for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
|
||||
async def get_shared_recordings_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[FileInfo]:
|
||||
"""Get recording files with SAS URLs for a browser session."""
|
||||
object_keys = await self.list_recordings_in_browser_session(organization_id, browser_session_id)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
# Playwright's record_video_dir should only contain .webm files.
|
||||
# Filter defensively in case of unexpected files.
|
||||
key_lower = key.lower()
|
||||
if not (key_lower.endswith(".webm") or key_lower.endswith(".mp4")):
|
||||
LOG.warning(
|
||||
"Skipping recording file with unsupported extension",
|
||||
uri=key,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
metadata = {}
|
||||
modified_at: datetime | None = None
|
||||
content_length: int | None = None
|
||||
# Get metadata (including checksum)
|
||||
try:
|
||||
object_info = await self.async_client.get_object_info(key)
|
||||
if object_info:
|
||||
metadata = object_info.get("Metadata", {})
|
||||
modified_at = object_info.get("LastModified")
|
||||
content_length = object_info.get("ContentLength") or object_info.get("Size")
|
||||
except Exception:
|
||||
LOG.exception("Recording object info retrieval failed", uri=key)
|
||||
|
||||
# Skip zero-byte objects (if any incomplete uploads)
|
||||
if content_length == 0:
|
||||
continue
|
||||
|
||||
# Create FileInfo object
|
||||
filename = os.path.basename(key)
|
||||
checksum = metadata.get("sha256_checksum") if metadata else None
|
||||
|
||||
# Get SAS URL
|
||||
sas_urls = await self.async_client.create_sas_urls([key])
|
||||
if not sas_urls:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
url=sas_urls[0],
|
||||
checksum=checksum,
|
||||
filename=metadata.get("original_filename", filename) if metadata else filename,
|
||||
modified_at=modified_at,
|
||||
)
|
||||
file_infos.append(file_info)
|
||||
|
||||
# Prefer the newest recording first (Azure list order is not guaranteed).
|
||||
# Treat None as "oldest".
|
||||
file_infos.sort(key=lambda f: (f.modified_at is not None, f.modified_at), reverse=True)
|
||||
return file_infos
|
||||
|
||||
async def save_downloaded_files(self, organization_id: str, run_id: str | None) -> None:
|
||||
download_dir = get_download_dir(run_id=run_id)
|
||||
files = os.listdir(download_dir)
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
base_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
for file in files:
|
||||
fpath = os.path.join(download_dir, file)
|
||||
if not os.path.isfile(fpath):
|
||||
continue
|
||||
uri = f"{base_uri}/{file}"
|
||||
checksum = calculate_sha256_for_file(fpath)
|
||||
LOG.info(
|
||||
"Calculated checksum for file",
|
||||
file=file,
|
||||
checksum=checksum,
|
||||
organization_id=organization_id,
|
||||
storage_tier=tier,
|
||||
)
|
||||
# Upload file with checksum metadata
|
||||
await self.async_client.upload_file_from_path(
|
||||
uri=uri,
|
||||
file_path=fpath,
|
||||
metadata={"sha256_checksum": checksum, "original_filename": file},
|
||||
tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
async def get_downloaded_files(self, organization_id: str, run_id: str | None) -> list[FileInfo]:
|
||||
uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
object_keys = await self.async_client.list_files(uri=uri)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
object_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{key}"
|
||||
|
||||
# Get metadata (including checksum)
|
||||
metadata = await self.async_client.get_file_metadata(object_uri, log_exception=False)
|
||||
|
||||
# Create FileInfo object
|
||||
filename = os.path.basename(key)
|
||||
checksum = metadata.get("sha256_checksum") if metadata else None
|
||||
|
||||
# Get SAS URL
|
||||
sas_urls = await self.async_client.create_sas_urls([object_uri])
|
||||
if not sas_urls:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
url=sas_urls[0],
|
||||
checksum=checksum,
|
||||
filename=metadata.get("original_filename", filename) if metadata else filename,
|
||||
)
|
||||
file_infos.append(file_info)
|
||||
|
||||
return file_infos
|
||||
|
||||
async def save_legacy_file(
|
||||
self, *, organization_id: str, filename: str, fileObj: BinaryIO
|
||||
) -> tuple[str, str] | None:
|
||||
todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
container = settings.AZURE_STORAGE_CONTAINER_UPLOADS
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
# First try uploading with original filename
|
||||
try:
|
||||
sanitized_filename = os.path.basename(filename) # Remove any path components
|
||||
azure_uri = f"azure://{container}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}"
|
||||
uploaded_uri = await self.async_client.upload_file_stream(azure_uri, fileObj, tier=tier, tags=tags)
|
||||
except Exception:
|
||||
LOG.error("Failed to upload file to Azure", exc_info=True)
|
||||
uploaded_uri = None
|
||||
|
||||
# If upload fails, try again with UUID prefix
|
||||
if not uploaded_uri:
|
||||
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}"
|
||||
azure_uri = f"azure://{container}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}"
|
||||
fileObj.seek(0) # Reset file pointer
|
||||
uploaded_uri = await self.async_client.upload_file_stream(azure_uri, fileObj, tier=tier, tags=tags)
|
||||
|
||||
if not uploaded_uri:
|
||||
LOG.error(
|
||||
"Failed to upload file to Azure after retrying with UUID prefix",
|
||||
organization_id=organization_id,
|
||||
storage_tier=tier,
|
||||
filename=filename,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
LOG.debug(
|
||||
"Legacy file upload",
|
||||
organization_id=organization_id,
|
||||
storage_tier=tier,
|
||||
filename=filename,
|
||||
uploaded_uri=uploaded_uri,
|
||||
)
|
||||
# Generate a SAS URL for the uploaded file
|
||||
sas_urls = await self.async_client.create_sas_urls([uploaded_uri])
|
||||
if not sas_urls:
|
||||
LOG.error(
|
||||
"Failed to create SAS URL for uploaded file",
|
||||
organization_id=organization_id,
|
||||
storage_tier=tier,
|
||||
uploaded_uri=uploaded_uri,
|
||||
filename=filename,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
return sas_urls[0], uploaded_uri
|
||||
|
||||
def _build_browser_session_uri(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Build the Azure URI for a browser session file."""
|
||||
base = f"azure://{self.container}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/{artifact_type}"
|
||||
if date:
|
||||
return f"{base}/{date}/{remote_path}"
|
||||
return f"{base}/{remote_path}"
|
||||
|
||||
async def sync_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
local_file_path: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Sync a file from local browser session to Azure."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
await self.async_client.upload_file_from_path(uri, local_file_path, tier=tier, tags=tags)
|
||||
return uri
|
||||
|
||||
async def delete_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> None:
|
||||
"""Delete a file from browser session storage in Azure."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
await self.async_client.delete_file(uri)
|
||||
|
||||
async def browser_session_file_exists(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> bool:
|
||||
"""Check if a file exists in browser session storage in Azure."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
try:
|
||||
info = await self.async_client.get_object_info(uri)
|
||||
return info is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def download_uploaded_file(self, uri: str) -> bytes | None:
|
||||
"""Download a user-uploaded file from Azure."""
|
||||
return await self.async_client.download_file(uri, log_exception=False)
|
||||
|
||||
async def file_exists(self, uri: str) -> bool:
|
||||
"""Check if a file exists at the given Azure URI."""
|
||||
try:
|
||||
info = await self.async_client.get_object_info(uri)
|
||||
return info is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def storage_type(self) -> str:
|
||||
"""Returns 'azure' as the storage type."""
|
||||
return "azure"
|
||||
@@ -172,3 +172,50 @@ class BaseStorage(ABC):
|
||||
self, *, organization_id: str, filename: str, fileObj: BinaryIO
|
||||
) -> tuple[str, str] | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def sync_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
local_file_path: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def browser_session_file_exists(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def download_uploaded_file(self, uri: str) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def file_exists(self, uri: str) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def storage_type(self) -> str:
|
||||
pass
|
||||
|
||||
@@ -160,11 +160,11 @@ class LocalStorage(BaseStorage):
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_share_link(self, artifact: Artifact) -> str:
|
||||
return artifact.uri
|
||||
async def get_share_link(self, artifact: Artifact) -> str | None:
|
||||
return None
|
||||
|
||||
async def get_share_links(self, artifacts: list[Artifact]) -> list[str]:
|
||||
return [artifact.uri for artifact in artifacts]
|
||||
async def get_share_links(self, artifacts: list[Artifact]) -> list[str] | None:
|
||||
return None
|
||||
|
||||
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
||||
return
|
||||
@@ -346,3 +346,98 @@ class LocalStorage(BaseStorage):
|
||||
raise NotImplementedError(
|
||||
"Legacy file storage is not implemented for LocalStorage. Please use a different storage backend."
|
||||
)
|
||||
|
||||
def _build_browser_session_path(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> Path:
|
||||
"""Build the local path for a browser session file."""
|
||||
base = (
|
||||
Path(self.artifact_path)
|
||||
/ settings.ENV
|
||||
/ organization_id
|
||||
/ "browser_sessions"
|
||||
/ browser_session_id
|
||||
/ artifact_type
|
||||
)
|
||||
if date:
|
||||
return base / date / remote_path
|
||||
return base / remote_path
|
||||
|
||||
async def sync_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
local_file_path: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Sync a file from local browser session to local storage."""
|
||||
target_path = self._build_browser_session_path(
|
||||
organization_id, browser_session_id, artifact_type, remote_path, date
|
||||
)
|
||||
if WINDOWS:
|
||||
target_path = target_path.with_name(_windows_safe_filename(target_path.name))
|
||||
self._create_directories_if_not_exists(target_path)
|
||||
shutil.copy2(local_file_path, target_path)
|
||||
return f"file://{target_path}"
|
||||
|
||||
async def delete_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> None:
|
||||
"""Delete a file from browser session storage in local filesystem."""
|
||||
target_path = self._build_browser_session_path(
|
||||
organization_id, browser_session_id, artifact_type, remote_path, date
|
||||
)
|
||||
try:
|
||||
if target_path.exists():
|
||||
target_path.unlink()
|
||||
except Exception:
|
||||
LOG.exception("Failed to delete local browser session file", path=str(target_path))
|
||||
|
||||
async def browser_session_file_exists(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> bool:
|
||||
"""Check if a file exists in browser session storage in local filesystem."""
|
||||
target_path = self._build_browser_session_path(
|
||||
organization_id, browser_session_id, artifact_type, remote_path, date
|
||||
)
|
||||
return target_path.exists()
|
||||
|
||||
async def download_uploaded_file(self, uri: str) -> bytes | None:
|
||||
"""Download a user-uploaded file from local filesystem."""
|
||||
try:
|
||||
file_path = parse_uri_to_path(uri)
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
LOG.exception("Failed to read local file", uri=uri)
|
||||
return None
|
||||
|
||||
async def file_exists(self, uri: str) -> bool:
|
||||
"""Check if a file exists at the given local URI."""
|
||||
try:
|
||||
file_path = parse_uri_to_path(uri)
|
||||
return os.path.exists(file_path)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def storage_type(self) -> str:
|
||||
"""Returns 'file' as the storage type."""
|
||||
return "file"
|
||||
|
||||
@@ -235,10 +235,9 @@ class S3Storage(BaseStorage):
|
||||
async def list_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
return [
|
||||
f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
bucket = settings.AWS_S3_BUCKET_ARTIFACTS
|
||||
uri = f"s3://{bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
return [f"s3://{bucket}/{file}" for file in await self.async_client.list_files(uri=uri)]
|
||||
|
||||
async def get_shared_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
@@ -281,18 +280,16 @@ class S3Storage(BaseStorage):
|
||||
async def list_downloading_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
files = [
|
||||
f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
bucket = settings.AWS_S3_BUCKET_ARTIFACTS
|
||||
uri = f"s3://{bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
files = [f"s3://{bucket}/{file}" for file in await self.async_client.list_files(uri=uri)]
|
||||
return [file for file in files if file.endswith(BROWSER_DOWNLOADING_SUFFIX)]
|
||||
|
||||
async def list_recordings_in_browser_session(self, organization_id: str, browser_session_id: str) -> list[str]:
|
||||
"""List all recording files for a browser session from S3."""
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/videos"
|
||||
return [
|
||||
f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
bucket = settings.AWS_S3_BUCKET_ARTIFACTS
|
||||
uri = f"s3://{bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/videos"
|
||||
return [f"s3://{bucket}/{file}" for file in await self.async_client.list_files(uri=uri)]
|
||||
|
||||
async def get_shared_recordings_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
@@ -385,14 +382,15 @@ class S3Storage(BaseStorage):
|
||||
)
|
||||
|
||||
async def get_downloaded_files(self, organization_id: str, run_id: str | None) -> list[FileInfo]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
bucket = settings.AWS_S3_BUCKET_UPLOADS
|
||||
uri = f"s3://{bucket}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
object_keys = await self.async_client.list_files(uri=uri)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
object_uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{key}"
|
||||
object_uri = f"s3://{bucket}/{key}"
|
||||
|
||||
# Get metadata (including checksum)
|
||||
metadata = await self.async_client.get_file_metadata(object_uri, log_exception=False)
|
||||
@@ -467,3 +465,78 @@ class S3Storage(BaseStorage):
|
||||
)
|
||||
return None
|
||||
return presigned_urls[0], uploaded_s3_uri
|
||||
|
||||
def _build_browser_session_uri(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Build the S3 URI for a browser session file."""
|
||||
base = f"s3://{self.bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/{artifact_type}"
|
||||
if date:
|
||||
return f"{base}/{date}/{remote_path}"
|
||||
return f"{base}/{remote_path}"
|
||||
|
||||
async def sync_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
local_file_path: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Sync a file from local browser session to S3."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
await self.async_client.upload_file_from_path(uri, local_file_path, storage_class=sc, tags=tags)
|
||||
return uri
|
||||
|
||||
async def delete_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> None:
|
||||
"""Delete a file from browser session storage in S3."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
await self.async_client.delete_file(uri, log_exception=True)
|
||||
|
||||
async def browser_session_file_exists(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> bool:
|
||||
"""Check if a file exists in browser session storage in S3."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
try:
|
||||
info = await self.async_client.get_object_info(uri)
|
||||
return info is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def download_uploaded_file(self, uri: str) -> bytes | None:
|
||||
"""Download a user-uploaded file from S3."""
|
||||
return await self.async_client.download_file(uri, log_exception=False)
|
||||
|
||||
async def file_exists(self, uri: str) -> bool:
|
||||
"""Check if a file exists at the given S3 URI."""
|
||||
try:
|
||||
info = await self.async_client.get_object_info(uri)
|
||||
return info is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def storage_type(self) -> str:
|
||||
"""Returns 's3' as the storage type."""
|
||||
return "s3"
|
||||
|
||||
251
skyvern/forge/sdk/artifact/storage/test_azure_storage.py
Normal file
251
skyvern/forge/sdk/artifact/storage/test_azure_storage.py
Normal file
@@ -0,0 +1,251 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.api.azure import StandardBlobTier
|
||||
from skyvern.forge.sdk.api.real_azure import RealAsyncAzureStorageClient
|
||||
from skyvern.forge.sdk.artifact.storage.azure import AzureStorage
|
||||
|
||||
# Test constants
|
||||
TEST_CONTAINER = "test-azure-container"
|
||||
TEST_ORGANIZATION_ID = "test-org-123"
|
||||
TEST_BROWSER_SESSION_ID = "bs_test_123"
|
||||
|
||||
|
||||
class AzureStorageForTests(AzureStorage):
|
||||
"""Test subclass that overrides org-specific methods and bypasses client init."""
|
||||
|
||||
async_client: Any # Allow mock attribute access
|
||||
|
||||
def __init__(self, container: str) -> None:
|
||||
# Don't call super().__init__ to avoid creating real RealAsyncAzureStorageClient
|
||||
self.container = container
|
||||
self.async_client = AsyncMock()
|
||||
|
||||
async def _get_storage_tier_for_org(self, organization_id: str) -> StandardBlobTier:
|
||||
return StandardBlobTier.HOT
|
||||
|
||||
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||
return {"test": "tag"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_storage() -> AzureStorageForTests:
|
||||
"""Create AzureStorage with mocked async_client."""
|
||||
return AzureStorageForTests(container=TEST_CONTAINER)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAzureStorageBrowserSessionFiles:
|
||||
"""Test AzureStorage browser session file methods."""
|
||||
|
||||
async def test_sync_browser_session_file_with_date(
|
||||
self, azure_storage: AzureStorageForTests, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test syncing a file with date in path (videos/har)."""
|
||||
test_file = tmp_path / "recording.webm"
|
||||
test_file.write_bytes(b"fake video data")
|
||||
|
||||
uri = await azure_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="recording.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
expected_uri = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/recording.webm"
|
||||
assert uri == expected_uri
|
||||
azure_storage.async_client.upload_file_from_path.assert_called_once()
|
||||
|
||||
async def test_sync_browser_session_file_without_date(
|
||||
self, azure_storage: AzureStorageForTests, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test syncing a file without date (downloads category)."""
|
||||
test_file = tmp_path / "document.pdf"
|
||||
test_file.write_bytes(b"fake download data")
|
||||
|
||||
uri = await azure_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="document.pdf",
|
||||
date=None,
|
||||
)
|
||||
|
||||
expected_uri = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/downloads/document.pdf"
|
||||
assert uri == expected_uri
|
||||
|
||||
async def test_browser_session_file_exists_returns_true(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test browser_session_file_exists returns True when file exists."""
|
||||
azure_storage.async_client.get_object_info.return_value = {"LastModified": "2025-01-15"}
|
||||
|
||||
exists = await azure_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="exists.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
assert exists is True
|
||||
|
||||
async def test_browser_session_file_exists_returns_false_on_exception(
|
||||
self, azure_storage: AzureStorageForTests
|
||||
) -> None:
|
||||
"""Test browser_session_file_exists returns False when exception is raised."""
|
||||
azure_storage.async_client.get_object_info.side_effect = Exception("Not found")
|
||||
|
||||
exists = await azure_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="nonexistent.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
assert exists is False
|
||||
|
||||
async def test_delete_browser_session_file(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test deleting a browser session file."""
|
||||
await azure_storage.delete_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
expected_uri = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/to_delete.webm"
|
||||
azure_storage.async_client.delete_file.assert_called_once_with(expected_uri)
|
||||
|
||||
async def test_file_exists_returns_true(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test file_exists returns True when file exists."""
|
||||
azure_storage.async_client.get_object_info.return_value = {"LastModified": "2025-01-15"}
|
||||
uri = f"azure://{TEST_CONTAINER}/test/file.txt"
|
||||
|
||||
exists = await azure_storage.file_exists(uri)
|
||||
|
||||
assert exists is True
|
||||
|
||||
async def test_file_exists_returns_false_on_exception(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test file_exists returns False when exception is raised (404)."""
|
||||
azure_storage.async_client.get_object_info.side_effect = Exception("Not found")
|
||||
uri = f"azure://{TEST_CONTAINER}/nonexistent/file.txt"
|
||||
|
||||
exists = await azure_storage.file_exists(uri)
|
||||
|
||||
assert exists is False
|
||||
|
||||
async def test_download_uploaded_file(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test downloading an uploaded file."""
|
||||
test_data = b"uploaded file content"
|
||||
azure_storage.async_client.download_file.return_value = test_data
|
||||
uri = f"azure://{TEST_CONTAINER}/uploads/file.pdf"
|
||||
|
||||
downloaded = await azure_storage.download_uploaded_file(uri)
|
||||
|
||||
assert downloaded == test_data
|
||||
azure_storage.async_client.download_file.assert_called_once_with(uri, log_exception=False)
|
||||
|
||||
async def test_download_uploaded_file_returns_none(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test downloading a non-existent file returns None."""
|
||||
azure_storage.async_client.download_file.return_value = None
|
||||
uri = f"azure://{TEST_CONTAINER}/nonexistent/file.txt"
|
||||
|
||||
downloaded = await azure_storage.download_uploaded_file(uri)
|
||||
|
||||
assert downloaded is None
|
||||
|
||||
def test_storage_type_property(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test storage_type returns 'azure'."""
|
||||
assert azure_storage.storage_type == "azure"
|
||||
|
||||
|
||||
class TestAzureStorageBuildUri:
|
||||
"""Test Azure URI building methods."""
|
||||
|
||||
def test_build_browser_session_uri_with_date(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test building URI with date."""
|
||||
uri = azure_storage._build_browser_session_uri(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="file.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
expected = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/file.webm"
|
||||
assert uri == expected
|
||||
|
||||
def test_build_browser_session_uri_without_date(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test building URI without date."""
|
||||
uri = azure_storage._build_browser_session_uri(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
remote_path="file.pdf",
|
||||
date=None,
|
||||
)
|
||||
|
||||
expected = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/downloads/file.pdf"
|
||||
assert uri == expected
|
||||
|
||||
|
||||
AZURE_CONTENT_TYPE_TEST_CASES = [
|
||||
# (filename, expected_content_type, artifact_type, date)
|
||||
("video.webm", "video/webm", "videos", "2025-01-15"),
|
||||
("data.json", "application/json", "har", "2025-01-15"),
|
||||
("network.har", "application/json", "har", "2025-01-15"),
|
||||
("screenshot.png", "image/png", "downloads", None),
|
||||
("output.txt", "text/plain", "downloads", None),
|
||||
("debug.log", "text/plain", "downloads", None),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAzureStorageContentType:
|
||||
"""Test Azure Storage content type guessing.
|
||||
|
||||
Tests at two levels:
|
||||
1. High-level: sync_browser_session_file interface with artifact_type/date
|
||||
2. Low-level: RealAsyncAzureStorageClient to verify ContentSettings is passed
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("filename,expected_content_type,artifact_type,date", AZURE_CONTENT_TYPE_TEST_CASES)
|
||||
async def test_content_type_guessing(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
filename: str,
|
||||
expected_content_type: str,
|
||||
artifact_type: str,
|
||||
date: str | None,
|
||||
) -> None:
|
||||
"""Test that RealAsyncAzureStorageClient sets correct content type based on extension."""
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_bytes(b"test content")
|
||||
|
||||
with patch.object(RealAsyncAzureStorageClient, "_get_blob_service_client") as mock_get_client:
|
||||
mock_container_client = MagicMock()
|
||||
mock_container_client.upload_blob = AsyncMock()
|
||||
mock_container_client.exists = AsyncMock(return_value=True)
|
||||
mock_blob_service = MagicMock()
|
||||
mock_blob_service.get_container_client.return_value = mock_container_client
|
||||
mock_get_client.return_value = mock_blob_service
|
||||
|
||||
client = RealAsyncAzureStorageClient(account_name="test", account_key="testkey")
|
||||
client._verified_containers.add("test-container")
|
||||
|
||||
await client.upload_file_from_path(
|
||||
uri=f"azure://test-container/path/{filename}",
|
||||
file_path=str(test_file),
|
||||
)
|
||||
|
||||
call_kwargs = mock_container_client.upload_blob.call_args.kwargs
|
||||
assert call_kwargs["content_settings"] is not None
|
||||
assert call_kwargs["content_settings"].content_type == expected_content_type
|
||||
@@ -221,3 +221,229 @@ class TestS3StorageStore:
|
||||
await s3_storage.store_artifact(artifact, test_data)
|
||||
_assert_object_content(boto3_test_client, artifact.uri, test_data)
|
||||
_assert_object_meta(boto3_test_client, artifact.uri)
|
||||
|
||||
|
||||
TEST_BROWSER_SESSION_ID = "bs_test_123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestS3StorageBrowserSessionFiles:
|
||||
"""Test S3Storage browser session file methods."""
|
||||
|
||||
async def test_sync_browser_session_file_with_date(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test syncing a file with date in path (videos/har)."""
|
||||
test_data = b"fake video data"
|
||||
test_file = tmp_path / "recording.webm"
|
||||
test_file.write_bytes(test_data)
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="recording.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
expected_uri = f"s3://{TEST_BUCKET}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/recording.webm"
|
||||
assert uri == expected_uri
|
||||
_assert_object_content(boto3_test_client, uri, test_data)
|
||||
_assert_object_meta(boto3_test_client, uri)
|
||||
|
||||
async def test_sync_browser_session_file_without_date(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test syncing a file without date (downloads category)."""
|
||||
test_data = b"fake download data"
|
||||
test_file = tmp_path / "document.pdf"
|
||||
test_file.write_bytes(test_data)
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="document.pdf",
|
||||
date=None,
|
||||
)
|
||||
|
||||
expected_uri = f"s3://{TEST_BUCKET}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/downloads/document.pdf"
|
||||
assert uri == expected_uri
|
||||
_assert_object_content(boto3_test_client, uri, test_data)
|
||||
|
||||
async def test_browser_session_file_exists_returns_true(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test browser_session_file_exists returns True for existing file."""
|
||||
test_file = tmp_path / "exists.webm"
|
||||
test_file.write_bytes(b"test data")
|
||||
|
||||
await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="exists.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
exists = await s3_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="exists.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
assert exists is True
|
||||
|
||||
async def test_browser_session_file_exists_returns_false(self, s3_storage: S3Storage) -> None:
|
||||
"""Test browser_session_file_exists returns False for non-existent file."""
|
||||
exists = await s3_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="nonexistent.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
assert exists is False
|
||||
|
||||
async def test_delete_browser_session_file(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting a browser session file."""
|
||||
test_file = tmp_path / "to_delete.webm"
|
||||
test_file.write_bytes(b"test data")
|
||||
|
||||
await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
exists_before = await s3_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
assert exists_before is True
|
||||
|
||||
await s3_storage.delete_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
exists_after = await s3_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
assert exists_after is False
|
||||
|
||||
async def test_file_exists_returns_true(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test file_exists returns True for existing file."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_bytes(b"test data")
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="test.txt",
|
||||
)
|
||||
|
||||
exists = await s3_storage.file_exists(uri)
|
||||
assert exists is True
|
||||
|
||||
async def test_file_exists_returns_false(self, s3_storage: S3Storage) -> None:
|
||||
"""Test file_exists returns False for non-existent file."""
|
||||
uri = f"s3://{TEST_BUCKET}/nonexistent/path/file.txt"
|
||||
exists = await s3_storage.file_exists(uri)
|
||||
assert exists is False
|
||||
|
||||
async def test_download_uploaded_file(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test downloading an uploaded file."""
|
||||
test_data = b"uploaded file content"
|
||||
test_file = tmp_path / "uploaded.pdf"
|
||||
test_file.write_bytes(test_data)
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="uploaded.pdf",
|
||||
)
|
||||
|
||||
downloaded = await s3_storage.download_uploaded_file(uri)
|
||||
assert downloaded == test_data
|
||||
|
||||
async def test_download_uploaded_file_nonexistent(self, s3_storage: S3Storage) -> None:
|
||||
"""Test downloading a non-existent file returns None."""
|
||||
uri = f"s3://{TEST_BUCKET}/nonexistent/path/file.txt"
|
||||
downloaded = await s3_storage.download_uploaded_file(uri)
|
||||
assert downloaded is None
|
||||
|
||||
def test_storage_type_property(self, s3_storage: S3Storage) -> None:
|
||||
"""Test storage_type returns 's3'."""
|
||||
assert s3_storage.storage_type == "s3"
|
||||
|
||||
|
||||
CONTENT_TYPE_TEST_CASES = [
|
||||
# (filename, expected_content_type, artifact_type, date)
|
||||
("video.webm", "video/webm", "videos", "2025-01-15"),
|
||||
("data.json", "application/json", "har", "2025-01-15"),
|
||||
("network.har", "application/json", "har", "2025-01-15"),
|
||||
("screenshot.png", "image/png", "downloads", None),
|
||||
("output.txt", "text/plain", "downloads", None),
|
||||
("debug.log", "text/plain", "downloads", None),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestS3StorageContentType:
|
||||
"""Test S3Storage content type guessing."""
|
||||
|
||||
@pytest.mark.parametrize("filename,expected_content_type,artifact_type,date", CONTENT_TYPE_TEST_CASES)
|
||||
async def test_content_type_guessing(
|
||||
self,
|
||||
s3_storage: S3Storage,
|
||||
boto3_test_client: S3Client,
|
||||
tmp_path: Path,
|
||||
filename: str,
|
||||
expected_content_type: str,
|
||||
artifact_type: str,
|
||||
date: str | None,
|
||||
) -> None:
|
||||
"""Test that files get correct content type based on extension."""
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_bytes(b"test content")
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type=artifact_type,
|
||||
local_file_path=str(test_file),
|
||||
remote_path=filename,
|
||||
date=date,
|
||||
)
|
||||
|
||||
s3uri = S3Uri(uri)
|
||||
obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
assert obj_meta["ContentType"] == expected_content_type
|
||||
|
||||
@@ -1294,15 +1294,9 @@ async def get_artifact(
|
||||
status_code=http_status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Artifact not found {artifact_id}",
|
||||
)
|
||||
if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS:
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links([artifact])
|
||||
if signed_urls:
|
||||
artifact.signed_url = signed_urls[0]
|
||||
else:
|
||||
LOG.warning(
|
||||
"Failed to get signed url for artifact",
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links([artifact])
|
||||
if signed_urls and len(signed_urls) == 1:
|
||||
artifact.signed_url = signed_urls[0]
|
||||
return artifact
|
||||
|
||||
|
||||
@@ -1334,23 +1328,10 @@ async def get_run_artifacts(
|
||||
# Ensure we have a list of artifacts (since group_by_type=False, this will always be a list)
|
||||
artifacts_list = artifacts if isinstance(artifacts, list) else []
|
||||
|
||||
if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS:
|
||||
# Get signed URLs for all artifacts
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts_list)
|
||||
|
||||
if signed_urls and len(signed_urls) == len(artifacts_list):
|
||||
for i, artifact in enumerate(artifacts_list):
|
||||
if hasattr(artifact, "signed_url"):
|
||||
artifact.signed_url = signed_urls[i]
|
||||
elif signed_urls:
|
||||
LOG.warning(
|
||||
"Mismatch between artifacts and signed URLs count",
|
||||
artifacts_count=len(artifacts_list),
|
||||
urls_count=len(signed_urls),
|
||||
run_id=run_id,
|
||||
)
|
||||
else:
|
||||
LOG.warning("Failed to get signed urls for artifacts", run_id=run_id)
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts_list)
|
||||
if signed_urls and len(signed_urls) == len(artifacts_list):
|
||||
for i, artifact in enumerate(artifacts_list):
|
||||
artifact.signed_url = signed_urls[i]
|
||||
|
||||
return ORJSONResponse([artifact.model_dump() for artifact in artifacts_list])
|
||||
|
||||
@@ -1976,17 +1957,10 @@ async def get_artifacts(
|
||||
}
|
||||
artifacts = await app.DATABASE.get_artifacts_by_entity_id(organization_id=current_org.organization_id, **params) # type: ignore
|
||||
|
||||
if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS:
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts)
|
||||
if signed_urls:
|
||||
for i, artifact in enumerate(artifacts):
|
||||
artifact.signed_url = signed_urls[i]
|
||||
else:
|
||||
LOG.warning(
|
||||
"Failed to get signed urls for artifacts",
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts)
|
||||
if signed_urls and len(signed_urls) == len(artifacts):
|
||||
for i, artifact in enumerate(artifacts):
|
||||
artifact.signed_url = signed_urls[i]
|
||||
|
||||
return ORJSONResponse([artifact.model_dump() for artifact in artifacts])
|
||||
|
||||
@@ -2021,17 +1995,10 @@ async def get_step_artifacts(
|
||||
step_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
if settings.ENV != "local" or settings.GENERATE_PRESIGNED_URLS:
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts)
|
||||
if signed_urls:
|
||||
for i, artifact in enumerate(artifacts):
|
||||
artifact.signed_url = signed_urls[i]
|
||||
else:
|
||||
LOG.warning(
|
||||
"Failed to get signed urls for artifacts",
|
||||
task_id=task_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts)
|
||||
if signed_urls and len(signed_urls) == len(artifacts):
|
||||
for i, artifact in enumerate(artifacts):
|
||||
artifact.signed_url = signed_urls[i]
|
||||
return ORJSONResponse([artifact.model_dump() for artifact in artifacts])
|
||||
|
||||
|
||||
|
||||
@@ -2603,9 +2603,8 @@ class FileUploadBlock(Block):
|
||||
blob_name = self._get_azure_blob_name(workflow_run_id, file_path)
|
||||
azure_uri = self._get_azure_blob_uri(workflow_run_id, blob_name)
|
||||
uploaded_uris.append(azure_uri)
|
||||
await azure_client.upload_file_from_path(
|
||||
container_name=self.azure_blob_container_name or "", blob_name=blob_name, file_path=file_path
|
||||
)
|
||||
uri = f"azure://{self.azure_blob_container_name or ''}/{blob_name}"
|
||||
await azure_client.upload_file_from_path(uri, file_path)
|
||||
LOG.info("FileUploadBlock: File(s) uploaded to Azure Blob Storage", file_path=self.path)
|
||||
else:
|
||||
# This case should ideally be caught by the initial validation
|
||||
|
||||
Reference in New Issue
Block a user