add azure blob storage (#4338)
Signed-off-by: Benji Visser <benji@093b.org> Co-authored-by: Benji Visser <benji@093b.org> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from enum import StrEnum
|
||||
from mimetypes import add_type, guess_type
|
||||
from typing import IO, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -12,6 +13,10 @@ from types_boto3_secretsmanager.client import SecretsManagerClient
|
||||
|
||||
from skyvern.config import settings
|
||||
|
||||
# Register custom mime types for mimetypes guessing
|
||||
add_type("application/json", ".har")
|
||||
add_type("text/plain", ".log")
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -188,6 +193,10 @@ class AsyncAWSClient:
|
||||
extra_args["Tagging"] = self._create_tag_string(tags)
|
||||
if content_type:
|
||||
extra_args["ContentType"] = content_type
|
||||
else:
|
||||
guessed_type, _ = guess_type(file_path)
|
||||
if guessed_type:
|
||||
extra_args["ContentType"] = guessed_type
|
||||
await client.upload_file(
|
||||
Filename=file_path,
|
||||
Bucket=parsed_uri.bucket,
|
||||
|
||||
@@ -1,8 +1,35 @@
|
||||
from typing import Protocol, Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from azure.storage.blob import StandardBlobTier
|
||||
|
||||
from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential
|
||||
|
||||
|
||||
class AzureUri:
|
||||
"""Parse azure://{container}/{blob_path} URIs."""
|
||||
|
||||
def __init__(self, uri: str) -> None:
|
||||
self._parsed = urlparse(uri, allow_fragments=False)
|
||||
|
||||
@property
|
||||
def container(self) -> str:
|
||||
return self._parsed.netloc
|
||||
|
||||
@property
|
||||
def blob_path(self) -> str:
|
||||
if self._parsed.query:
|
||||
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
|
||||
return self._parsed.path.lstrip("/")
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._parsed.geturl()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.uri
|
||||
|
||||
|
||||
class AsyncAzureVaultClient(Protocol):
|
||||
"""Protocol defining the interface for Azure Vault clients.
|
||||
|
||||
@@ -68,21 +95,24 @@ class AsyncAzureVaultClient(Protocol):
|
||||
|
||||
|
||||
class AsyncAzureStorageClient(Protocol):
|
||||
"""Protocol defining the interface for Azure Storage clients.
|
||||
"""Protocol defining the interface for Azure Storage clients."""
|
||||
|
||||
This client provides methods to interact with Azure Blob Storage for file operations.
|
||||
"""
|
||||
|
||||
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
|
||||
async def upload_file_from_path(
|
||||
self,
|
||||
uri: str,
|
||||
file_path: str,
|
||||
tier: StandardBlobTier = StandardBlobTier.HOT,
|
||||
tags: dict[str, str] | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Upload a file from the local filesystem to Azure Blob Storage.
|
||||
|
||||
Args:
|
||||
container_name: The name of the Azure Blob container
|
||||
blob_name: The name to give the blob in storage
|
||||
uri: The azure:// URI for the blob (azure://container/blob_path)
|
||||
file_path: The local path to the file to upload
|
||||
|
||||
Raises:
|
||||
Exception: If the upload fails
|
||||
tier: The storage tier for the blob
|
||||
tags: Optional tags to attach to the blob
|
||||
metadata: Optional metadata to attach to the blob
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ from yarl import URL
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import BROWSER_DOWNLOAD_TIMEOUT, BROWSER_DOWNLOADING_SUFFIX, REPO_ROOT_DIR
|
||||
from skyvern.exceptions import DownloadFileMaxSizeExceeded, DownloadFileMaxWaitingTime
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient, aws_client
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.utils.url_validators import encode_url
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -97,6 +98,12 @@ def validate_download_url(url: str) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
# Allow Azure URIs for Skyvern uploads container
|
||||
if scheme == "azure":
|
||||
if url.startswith(f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{settings.ENV}/o_"):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Allow file:// URLs only in local environment
|
||||
if scheme == "file":
|
||||
if settings.ENV != "local":
|
||||
@@ -129,20 +136,41 @@ async def download_file(url: str, max_size_mb: int | None = None) -> str:
|
||||
url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
||||
LOG.info("Converting Google Drive link to direct download", url=url)
|
||||
|
||||
# Check if URL is an S3 URI
|
||||
if url.startswith(f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"):
|
||||
LOG.info("Downloading Skyvern file from S3", url=url)
|
||||
client = AsyncAWSClient()
|
||||
return await download_from_s3(client, url)
|
||||
# Check if URL is a cloud storage URI (S3 or Azure)
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme == "s3":
|
||||
uploads_prefix = f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"
|
||||
if url.startswith(uploads_prefix):
|
||||
LOG.info("Downloading Skyvern file from S3", url=url)
|
||||
data = await app.STORAGE.download_uploaded_file(url)
|
||||
if data is None:
|
||||
raise Exception(f"Failed to download file from S3: {url}")
|
||||
filename = url.split("/")[-1]
|
||||
temp_file = create_named_temporary_file(delete=False, file_name=filename)
|
||||
LOG.info(f"Downloaded file to {temp_file.name}")
|
||||
temp_file.write(data)
|
||||
return temp_file.name
|
||||
elif parsed.scheme == "azure":
|
||||
uploads_prefix = f"azure://{settings.AZURE_STORAGE_CONTAINER_UPLOADS}/{settings.ENV}/o_"
|
||||
if url.startswith(uploads_prefix):
|
||||
LOG.info("Downloading Skyvern file from Azure Blob Storage", url=url)
|
||||
data = await app.STORAGE.download_uploaded_file(url)
|
||||
if data is None:
|
||||
raise Exception(f"Failed to download file from Azure Blob Storage: {url}")
|
||||
filename = url.split("/")[-1]
|
||||
temp_file = create_named_temporary_file(delete=False, file_name=filename)
|
||||
LOG.info(f"Downloaded file to {temp_file.name}")
|
||||
temp_file.write(data)
|
||||
return temp_file.name
|
||||
|
||||
# Check if URL is a file:// URI
|
||||
# we only support to download local files when the environment is local
|
||||
# and the file is in the skyvern downloads directory
|
||||
if url.startswith("file://") and settings.ENV == "local":
|
||||
file_path = parse_uri_to_path(url)
|
||||
if file_path.startswith(f"{REPO_ROOT_DIR}/downloads"):
|
||||
local_path = parse_uri_to_path(url)
|
||||
if local_path.startswith(f"{REPO_ROOT_DIR}/downloads"):
|
||||
LOG.info("Downloading file from local file system", url=url)
|
||||
return file_path
|
||||
return local_path
|
||||
|
||||
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
||||
LOG.info("Starting to download file", url=url)
|
||||
@@ -262,12 +290,13 @@ async def wait_for_download_finished(downloading_files: list[str], timeout: floa
|
||||
while len(cur_downloading_files) > 0:
|
||||
new_downloading_files: list[str] = []
|
||||
for path in cur_downloading_files:
|
||||
if path.startswith("s3://"):
|
||||
try:
|
||||
await aws_client.get_object_info(path)
|
||||
except Exception:
|
||||
# Check for cloud storage URIs (S3 or Azure)
|
||||
parsed = urlparse(path)
|
||||
if parsed.scheme in ("s3", "azure"):
|
||||
if not await app.STORAGE.file_exists(path):
|
||||
LOG.debug(
|
||||
"downloading file is not found in s3, means the file finished downloading", path=path
|
||||
"downloading file is not found in cloud storage, means the file finished downloading",
|
||||
path=path,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
|
||||
@@ -1,15 +1,29 @@
|
||||
"""Real implementations of Azure clients (Vault and Storage) and their factories."""
|
||||
|
||||
from typing import Self
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from mimetypes import add_type, guess_type
|
||||
from typing import IO, Self
|
||||
|
||||
import structlog
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
|
||||
from azure.keyvault.secrets.aio import SecretClient
|
||||
from azure.storage.blob import BlobSasPermissions, ContentSettings, StandardBlobTier, generate_blob_sas
|
||||
from azure.storage.blob.aio import BlobServiceClient
|
||||
|
||||
from skyvern.forge.sdk.api.azure import AsyncAzureStorageClient, AsyncAzureVaultClient, AzureClientFactory
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.api.azure import (
|
||||
AsyncAzureStorageClient,
|
||||
AsyncAzureVaultClient,
|
||||
AzureClientFactory,
|
||||
AzureUri,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential
|
||||
|
||||
# Register custom mime types for mimetypes guessing
|
||||
add_type("application/json", ".har")
|
||||
add_type("text/plain", ".log")
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -73,37 +87,256 @@ class RealAsyncAzureVaultClient(AsyncAzureVaultClient):
|
||||
|
||||
|
||||
class RealAsyncAzureStorageClient(AsyncAzureStorageClient):
|
||||
"""Real implementation of Azure Storage client using Azure SDK."""
|
||||
"""Async client for Azure Blob Storage operations. Implements AsyncAzureStorageClient protocol."""
|
||||
|
||||
def __init__(self, storage_account_name: str, storage_account_key: str):
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=f"https://{storage_account_name}.blob.core.windows.net",
|
||||
credential=storage_account_key,
|
||||
def __init__(
|
||||
self,
|
||||
account_name: str | None = None,
|
||||
account_key: str | None = None,
|
||||
) -> None:
|
||||
self.account_name = account_name or settings.AZURE_STORAGE_ACCOUNT_NAME
|
||||
self.account_key = account_key or settings.AZURE_STORAGE_ACCOUNT_KEY
|
||||
|
||||
if not self.account_name or not self.account_key:
|
||||
raise ValueError("Azure Storage account name and key are required")
|
||||
|
||||
self._blob_service_client: BlobServiceClient | None = None
|
||||
self._verified_containers: set[str] = set()
|
||||
|
||||
def _get_blob_service_client(self) -> BlobServiceClient:
|
||||
if self._blob_service_client is None:
|
||||
self._blob_service_client = BlobServiceClient(
|
||||
account_url=f"https://{self.account_name}.blob.core.windows.net",
|
||||
credential=self.account_key,
|
||||
)
|
||||
return self._blob_service_client
|
||||
|
||||
async def _ensure_container_exists(self, container: str) -> None:
|
||||
if container in self._verified_containers:
|
||||
return
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(container)
|
||||
try:
|
||||
if not await container_client.exists():
|
||||
await container_client.create_container()
|
||||
LOG.info("Created Azure container", container=container)
|
||||
except Exception:
|
||||
LOG.debug("Container may already exist", container=container)
|
||||
self._verified_containers.add(container)
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
uri: str,
|
||||
data: bytes,
|
||||
tier: StandardBlobTier = StandardBlobTier.HOT,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
parsed = AzureUri(uri)
|
||||
await self._ensure_container_exists(parsed.container)
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
await container_client.upload_blob(
|
||||
name=parsed.blob_path,
|
||||
data=data,
|
||||
overwrite=True,
|
||||
standard_blob_tier=tier,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
|
||||
try:
|
||||
container_client = self.blob_service_client.get_container_client(container_name)
|
||||
# Create the container if it doesn't exist
|
||||
try:
|
||||
await container_client.create_container()
|
||||
except Exception as e:
|
||||
LOG.info("Azure container already exists or failed to create", container_name=container_name, error=e)
|
||||
|
||||
with open(file_path, "rb") as data:
|
||||
await container_client.upload_blob(name=blob_name, data=data, overwrite=True)
|
||||
LOG.info("File uploaded to Azure Blob Storage", container_name=container_name, blob_name=blob_name)
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"Failed to upload file to Azure Blob Storage",
|
||||
container_name=container_name,
|
||||
blob_name=blob_name,
|
||||
error=e,
|
||||
async def upload_file_from_path(
|
||||
self,
|
||||
uri: str,
|
||||
file_path: str,
|
||||
tier: StandardBlobTier = StandardBlobTier.HOT,
|
||||
tags: dict[str, str] | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
parsed = AzureUri(uri)
|
||||
await self._ensure_container_exists(parsed.container)
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
content_type, _ = guess_type(file_path)
|
||||
content_settings = ContentSettings(content_type=content_type) if content_type else None
|
||||
with open(file_path, "rb") as f:
|
||||
await container_client.upload_blob(
|
||||
name=parsed.blob_path,
|
||||
data=f,
|
||||
overwrite=True,
|
||||
standard_blob_tier=tier,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
content_settings=content_settings,
|
||||
)
|
||||
raise e
|
||||
|
||||
async def upload_file_stream(
|
||||
self,
|
||||
uri: str,
|
||||
file_obj: IO[bytes],
|
||||
tier: StandardBlobTier = StandardBlobTier.HOT,
|
||||
tags: dict[str, str] | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
parsed = AzureUri(uri)
|
||||
await self._ensure_container_exists(parsed.container)
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
await container_client.upload_blob(
|
||||
name=parsed.blob_path,
|
||||
data=file_obj,
|
||||
overwrite=True,
|
||||
standard_blob_tier=tier,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
)
|
||||
return uri
|
||||
|
||||
async def download_file(self, uri: str, log_exception: bool = True) -> bytes | None:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
download = await blob_client.download_blob()
|
||||
return await download.readall()
|
||||
except ResourceNotFoundError:
|
||||
if log_exception:
|
||||
LOG.warning("Azure blob not found", uri=uri)
|
||||
return None
|
||||
except Exception:
|
||||
if log_exception:
|
||||
LOG.exception("Failed to download from Azure", uri=uri)
|
||||
return None
|
||||
|
||||
async def get_blob_properties(self, uri: str) -> dict | None:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
props = await blob_client.get_blob_properties()
|
||||
return {
|
||||
"size": props.size,
|
||||
"content_type": props.content_settings.content_type if props.content_settings else None,
|
||||
"last_modified": props.last_modified,
|
||||
"etag": props.etag,
|
||||
"metadata": props.metadata,
|
||||
}
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
except Exception:
|
||||
LOG.exception("Failed to get blob properties", uri=uri)
|
||||
return None
|
||||
|
||||
async def blob_exists(self, uri: str) -> bool:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
return await blob_client.exists()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_blob(self, uri: str) -> None:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
await blob_client.delete_blob()
|
||||
except ResourceNotFoundError:
|
||||
LOG.debug("Azure blob not found for deletion", uri=uri)
|
||||
except Exception:
|
||||
LOG.exception("Failed to delete Azure blob", uri=uri)
|
||||
raise
|
||||
|
||||
async def list_blobs(self, container: str, prefix: str | None = None) -> list[str]:
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(container)
|
||||
blobs = []
|
||||
async for blob in container_client.list_blobs(name_starts_with=prefix):
|
||||
blobs.append(blob.name)
|
||||
return blobs
|
||||
except ResourceNotFoundError:
|
||||
return []
|
||||
except Exception:
|
||||
LOG.exception("Failed to list Azure blobs", container=container, prefix=prefix)
|
||||
return []
|
||||
|
||||
def create_sas_url(self, uri: str, expiry_hours: int = 24) -> str | None:
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
sas_token = generate_blob_sas(
|
||||
account_name=self.account_name,
|
||||
container_name=parsed.container,
|
||||
blob_name=parsed.blob_path,
|
||||
account_key=self.account_key,
|
||||
permission=BlobSasPermissions(read=True),
|
||||
expiry=datetime.now(timezone.utc) + timedelta(hours=expiry_hours),
|
||||
)
|
||||
return (
|
||||
f"https://{self.account_name}.blob.core.windows.net/{parsed.container}/{parsed.blob_path}?{sas_token}"
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception("Failed to create SAS URL", uri=uri)
|
||||
return None
|
||||
|
||||
async def create_sas_urls(self, uris: list[str], expiry_hours: int = 24) -> list[str] | None:
|
||||
try:
|
||||
sas_urls: list[str] = []
|
||||
for uri in uris:
|
||||
url = self.create_sas_url(uri, expiry_hours)
|
||||
if url is None:
|
||||
LOG.warning("SAS URL generation failed, aborting batch", failed_uri=uri, uris=uris)
|
||||
return None
|
||||
sas_urls.append(url)
|
||||
return sas_urls
|
||||
except Exception:
|
||||
LOG.exception("Failed to create SAS URLs")
|
||||
return None
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.blob_service_client.close()
|
||||
if self._blob_service_client:
|
||||
await self._blob_service_client.close()
|
||||
self._blob_service_client = None
|
||||
|
||||
async def list_files(self, uri: str) -> list[str]:
|
||||
"""List files under a URI prefix. Returns blob names relative to container."""
|
||||
parsed = AzureUri(uri)
|
||||
return await self.list_blobs(parsed.container, parsed.blob_path)
|
||||
|
||||
async def get_object_info(self, uri: str) -> dict | None:
|
||||
"""Get object info including metadata. Returns dict with Metadata and LastModified keys."""
|
||||
props = await self.get_blob_properties(uri)
|
||||
if props is None:
|
||||
return None
|
||||
return {
|
||||
"Metadata": props.get("metadata", {}),
|
||||
"LastModified": props.get("last_modified"),
|
||||
}
|
||||
|
||||
async def delete_file(self, uri: str) -> None:
|
||||
"""Delete a file at the given URI."""
|
||||
await self.delete_blob(uri)
|
||||
|
||||
async def get_file_metadata(self, uri: str, log_exception: bool = True) -> dict[str, str] | None:
|
||||
"""Get only the metadata for a file."""
|
||||
parsed = AzureUri(uri)
|
||||
try:
|
||||
client = self._get_blob_service_client()
|
||||
container_client = client.get_container_client(parsed.container)
|
||||
blob_client = container_client.get_blob_client(parsed.blob_path)
|
||||
props = await blob_client.get_blob_properties()
|
||||
return props.metadata or {}
|
||||
except ResourceNotFoundError:
|
||||
if log_exception:
|
||||
LOG.warning("Azure blob not found for metadata", uri=uri)
|
||||
return None
|
||||
except Exception:
|
||||
if log_exception:
|
||||
LOG.exception("Failed to get blob metadata", uri=uri)
|
||||
return None
|
||||
|
||||
|
||||
class RealAzureClientFactory(AzureClientFactory):
|
||||
@@ -124,4 +357,4 @@ class RealAzureClientFactory(AzureClientFactory):
|
||||
|
||||
def create_storage_client(self, storage_account_name: str, storage_account_key: str) -> AsyncAzureStorageClient:
|
||||
"""Create an Azure Storage client with the provided credentials."""
|
||||
return RealAsyncAzureStorageClient(storage_account_name, storage_account_key)
|
||||
return RealAsyncAzureStorageClient(account_name=storage_account_name, account_key=storage_account_key)
|
||||
|
||||
Reference in New Issue
Block a user