add azure blob storage (#4338)
Signed-off-by: Benji Visser <benji@093b.org> Co-authored-by: Benji Visser <benji@093b.org> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
539
skyvern/forge/sdk/artifact/storage/azure.py
Normal file
539
skyvern/forge/sdk/artifact/storage/azure.py
Normal file
@@ -0,0 +1,539 @@
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import BinaryIO
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import BROWSER_DOWNLOADING_SUFFIX, DOWNLOAD_FILE_PREFIX
|
||||
from skyvern.forge.sdk.api.azure import StandardBlobTier
|
||||
from skyvern.forge.sdk.api.files import (
|
||||
calculate_sha256_for_file,
|
||||
create_named_temporary_file,
|
||||
get_download_dir,
|
||||
get_skyvern_temp_dir,
|
||||
make_temp_directory,
|
||||
unzip_files,
|
||||
)
|
||||
from skyvern.forge.sdk.api.real_azure import RealAsyncAzureStorageClient
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType, LogEntityType
|
||||
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AzureStorage(BaseStorage):
|
||||
_PATH_VERSION = "v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
container: str | None = None,
|
||||
account_name: str | None = None,
|
||||
account_key: str | None = None,
|
||||
) -> None:
|
||||
self.async_client = RealAsyncAzureStorageClient(account_name=account_name, account_key=account_key)
|
||||
self.container = container or settings.AZURE_STORAGE_CONTAINER_ARTIFACTS
|
||||
|
||||
def build_uri(self, *, organization_id: str, artifact_id: str, step: Step, artifact_type: ArtifactType) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/{step.task_id}/{step.order:02d}_{step.retry_index}_{step.step_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
async def retrieve_global_workflows(self) -> list[str]:
|
||||
uri = f"azure://{self.container}/{settings.ENV}/global_workflows.txt"
|
||||
data = await self.async_client.download_file(uri, log_exception=False)
|
||||
if not data:
|
||||
return []
|
||||
return [line.strip() for line in data.decode("utf-8").split("\n") if line.strip()]
|
||||
|
||||
def _build_base_uri(self, organization_id: str) -> str:
|
||||
return f"azure://{self.container}/{self._PATH_VERSION}/{settings.ENV}/{organization_id}"
|
||||
|
||||
def build_log_uri(
|
||||
self, *, organization_id: str, log_entity_type: LogEntityType, log_entity_id: str, artifact_type: ArtifactType
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/logs/{log_entity_type}/{log_entity_id}/{datetime.utcnow().isoformat()}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_thought_uri(
|
||||
self, *, organization_id: str, artifact_id: str, thought: Thought, artifact_type: ArtifactType
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/observers/{thought.observer_cruise_id}/{thought.observer_thought_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_task_v2_uri(
|
||||
self, *, organization_id: str, artifact_id: str, task_v2: TaskV2, artifact_type: ArtifactType
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/observers/{task_v2.observer_cruise_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_workflow_run_block_uri(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
artifact_id: str,
|
||||
workflow_run_block: WorkflowRunBlock,
|
||||
artifact_type: ArtifactType,
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/workflow_runs/{workflow_run_block.workflow_run_id}/{workflow_run_block.workflow_run_block_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_ai_suggestion_uri(
|
||||
self, *, organization_id: str, artifact_id: str, ai_suggestion: AISuggestion, artifact_type: ArtifactType
|
||||
) -> str:
|
||||
file_ext = FILE_EXTENTSION_MAP[artifact_type]
|
||||
return f"{self._build_base_uri(organization_id)}/ai_suggestions/{ai_suggestion.ai_suggestion_id}/{datetime.utcnow().isoformat()}_{artifact_id}_{artifact_type}.{file_ext}"
|
||||
|
||||
def build_script_file_uri(
|
||||
self, *, organization_id: str, script_id: str, script_version: int, file_path: str
|
||||
) -> str:
|
||||
"""Build the Azure URI for a script file."""
|
||||
return f"{self._build_base_uri(organization_id)}/scripts/{script_id}/{script_version}/{file_path}"
|
||||
|
||||
async def store_artifact(self, artifact: Artifact, data: bytes) -> None:
|
||||
tier = await self._get_storage_tier_for_org(artifact.organization_id)
|
||||
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||
LOG.debug(
|
||||
"Storing artifact",
|
||||
artifact_id=artifact.artifact_id,
|
||||
organization_id=artifact.organization_id,
|
||||
uri=artifact.uri,
|
||||
storage_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file(artifact.uri, data, tier=tier, tags=tags)
|
||||
|
||||
async def _get_storage_tier_for_org(self, organization_id: str) -> StandardBlobTier:
|
||||
return StandardBlobTier.HOT
|
||||
|
||||
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
return await self.async_client.download_file(artifact.uri)
|
||||
|
||||
async def get_share_link(self, artifact: Artifact) -> str | None:
|
||||
share_urls = await self.async_client.create_sas_urls([artifact.uri])
|
||||
return share_urls[0] if share_urls else None
|
||||
|
||||
async def get_share_links(self, artifacts: list[Artifact]) -> list[str] | None:
|
||||
return await self.async_client.create_sas_urls([artifact.uri for artifact in artifacts])
|
||||
|
||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||
tier = await self._get_storage_tier_for_org(artifact.organization_id)
|
||||
tags = await self._get_tags_for_org(artifact.organization_id)
|
||||
LOG.debug(
|
||||
"Storing artifact from path",
|
||||
artifact_id=artifact.artifact_id,
|
||||
organization_id=artifact.organization_id,
|
||||
uri=artifact.uri,
|
||||
storage_tier=tier,
|
||||
path=path,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(artifact.uri, path, tier=tier, tags=tags)
|
||||
|
||||
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
||||
from_path = f"{get_skyvern_temp_dir()}/{organization_id}/{file_name}"
|
||||
to_path = f"azure://{settings.AZURE_STORAGE_CONTAINER_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Saving streaming file",
|
||||
organization_id=organization_id,
|
||||
file_name=file_name,
|
||||
from_path=from_path,
|
||||
to_path=to_path,
|
||||
storage_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(to_path, from_path, tier=tier, tags=tags)
|
||||
|
||||
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
|
||||
path = f"azure://{settings.AZURE_STORAGE_CONTAINER_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
|
||||
return await self.async_client.download_file(path, log_exception=False)
|
||||
|
||||
async def store_browser_session(self, organization_id: str, workflow_permanent_id: str, directory: str) -> None:
|
||||
# Zip the directory to a temp file
|
||||
temp_zip_file = create_named_temporary_file()
|
||||
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
|
||||
browser_session_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Storing browser session",
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
zip_file_path=zip_file_path,
|
||||
browser_session_uri=browser_session_uri,
|
||||
storage_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path, tier=tier, tags=tags)
|
||||
|
||||
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
|
||||
browser_session_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
|
||||
downloaded_zip_bytes = await self.async_client.download_file(browser_session_uri, log_exception=True)
|
||||
if not downloaded_zip_bytes:
|
||||
return None
|
||||
temp_zip_file = create_named_temporary_file(delete=False)
|
||||
temp_zip_file.write(downloaded_zip_bytes)
|
||||
temp_zip_file_path = temp_zip_file.name
|
||||
|
||||
temp_dir = make_temp_directory(prefix="skyvern_browser_session_")
|
||||
unzip_files(temp_zip_file_path, temp_dir)
|
||||
temp_zip_file.close()
|
||||
return temp_dir
|
||||
|
||||
async def store_browser_profile(self, organization_id: str, profile_id: str, directory: str) -> None:
|
||||
"""Store browser profile to Azure."""
|
||||
temp_zip_file = create_named_temporary_file()
|
||||
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
|
||||
profile_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/profiles/{profile_id}.zip"
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
LOG.debug(
|
||||
"Storing browser profile",
|
||||
organization_id=organization_id,
|
||||
profile_id=profile_id,
|
||||
zip_file_path=zip_file_path,
|
||||
profile_uri=profile_uri,
|
||||
storage_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
await self.async_client.upload_file_from_path(profile_uri, zip_file_path, tier=tier, tags=tags)
|
||||
|
||||
async def retrieve_browser_profile(self, organization_id: str, profile_id: str) -> str | None:
|
||||
"""Retrieve browser profile from Azure."""
|
||||
profile_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/profiles/{profile_id}.zip"
|
||||
downloaded_zip_bytes = await self.async_client.download_file(profile_uri, log_exception=True)
|
||||
if not downloaded_zip_bytes:
|
||||
return None
|
||||
temp_zip_file = create_named_temporary_file(delete=False)
|
||||
temp_zip_file.write(downloaded_zip_bytes)
|
||||
temp_zip_file_path = temp_zip_file.name
|
||||
|
||||
temp_dir = make_temp_directory(prefix="skyvern_browser_profile_")
|
||||
unzip_files(temp_zip_file_path, temp_dir)
|
||||
temp_zip_file.close()
|
||||
return temp_dir
|
||||
|
||||
async def list_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
return [
|
||||
f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/{file}"
|
||||
for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
|
||||
async def get_shared_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[FileInfo]:
|
||||
object_keys = await self.list_downloaded_files_in_browser_session(organization_id, browser_session_id)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
metadata = {}
|
||||
modified_at: datetime | None = None
|
||||
# Get metadata (including checksum)
|
||||
try:
|
||||
object_info = await self.async_client.get_object_info(key)
|
||||
if object_info:
|
||||
metadata = object_info.get("Metadata", {})
|
||||
modified_at = object_info.get("LastModified")
|
||||
except Exception:
|
||||
LOG.exception("Object info retrieval failed", uri=key)
|
||||
|
||||
# Create FileInfo object
|
||||
filename = os.path.basename(key)
|
||||
checksum = metadata.get("sha256_checksum") if metadata else None
|
||||
|
||||
# Get SAS URL
|
||||
sas_urls = await self.async_client.create_sas_urls([key])
|
||||
if not sas_urls:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
url=sas_urls[0],
|
||||
checksum=checksum,
|
||||
filename=metadata.get("original_filename", filename) if metadata else filename,
|
||||
modified_at=modified_at,
|
||||
)
|
||||
file_infos.append(file_info)
|
||||
|
||||
return file_infos
|
||||
|
||||
async def list_downloading_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
files = [
|
||||
f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/{file}"
|
||||
for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
return [file for file in files if file.endswith(BROWSER_DOWNLOADING_SUFFIX)]
|
||||
|
||||
async def list_recordings_in_browser_session(self, organization_id: str, browser_session_id: str) -> list[str]:
|
||||
"""List all recording files for a browser session from Azure."""
|
||||
uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/videos"
|
||||
return [
|
||||
f"azure://{settings.AZURE_STORAGE_CONTAINER_ARTIFACTS}/{file}"
|
||||
for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
|
||||
async def get_shared_recordings_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[FileInfo]:
|
||||
"""Get recording files with SAS URLs for a browser session."""
|
||||
object_keys = await self.list_recordings_in_browser_session(organization_id, browser_session_id)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
# Playwright's record_video_dir should only contain .webm files.
|
||||
# Filter defensively in case of unexpected files.
|
||||
key_lower = key.lower()
|
||||
if not (key_lower.endswith(".webm") or key_lower.endswith(".mp4")):
|
||||
LOG.warning(
|
||||
"Skipping recording file with unsupported extension",
|
||||
uri=key,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
metadata = {}
|
||||
modified_at: datetime | None = None
|
||||
content_length: int | None = None
|
||||
# Get metadata (including checksum)
|
||||
try:
|
||||
object_info = await self.async_client.get_object_info(key)
|
||||
if object_info:
|
||||
metadata = object_info.get("Metadata", {})
|
||||
modified_at = object_info.get("LastModified")
|
||||
content_length = object_info.get("ContentLength") or object_info.get("Size")
|
||||
except Exception:
|
||||
LOG.exception("Recording object info retrieval failed", uri=key)
|
||||
|
||||
# Skip zero-byte objects (if any incomplete uploads)
|
||||
if content_length == 0:
|
||||
continue
|
||||
|
||||
# Create FileInfo object
|
||||
filename = os.path.basename(key)
|
||||
checksum = metadata.get("sha256_checksum") if metadata else None
|
||||
|
||||
# Get SAS URL
|
||||
sas_urls = await self.async_client.create_sas_urls([key])
|
||||
if not sas_urls:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
url=sas_urls[0],
|
||||
checksum=checksum,
|
||||
filename=metadata.get("original_filename", filename) if metadata else filename,
|
||||
modified_at=modified_at,
|
||||
)
|
||||
file_infos.append(file_info)
|
||||
|
||||
# Prefer the newest recording first (Azure list order is not guaranteed).
|
||||
# Treat None as "oldest".
|
||||
file_infos.sort(key=lambda f: (f.modified_at is not None, f.modified_at), reverse=True)
|
||||
return file_infos
|
||||
|
||||
async def save_downloaded_files(self, organization_id: str, run_id: str | None) -> None:
|
||||
download_dir = get_download_dir(run_id=run_id)
|
||||
files = os.listdir(download_dir)
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
base_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
for file in files:
|
||||
fpath = os.path.join(download_dir, file)
|
||||
if not os.path.isfile(fpath):
|
||||
continue
|
||||
uri = f"{base_uri}/{file}"
|
||||
checksum = calculate_sha256_for_file(fpath)
|
||||
LOG.info(
|
||||
"Calculated checksum for file",
|
||||
file=file,
|
||||
checksum=checksum,
|
||||
organization_id=organization_id,
|
||||
storage_tier=tier,
|
||||
)
|
||||
# Upload file with checksum metadata
|
||||
await self.async_client.upload_file_from_path(
|
||||
uri=uri,
|
||||
file_path=fpath,
|
||||
metadata={"sha256_checksum": checksum, "original_filename": file},
|
||||
tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
async def get_downloaded_files(self, organization_id: str, run_id: str | None) -> list[FileInfo]:
|
||||
uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
object_keys = await self.async_client.list_files(uri=uri)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
object_uri = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{key}"
|
||||
|
||||
# Get metadata (including checksum)
|
||||
metadata = await self.async_client.get_file_metadata(object_uri, log_exception=False)
|
||||
|
||||
# Create FileInfo object
|
||||
filename = os.path.basename(key)
|
||||
checksum = metadata.get("sha256_checksum") if metadata else None
|
||||
|
||||
# Get SAS URL
|
||||
sas_urls = await self.async_client.create_sas_urls([object_uri])
|
||||
if not sas_urls:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
url=sas_urls[0],
|
||||
checksum=checksum,
|
||||
filename=metadata.get("original_filename", filename) if metadata else filename,
|
||||
)
|
||||
file_infos.append(file_info)
|
||||
|
||||
return file_infos
|
||||
|
||||
async def save_legacy_file(
|
||||
self, *, organization_id: str, filename: str, fileObj: BinaryIO
|
||||
) -> tuple[str, str] | None:
|
||||
todays_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
container = settings.AZURE_STORAGE_CONTAINER_UPLOADS
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
# First try uploading with original filename
|
||||
try:
|
||||
sanitized_filename = os.path.basename(filename) # Remove any path components
|
||||
azure_uri = f"azure://{container}/{settings.ENV}/{organization_id}/{todays_date}/{sanitized_filename}"
|
||||
uploaded_uri = await self.async_client.upload_file_stream(azure_uri, fileObj, tier=tier, tags=tags)
|
||||
except Exception:
|
||||
LOG.error("Failed to upload file to Azure", exc_info=True)
|
||||
uploaded_uri = None
|
||||
|
||||
# If upload fails, try again with UUID prefix
|
||||
if not uploaded_uri:
|
||||
uuid_prefixed_filename = f"{str(uuid.uuid4())}_{filename}"
|
||||
azure_uri = f"azure://{container}/{settings.ENV}/{organization_id}/{todays_date}/{uuid_prefixed_filename}"
|
||||
fileObj.seek(0) # Reset file pointer
|
||||
uploaded_uri = await self.async_client.upload_file_stream(azure_uri, fileObj, tier=tier, tags=tags)
|
||||
|
||||
if not uploaded_uri:
|
||||
LOG.error(
|
||||
"Failed to upload file to Azure after retrying with UUID prefix",
|
||||
organization_id=organization_id,
|
||||
storage_tier=tier,
|
||||
filename=filename,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
LOG.debug(
|
||||
"Legacy file upload",
|
||||
organization_id=organization_id,
|
||||
storage_tier=tier,
|
||||
filename=filename,
|
||||
uploaded_uri=uploaded_uri,
|
||||
)
|
||||
# Generate a SAS URL for the uploaded file
|
||||
sas_urls = await self.async_client.create_sas_urls([uploaded_uri])
|
||||
if not sas_urls:
|
||||
LOG.error(
|
||||
"Failed to create SAS URL for uploaded file",
|
||||
organization_id=organization_id,
|
||||
storage_tier=tier,
|
||||
uploaded_uri=uploaded_uri,
|
||||
filename=filename,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
return sas_urls[0], uploaded_uri
|
||||
|
||||
def _build_browser_session_uri(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Build the Azure URI for a browser session file."""
|
||||
base = f"azure://{self.container}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/{artifact_type}"
|
||||
if date:
|
||||
return f"{base}/{date}/{remote_path}"
|
||||
return f"{base}/{remote_path}"
|
||||
|
||||
async def sync_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
local_file_path: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Sync a file from local browser session to Azure."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
tier = await self._get_storage_tier_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
await self.async_client.upload_file_from_path(uri, local_file_path, tier=tier, tags=tags)
|
||||
return uri
|
||||
|
||||
async def delete_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> None:
|
||||
"""Delete a file from browser session storage in Azure."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
await self.async_client.delete_file(uri)
|
||||
|
||||
async def browser_session_file_exists(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> bool:
|
||||
"""Check if a file exists in browser session storage in Azure."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
try:
|
||||
info = await self.async_client.get_object_info(uri)
|
||||
return info is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def download_uploaded_file(self, uri: str) -> bytes | None:
|
||||
"""Download a user-uploaded file from Azure."""
|
||||
return await self.async_client.download_file(uri, log_exception=False)
|
||||
|
||||
async def file_exists(self, uri: str) -> bool:
|
||||
"""Check if a file exists at the given Azure URI."""
|
||||
try:
|
||||
info = await self.async_client.get_object_info(uri)
|
||||
return info is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def storage_type(self) -> str:
|
||||
"""Returns 'azure' as the storage type."""
|
||||
return "azure"
|
||||
@@ -172,3 +172,50 @@ class BaseStorage(ABC):
|
||||
self, *, organization_id: str, filename: str, fileObj: BinaryIO
|
||||
) -> tuple[str, str] | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def sync_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
local_file_path: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def browser_session_file_exists(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def download_uploaded_file(self, uri: str) -> bytes | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def file_exists(self, uri: str) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def storage_type(self) -> str:
|
||||
pass
|
||||
|
||||
@@ -160,11 +160,11 @@ class LocalStorage(BaseStorage):
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_share_link(self, artifact: Artifact) -> str:
|
||||
return artifact.uri
|
||||
async def get_share_link(self, artifact: Artifact) -> str | None:
|
||||
return None
|
||||
|
||||
async def get_share_links(self, artifacts: list[Artifact]) -> list[str]:
|
||||
return [artifact.uri for artifact in artifacts]
|
||||
async def get_share_links(self, artifacts: list[Artifact]) -> list[str] | None:
|
||||
return None
|
||||
|
||||
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
||||
return
|
||||
@@ -346,3 +346,98 @@ class LocalStorage(BaseStorage):
|
||||
raise NotImplementedError(
|
||||
"Legacy file storage is not implemented for LocalStorage. Please use a different storage backend."
|
||||
)
|
||||
|
||||
def _build_browser_session_path(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> Path:
|
||||
"""Build the local path for a browser session file."""
|
||||
base = (
|
||||
Path(self.artifact_path)
|
||||
/ settings.ENV
|
||||
/ organization_id
|
||||
/ "browser_sessions"
|
||||
/ browser_session_id
|
||||
/ artifact_type
|
||||
)
|
||||
if date:
|
||||
return base / date / remote_path
|
||||
return base / remote_path
|
||||
|
||||
async def sync_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
local_file_path: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Sync a file from local browser session to local storage."""
|
||||
target_path = self._build_browser_session_path(
|
||||
organization_id, browser_session_id, artifact_type, remote_path, date
|
||||
)
|
||||
if WINDOWS:
|
||||
target_path = target_path.with_name(_windows_safe_filename(target_path.name))
|
||||
self._create_directories_if_not_exists(target_path)
|
||||
shutil.copy2(local_file_path, target_path)
|
||||
return f"file://{target_path}"
|
||||
|
||||
async def delete_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> None:
|
||||
"""Delete a file from browser session storage in local filesystem."""
|
||||
target_path = self._build_browser_session_path(
|
||||
organization_id, browser_session_id, artifact_type, remote_path, date
|
||||
)
|
||||
try:
|
||||
if target_path.exists():
|
||||
target_path.unlink()
|
||||
except Exception:
|
||||
LOG.exception("Failed to delete local browser session file", path=str(target_path))
|
||||
|
||||
async def browser_session_file_exists(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> bool:
|
||||
"""Check if a file exists in browser session storage in local filesystem."""
|
||||
target_path = self._build_browser_session_path(
|
||||
organization_id, browser_session_id, artifact_type, remote_path, date
|
||||
)
|
||||
return target_path.exists()
|
||||
|
||||
async def download_uploaded_file(self, uri: str) -> bytes | None:
|
||||
"""Download a user-uploaded file from local filesystem."""
|
||||
try:
|
||||
file_path = parse_uri_to_path(uri)
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
LOG.exception("Failed to read local file", uri=uri)
|
||||
return None
|
||||
|
||||
async def file_exists(self, uri: str) -> bool:
|
||||
"""Check if a file exists at the given local URI."""
|
||||
try:
|
||||
file_path = parse_uri_to_path(uri)
|
||||
return os.path.exists(file_path)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def storage_type(self) -> str:
|
||||
"""Returns 'file' as the storage type."""
|
||||
return "file"
|
||||
|
||||
@@ -235,10 +235,9 @@ class S3Storage(BaseStorage):
|
||||
async def list_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
return [
|
||||
f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
bucket = settings.AWS_S3_BUCKET_ARTIFACTS
|
||||
uri = f"s3://{bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
return [f"s3://{bucket}/{file}" for file in await self.async_client.list_files(uri=uri)]
|
||||
|
||||
async def get_shared_downloaded_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
@@ -281,18 +280,16 @@ class S3Storage(BaseStorage):
|
||||
async def list_downloading_files_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
) -> list[str]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
files = [
|
||||
f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
bucket = settings.AWS_S3_BUCKET_ARTIFACTS
|
||||
uri = f"s3://{bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/downloads"
|
||||
files = [f"s3://{bucket}/{file}" for file in await self.async_client.list_files(uri=uri)]
|
||||
return [file for file in files if file.endswith(BROWSER_DOWNLOADING_SUFFIX)]
|
||||
|
||||
async def list_recordings_in_browser_session(self, organization_id: str, browser_session_id: str) -> list[str]:
|
||||
"""List all recording files for a browser session from S3."""
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/videos"
|
||||
return [
|
||||
f"s3://{settings.AWS_S3_BUCKET_ARTIFACTS}/{file}" for file in await self.async_client.list_files(uri=uri)
|
||||
]
|
||||
bucket = settings.AWS_S3_BUCKET_ARTIFACTS
|
||||
uri = f"s3://{bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/videos"
|
||||
return [f"s3://{bucket}/{file}" for file in await self.async_client.list_files(uri=uri)]
|
||||
|
||||
async def get_shared_recordings_in_browser_session(
|
||||
self, organization_id: str, browser_session_id: str
|
||||
@@ -385,14 +382,15 @@ class S3Storage(BaseStorage):
|
||||
)
|
||||
|
||||
async def get_downloaded_files(self, organization_id: str, run_id: str | None) -> list[FileInfo]:
|
||||
uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
bucket = settings.AWS_S3_BUCKET_UPLOADS
|
||||
uri = f"s3://{bucket}/{DOWNLOAD_FILE_PREFIX}/{settings.ENV}/{organization_id}/{run_id}"
|
||||
object_keys = await self.async_client.list_files(uri=uri)
|
||||
if len(object_keys) == 0:
|
||||
return []
|
||||
|
||||
file_infos: list[FileInfo] = []
|
||||
for key in object_keys:
|
||||
object_uri = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{key}"
|
||||
object_uri = f"s3://{bucket}/{key}"
|
||||
|
||||
# Get metadata (including checksum)
|
||||
metadata = await self.async_client.get_file_metadata(object_uri, log_exception=False)
|
||||
@@ -467,3 +465,78 @@ class S3Storage(BaseStorage):
|
||||
)
|
||||
return None
|
||||
return presigned_urls[0], uploaded_s3_uri
|
||||
|
||||
def _build_browser_session_uri(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Build the S3 URI for a browser session file."""
|
||||
base = f"s3://{self.bucket}/v1/{settings.ENV}/{organization_id}/browser_sessions/{browser_session_id}/{artifact_type}"
|
||||
if date:
|
||||
return f"{base}/{date}/{remote_path}"
|
||||
return f"{base}/{remote_path}"
|
||||
|
||||
async def sync_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
local_file_path: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> str:
|
||||
"""Sync a file from local browser session to S3."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
sc = await self._get_storage_class_for_org(organization_id)
|
||||
tags = await self._get_tags_for_org(organization_id)
|
||||
await self.async_client.upload_file_from_path(uri, local_file_path, storage_class=sc, tags=tags)
|
||||
return uri
|
||||
|
||||
async def delete_browser_session_file(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> None:
|
||||
"""Delete a file from browser session storage in S3."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
await self.async_client.delete_file(uri, log_exception=True)
|
||||
|
||||
async def browser_session_file_exists(
|
||||
self,
|
||||
organization_id: str,
|
||||
browser_session_id: str,
|
||||
artifact_type: str,
|
||||
remote_path: str,
|
||||
date: str | None = None,
|
||||
) -> bool:
|
||||
"""Check if a file exists in browser session storage in S3."""
|
||||
uri = self._build_browser_session_uri(organization_id, browser_session_id, artifact_type, remote_path, date)
|
||||
try:
|
||||
info = await self.async_client.get_object_info(uri)
|
||||
return info is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def download_uploaded_file(self, uri: str) -> bytes | None:
|
||||
"""Download a user-uploaded file from S3."""
|
||||
return await self.async_client.download_file(uri, log_exception=False)
|
||||
|
||||
async def file_exists(self, uri: str) -> bool:
|
||||
"""Check if a file exists at the given S3 URI."""
|
||||
try:
|
||||
info = await self.async_client.get_object_info(uri)
|
||||
return info is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def storage_type(self) -> str:
|
||||
"""Returns 's3' as the storage type."""
|
||||
return "s3"
|
||||
|
||||
251
skyvern/forge/sdk/artifact/storage/test_azure_storage.py
Normal file
251
skyvern/forge/sdk/artifact/storage/test_azure_storage.py
Normal file
@@ -0,0 +1,251 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.api.azure import StandardBlobTier
|
||||
from skyvern.forge.sdk.api.real_azure import RealAsyncAzureStorageClient
|
||||
from skyvern.forge.sdk.artifact.storage.azure import AzureStorage
|
||||
|
||||
# Test constants
|
||||
TEST_CONTAINER = "test-azure-container"
|
||||
TEST_ORGANIZATION_ID = "test-org-123"
|
||||
TEST_BROWSER_SESSION_ID = "bs_test_123"
|
||||
|
||||
|
||||
class AzureStorageForTests(AzureStorage):
|
||||
"""Test subclass that overrides org-specific methods and bypasses client init."""
|
||||
|
||||
async_client: Any # Allow mock attribute access
|
||||
|
||||
def __init__(self, container: str) -> None:
|
||||
# Don't call super().__init__ to avoid creating real RealAsyncAzureStorageClient
|
||||
self.container = container
|
||||
self.async_client = AsyncMock()
|
||||
|
||||
async def _get_storage_tier_for_org(self, organization_id: str) -> StandardBlobTier:
|
||||
return StandardBlobTier.HOT
|
||||
|
||||
async def _get_tags_for_org(self, organization_id: str) -> dict[str, str]:
|
||||
return {"test": "tag"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_storage() -> AzureStorageForTests:
|
||||
"""Create AzureStorage with mocked async_client."""
|
||||
return AzureStorageForTests(container=TEST_CONTAINER)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAzureStorageBrowserSessionFiles:
|
||||
"""Test AzureStorage browser session file methods."""
|
||||
|
||||
async def test_sync_browser_session_file_with_date(
|
||||
self, azure_storage: AzureStorageForTests, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test syncing a file with date in path (videos/har)."""
|
||||
test_file = tmp_path / "recording.webm"
|
||||
test_file.write_bytes(b"fake video data")
|
||||
|
||||
uri = await azure_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="recording.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
expected_uri = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/recording.webm"
|
||||
assert uri == expected_uri
|
||||
azure_storage.async_client.upload_file_from_path.assert_called_once()
|
||||
|
||||
async def test_sync_browser_session_file_without_date(
|
||||
self, azure_storage: AzureStorageForTests, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test syncing a file without date (downloads category)."""
|
||||
test_file = tmp_path / "document.pdf"
|
||||
test_file.write_bytes(b"fake download data")
|
||||
|
||||
uri = await azure_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="document.pdf",
|
||||
date=None,
|
||||
)
|
||||
|
||||
expected_uri = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/downloads/document.pdf"
|
||||
assert uri == expected_uri
|
||||
|
||||
async def test_browser_session_file_exists_returns_true(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test browser_session_file_exists returns True when file exists."""
|
||||
azure_storage.async_client.get_object_info.return_value = {"LastModified": "2025-01-15"}
|
||||
|
||||
exists = await azure_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="exists.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
assert exists is True
|
||||
|
||||
async def test_browser_session_file_exists_returns_false_on_exception(
|
||||
self, azure_storage: AzureStorageForTests
|
||||
) -> None:
|
||||
"""Test browser_session_file_exists returns False when exception is raised."""
|
||||
azure_storage.async_client.get_object_info.side_effect = Exception("Not found")
|
||||
|
||||
exists = await azure_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="nonexistent.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
assert exists is False
|
||||
|
||||
async def test_delete_browser_session_file(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test deleting a browser session file."""
|
||||
await azure_storage.delete_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
expected_uri = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/to_delete.webm"
|
||||
azure_storage.async_client.delete_file.assert_called_once_with(expected_uri)
|
||||
|
||||
async def test_file_exists_returns_true(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test file_exists returns True when file exists."""
|
||||
azure_storage.async_client.get_object_info.return_value = {"LastModified": "2025-01-15"}
|
||||
uri = f"azure://{TEST_CONTAINER}/test/file.txt"
|
||||
|
||||
exists = await azure_storage.file_exists(uri)
|
||||
|
||||
assert exists is True
|
||||
|
||||
async def test_file_exists_returns_false_on_exception(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test file_exists returns False when exception is raised (404)."""
|
||||
azure_storage.async_client.get_object_info.side_effect = Exception("Not found")
|
||||
uri = f"azure://{TEST_CONTAINER}/nonexistent/file.txt"
|
||||
|
||||
exists = await azure_storage.file_exists(uri)
|
||||
|
||||
assert exists is False
|
||||
|
||||
async def test_download_uploaded_file(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test downloading an uploaded file."""
|
||||
test_data = b"uploaded file content"
|
||||
azure_storage.async_client.download_file.return_value = test_data
|
||||
uri = f"azure://{TEST_CONTAINER}/uploads/file.pdf"
|
||||
|
||||
downloaded = await azure_storage.download_uploaded_file(uri)
|
||||
|
||||
assert downloaded == test_data
|
||||
azure_storage.async_client.download_file.assert_called_once_with(uri, log_exception=False)
|
||||
|
||||
async def test_download_uploaded_file_returns_none(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test downloading a non-existent file returns None."""
|
||||
azure_storage.async_client.download_file.return_value = None
|
||||
uri = f"azure://{TEST_CONTAINER}/nonexistent/file.txt"
|
||||
|
||||
downloaded = await azure_storage.download_uploaded_file(uri)
|
||||
|
||||
assert downloaded is None
|
||||
|
||||
def test_storage_type_property(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test storage_type returns 'azure'."""
|
||||
assert azure_storage.storage_type == "azure"
|
||||
|
||||
|
||||
class TestAzureStorageBuildUri:
|
||||
"""Test Azure URI building methods."""
|
||||
|
||||
def test_build_browser_session_uri_with_date(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test building URI with date."""
|
||||
uri = azure_storage._build_browser_session_uri(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="file.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
expected = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/file.webm"
|
||||
assert uri == expected
|
||||
|
||||
def test_build_browser_session_uri_without_date(self, azure_storage: AzureStorageForTests) -> None:
|
||||
"""Test building URI without date."""
|
||||
uri = azure_storage._build_browser_session_uri(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
remote_path="file.pdf",
|
||||
date=None,
|
||||
)
|
||||
|
||||
expected = f"azure://{TEST_CONTAINER}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/downloads/file.pdf"
|
||||
assert uri == expected
|
||||
|
||||
|
||||
AZURE_CONTENT_TYPE_TEST_CASES = [
|
||||
# (filename, expected_content_type, artifact_type, date)
|
||||
("video.webm", "video/webm", "videos", "2025-01-15"),
|
||||
("data.json", "application/json", "har", "2025-01-15"),
|
||||
("network.har", "application/json", "har", "2025-01-15"),
|
||||
("screenshot.png", "image/png", "downloads", None),
|
||||
("output.txt", "text/plain", "downloads", None),
|
||||
("debug.log", "text/plain", "downloads", None),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAzureStorageContentType:
|
||||
"""Test Azure Storage content type guessing.
|
||||
|
||||
Tests at two levels:
|
||||
1. High-level: sync_browser_session_file interface with artifact_type/date
|
||||
2. Low-level: RealAsyncAzureStorageClient to verify ContentSettings is passed
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("filename,expected_content_type,artifact_type,date", AZURE_CONTENT_TYPE_TEST_CASES)
|
||||
async def test_content_type_guessing(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
filename: str,
|
||||
expected_content_type: str,
|
||||
artifact_type: str,
|
||||
date: str | None,
|
||||
) -> None:
|
||||
"""Test that RealAsyncAzureStorageClient sets correct content type based on extension."""
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_bytes(b"test content")
|
||||
|
||||
with patch.object(RealAsyncAzureStorageClient, "_get_blob_service_client") as mock_get_client:
|
||||
mock_container_client = MagicMock()
|
||||
mock_container_client.upload_blob = AsyncMock()
|
||||
mock_container_client.exists = AsyncMock(return_value=True)
|
||||
mock_blob_service = MagicMock()
|
||||
mock_blob_service.get_container_client.return_value = mock_container_client
|
||||
mock_get_client.return_value = mock_blob_service
|
||||
|
||||
client = RealAsyncAzureStorageClient(account_name="test", account_key="testkey")
|
||||
client._verified_containers.add("test-container")
|
||||
|
||||
await client.upload_file_from_path(
|
||||
uri=f"azure://test-container/path/{filename}",
|
||||
file_path=str(test_file),
|
||||
)
|
||||
|
||||
call_kwargs = mock_container_client.upload_blob.call_args.kwargs
|
||||
assert call_kwargs["content_settings"] is not None
|
||||
assert call_kwargs["content_settings"].content_type == expected_content_type
|
||||
@@ -221,3 +221,229 @@ class TestS3StorageStore:
|
||||
await s3_storage.store_artifact(artifact, test_data)
|
||||
_assert_object_content(boto3_test_client, artifact.uri, test_data)
|
||||
_assert_object_meta(boto3_test_client, artifact.uri)
|
||||
|
||||
|
||||
TEST_BROWSER_SESSION_ID = "bs_test_123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestS3StorageBrowserSessionFiles:
|
||||
"""Test S3Storage browser session file methods."""
|
||||
|
||||
async def test_sync_browser_session_file_with_date(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test syncing a file with date in path (videos/har)."""
|
||||
test_data = b"fake video data"
|
||||
test_file = tmp_path / "recording.webm"
|
||||
test_file.write_bytes(test_data)
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="recording.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
expected_uri = f"s3://{TEST_BUCKET}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/videos/2025-01-15/recording.webm"
|
||||
assert uri == expected_uri
|
||||
_assert_object_content(boto3_test_client, uri, test_data)
|
||||
_assert_object_meta(boto3_test_client, uri)
|
||||
|
||||
async def test_sync_browser_session_file_without_date(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test syncing a file without date (downloads category)."""
|
||||
test_data = b"fake download data"
|
||||
test_file = tmp_path / "document.pdf"
|
||||
test_file.write_bytes(test_data)
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="document.pdf",
|
||||
date=None,
|
||||
)
|
||||
|
||||
expected_uri = f"s3://{TEST_BUCKET}/v1/{settings.ENV}/{TEST_ORGANIZATION_ID}/browser_sessions/{TEST_BROWSER_SESSION_ID}/downloads/document.pdf"
|
||||
assert uri == expected_uri
|
||||
_assert_object_content(boto3_test_client, uri, test_data)
|
||||
|
||||
async def test_browser_session_file_exists_returns_true(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test browser_session_file_exists returns True for existing file."""
|
||||
test_file = tmp_path / "exists.webm"
|
||||
test_file.write_bytes(b"test data")
|
||||
|
||||
await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="exists.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
exists = await s3_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="exists.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
assert exists is True
|
||||
|
||||
async def test_browser_session_file_exists_returns_false(self, s3_storage: S3Storage) -> None:
|
||||
"""Test browser_session_file_exists returns False for non-existent file."""
|
||||
exists = await s3_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="nonexistent.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
assert exists is False
|
||||
|
||||
async def test_delete_browser_session_file(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting a browser session file."""
|
||||
test_file = tmp_path / "to_delete.webm"
|
||||
test_file.write_bytes(b"test data")
|
||||
|
||||
await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
exists_before = await s3_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
assert exists_before is True
|
||||
|
||||
await s3_storage.delete_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
|
||||
exists_after = await s3_storage.browser_session_file_exists(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="videos",
|
||||
remote_path="to_delete.webm",
|
||||
date="2025-01-15",
|
||||
)
|
||||
assert exists_after is False
|
||||
|
||||
async def test_file_exists_returns_true(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test file_exists returns True for existing file."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_bytes(b"test data")
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="test.txt",
|
||||
)
|
||||
|
||||
exists = await s3_storage.file_exists(uri)
|
||||
assert exists is True
|
||||
|
||||
async def test_file_exists_returns_false(self, s3_storage: S3Storage) -> None:
|
||||
"""Test file_exists returns False for non-existent file."""
|
||||
uri = f"s3://{TEST_BUCKET}/nonexistent/path/file.txt"
|
||||
exists = await s3_storage.file_exists(uri)
|
||||
assert exists is False
|
||||
|
||||
async def test_download_uploaded_file(
|
||||
self, s3_storage: S3Storage, boto3_test_client: S3Client, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test downloading an uploaded file."""
|
||||
test_data = b"uploaded file content"
|
||||
test_file = tmp_path / "uploaded.pdf"
|
||||
test_file.write_bytes(test_data)
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type="downloads",
|
||||
local_file_path=str(test_file),
|
||||
remote_path="uploaded.pdf",
|
||||
)
|
||||
|
||||
downloaded = await s3_storage.download_uploaded_file(uri)
|
||||
assert downloaded == test_data
|
||||
|
||||
async def test_download_uploaded_file_nonexistent(self, s3_storage: S3Storage) -> None:
|
||||
"""Test downloading a non-existent file returns None."""
|
||||
uri = f"s3://{TEST_BUCKET}/nonexistent/path/file.txt"
|
||||
downloaded = await s3_storage.download_uploaded_file(uri)
|
||||
assert downloaded is None
|
||||
|
||||
def test_storage_type_property(self, s3_storage: S3Storage) -> None:
|
||||
"""Test storage_type returns 's3'."""
|
||||
assert s3_storage.storage_type == "s3"
|
||||
|
||||
|
||||
CONTENT_TYPE_TEST_CASES = [
|
||||
# (filename, expected_content_type, artifact_type, date)
|
||||
("video.webm", "video/webm", "videos", "2025-01-15"),
|
||||
("data.json", "application/json", "har", "2025-01-15"),
|
||||
("network.har", "application/json", "har", "2025-01-15"),
|
||||
("screenshot.png", "image/png", "downloads", None),
|
||||
("output.txt", "text/plain", "downloads", None),
|
||||
("debug.log", "text/plain", "downloads", None),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestS3StorageContentType:
|
||||
"""Test S3Storage content type guessing."""
|
||||
|
||||
@pytest.mark.parametrize("filename,expected_content_type,artifact_type,date", CONTENT_TYPE_TEST_CASES)
|
||||
async def test_content_type_guessing(
|
||||
self,
|
||||
s3_storage: S3Storage,
|
||||
boto3_test_client: S3Client,
|
||||
tmp_path: Path,
|
||||
filename: str,
|
||||
expected_content_type: str,
|
||||
artifact_type: str,
|
||||
date: str | None,
|
||||
) -> None:
|
||||
"""Test that files get correct content type based on extension."""
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_bytes(b"test content")
|
||||
|
||||
uri = await s3_storage.sync_browser_session_file(
|
||||
organization_id=TEST_ORGANIZATION_ID,
|
||||
browser_session_id=TEST_BROWSER_SESSION_ID,
|
||||
artifact_type=artifact_type,
|
||||
local_file_path=str(test_file),
|
||||
remote_path=filename,
|
||||
date=date,
|
||||
)
|
||||
|
||||
s3uri = S3Uri(uri)
|
||||
obj_meta = boto3_test_client.head_object(Bucket=TEST_BUCKET, Key=s3uri.key)
|
||||
assert obj_meta["ContentType"] == expected_content_type
|
||||
|
||||
Reference in New Issue
Block a user