Azure ClientSecretCredential support (#3456)

Co-authored-by: Suchintan <suchintan@users.noreply.github.com>
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
stenn930
2025-09-23 10:16:48 -06:00
committed by GitHub
parent 10fac9bad0
commit a29a2bc49b
12 changed files with 592 additions and 71 deletions

View File

@@ -1,39 +1,24 @@
import structlog
from azure.identity.aio import DefaultAzureCredential
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
from azure.keyvault.secrets.aio import SecretClient
from azure.storage.blob.aio import BlobServiceClient
from skyvern.exceptions import AzureConfigurationError
from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential
LOG = structlog.get_logger()
class AsyncAzureClient:
def __init__(self, storage_account_name: str | None, storage_account_key: str | None):
self.storage_account_name = storage_account_name
self.storage_account_key = storage_account_key
if storage_account_name and storage_account_key:
self.blob_service_client = BlobServiceClient(
account_url=f"https://{storage_account_name}.blob.core.windows.net",
credential=storage_account_key,
)
else:
self.blob_service_client = None
self.credential = DefaultAzureCredential()
async def get_secret(self, secret_name: str, vault_name: str | None = None) -> str | None:
vault_subdomain = vault_name or self.storage_account_name
if not vault_subdomain:
raise AzureConfigurationError("Missing vault")
class AsyncAzureVaultClient:
def __init__(self, credential: ClientSecretCredential | DefaultAzureCredential):
self.credential = credential
async def get_secret(self, secret_name: str, vault_name: str) -> str | None:
try:
# Azure Key Vault URL format: https://<your-key-vault-name>.vault.azure.net
# Assuming the secret_name is actually the Key Vault URL and the secret name
# This needs to be clarified or passed as separate parameters
# For now, let's assume secret_name is the actual secret name and Key Vault URL is in settings.
key_vault_url = f"https://{vault_subdomain}.vault.azure.net" # Placeholder, adjust as needed
key_vault_url = f"https://{vault_name}.vault.azure.net" # Placeholder, adjust as needed
secret_client = SecretClient(vault_url=key_vault_url, credential=self.credential)
secret = await secret_client.get_secret(secret_name)
return secret.value
@@ -43,10 +28,34 @@ class AsyncAzureClient:
finally:
await self.credential.close()
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
if not self.blob_service_client:
raise AzureConfigurationError("Storage is not configured")
async def close(self) -> None:
await self.credential.close()
@classmethod
def create_default(cls) -> "AsyncAzureVaultClient":
return cls(DefaultAzureCredential())
@classmethod
def create_from_client_secret(
cls,
credential: AzureClientSecretCredential,
) -> "AsyncAzureVaultClient":
cred = ClientSecretCredential(
tenant_id=credential.tenant_id,
client_id=credential.client_id,
client_secret=credential.client_secret,
)
return cls(cred)
class AsyncAzureStorageClient:
def __init__(self, storage_account_name: str, storage_account_key: str):
self.blob_service_client = BlobServiceClient(
account_url=f"https://{storage_account_name}.blob.core.windows.net",
credential=storage_account_key,
)
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
try:
container_client = self.blob_service_client.get_container_client(container_name)
# Create the container if it doesn't exist
@@ -69,4 +78,3 @@ class AsyncAzureClient:
async def close(self) -> None:
await self.blob_service_client.close()
await self.credential.close()

View File

@@ -1,6 +1,6 @@
import json
from datetime import datetime, timedelta, timezone
from typing import Any, List, Sequence
from typing import Any, List, Literal, Sequence, overload
import structlog
from sqlalchemy import and_, asc, delete, distinct, func, or_, pool, select, tuple_, update
@@ -79,7 +79,12 @@ from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
from skyvern.forge.sdk.schemas.organizations import (
AzureClientSecretCredential,
AzureOrganizationAuthToken,
Organization,
OrganizationAuthToken,
)
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
from skyvern.forge.sdk.schemas.runs import Run
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
@@ -865,11 +870,25 @@ class AgentDB:
await session.refresh(organization)
return Organization.model_validate(organization)
@overload
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account],
) -> OrganizationAuthToken | None: ...
@overload
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential],
) -> AzureOrganizationAuthToken | None: ...
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
) -> OrganizationAuthToken | None:
) -> OrganizationAuthToken | AzureOrganizationAuthToken | None:
try:
async with self.Session() as session:
if token := (
@@ -881,7 +900,7 @@ class AgentDB:
.order_by(OrganizationAuthTokenModel.created_at.desc())
)
).first():
return await convert_to_organization_auth_token(token)
return await convert_to_organization_auth_token(token, token_type)
else:
return None
except SQLAlchemyError:
@@ -907,7 +926,7 @@ class AgentDB:
.order_by(OrganizationAuthTokenModel.created_at.desc())
)
).all()
return [await convert_to_organization_auth_token(token) for token in tokens]
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
@@ -941,7 +960,7 @@ class AgentDB:
if valid is not None:
query = query.filter_by(valid=valid)
if token_obj := (await session.scalars(query)).first():
return await convert_to_organization_auth_token(token_obj)
return await convert_to_organization_auth_token(token_obj, token_type)
else:
return None
except SQLAlchemyError:
@@ -955,14 +974,22 @@ class AgentDB:
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
token: str | AzureClientSecretCredential,
encrypted_method: EncryptMethod | None = None,
) -> OrganizationAuthToken:
plaintext_token = token
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
if not isinstance(token, AzureClientSecretCredential):
raise TypeError("Expected AzureClientSecretCredential for this token_type")
plaintext_token = token.model_dump_json()
else:
if not isinstance(token, str):
raise TypeError("Expected str token for this token_type")
plaintext_token = token
encrypted_token = ""
if encrypted_method is not None:
encrypted_token = await encryptor.encrypt(token, encrypted_method)
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
plaintext_token = ""
async with self.Session() as session:
@@ -977,7 +1004,7 @@ class AgentDB:
await session.commit()
await session.refresh(auth_token)
return await convert_to_organization_auth_token(auth_token)
return await convert_to_organization_auth_token(auth_token, token_type)
async def invalidate_org_auth_tokens(
self,

View File

@@ -4,6 +4,7 @@ from enum import StrEnum
class OrganizationAuthTokenType(StrEnum):
api = "api"
onepassword_service_account = "onepassword_service_account"
azure_client_secret_credential = "azure_client_secret_credential"
class TaskType(StrEnum):

View File

@@ -30,7 +30,12 @@ from skyvern.forge.sdk.db.models import (
from skyvern.forge.sdk.encrypt import encryptor
from skyvern.forge.sdk.encrypt.base import EncryptMethod
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
from skyvern.forge.sdk.schemas.organizations import (
AzureClientSecretCredential,
AzureOrganizationAuthToken,
Organization,
OrganizationAuthToken,
)
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
from skyvern.forge.sdk.workflow.models.parameter import (
@@ -195,21 +200,33 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
async def convert_to_organization_auth_token(
org_auth_token: OrganizationAuthTokenModel,
) -> OrganizationAuthToken:
org_auth_token: OrganizationAuthTokenModel, token_type: OrganizationAuthTokenType
) -> OrganizationAuthToken | AzureOrganizationAuthToken:
token = org_auth_token.token
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method))
return OrganizationAuthToken(
id=org_auth_token.id,
organization_id=org_auth_token.organization_id,
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
token=token,
valid=org_auth_token.valid,
created_at=org_auth_token.created_at,
modified_at=org_auth_token.modified_at,
)
if token_type == OrganizationAuthTokenType.azure_client_secret_credential:
credential = AzureClientSecretCredential.model_validate_json(token)
return AzureOrganizationAuthToken(
id=org_auth_token.id,
organization_id=org_auth_token.organization_id,
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
credential=credential,
valid=org_auth_token.valid,
created_at=org_auth_token.created_at,
modified_at=org_auth_token.modified_at,
)
else:
return OrganizationAuthToken(
id=org_auth_token.id,
organization_id=org_auth_token.organization_id,
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
token=token,
valid=org_auth_token.valid,
created_at=org_auth_token.created_at,
modified_at=org_auth_token.modified_at,
)
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:

View File

@@ -21,6 +21,8 @@ from skyvern.forge.sdk.schemas.credentials import (
PasswordCredentialResponse,
)
from skyvern.forge.sdk.schemas.organizations import (
AzureClientSecretCredentialResponse,
CreateAzureClientSecretCredentialRequest,
CreateOnePasswordTokenRequest,
CreateOnePasswordTokenResponse,
Organization,
@@ -478,3 +480,106 @@ async def update_onepassword_token(
status_code=500,
detail=f"Failed to create or update OnePassword service account token: {str(e)}",
)
@base_router.get(
"/credentials/azure_credential/get",
response_model=AzureClientSecretCredentialResponse,
summary="Get Azure Client Secret Credential",
description="Retrieves the current Azure Client Secret Credential for the organization.",
include_in_schema=False,
)
@base_router.get(
"/credentials/azure_credential/get/",
response_model=AzureClientSecretCredentialResponse,
include_in_schema=False,
)
async def get_azure_client_secret_credential(
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> AzureClientSecretCredentialResponse:
"""
Get the current Azure Client Secret Credential for the organization.
"""
try:
auth_token = await app.DATABASE.get_valid_org_auth_token(
organization_id=current_org.organization_id,
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
)
if not auth_token:
raise HTTPException(
status_code=404,
detail="No Azure Client Secret Credential found for this organization",
)
return AzureClientSecretCredentialResponse(token=auth_token)
except HTTPException:
raise
except Exception as e:
LOG.error(
"Failed to get Azure Client Secret Credential",
organization_id=current_org.organization_id,
error=str(e),
exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"Failed to get Azure Client Secret Credential: {str(e)}",
)
@base_router.post(
"/credentials/azure_credential/create",
response_model=AzureClientSecretCredentialResponse,
summary="Create or update Azure Client Secret Credential",
description="Creates or updates a Azure Client Secret Credential for the current organization. Only one valid record is allowed per organization.",
include_in_schema=False,
)
@base_router.post(
"/credentials/azure_credential/create/",
response_model=AzureClientSecretCredentialResponse,
include_in_schema=False,
)
async def update_azure_client_secret_credential(
request: CreateAzureClientSecretCredentialRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> AzureClientSecretCredentialResponse:
"""
Create or update an Azure Client Secret Credential for the current organization.
This endpoint ensures only one valid Azure Client Secret Credential exists per organization.
If a valid token already exists, it will be invalidated before creating the new one.
"""
try:
# Invalidate any existing valid Azure Client Secret Credential for this organization
await app.DATABASE.invalidate_org_auth_tokens(
organization_id=current_org.organization_id,
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
)
# Create the new Azure token
auth_token = await app.DATABASE.create_org_auth_token(
organization_id=current_org.organization_id,
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
token=request.credential,
)
LOG.info(
"Created or updated Azure Client Secret Credential",
organization_id=current_org.organization_id,
token_id=auth_token.id,
)
return AzureClientSecretCredentialResponse(token=auth_token)
except Exception as e:
LOG.error(
"Failed to create or update Azure Client Secret Credential",
organization_id=current_org.organization_id,
error=str(e),
exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"Failed to create or update Azure Client Secret Credential: {str(e)}",
)

View File

@@ -21,16 +21,31 @@ class Organization(BaseModel):
modified_at: datetime
class OrganizationAuthToken(BaseModel):
class OrganizationAuthTokenBase(BaseModel):
id: str
organization_id: str
token_type: OrganizationAuthTokenType
token: str
valid: bool
created_at: datetime
modified_at: datetime
class OrganizationAuthToken(OrganizationAuthTokenBase):
token: str
class AzureClientSecretCredential(BaseModel):
tenant_id: str
client_id: str
client_secret: str
class AzureOrganizationAuthToken(OrganizationAuthTokenBase):
"""Represents OrganizationAuthToken for Azure; defined by 3 fields: tenant_id, client_id, and client_secret"""
credential: AzureClientSecretCredential
class CreateOnePasswordTokenRequest(BaseModel):
"""Request model for creating or updating a 1Password service account token."""
@@ -50,6 +65,21 @@ class CreateOnePasswordTokenResponse(BaseModel):
)
class AzureClientSecretCredentialResponse(BaseModel):
"""Response model for Azure ClientSecretCredential operations."""
token: AzureOrganizationAuthToken = Field(
...,
description="The created or updated Azure ClientSecretCredential",
)
class CreateAzureClientSecretCredentialRequest(BaseModel):
"""Request model for creating or updating an Azure ClientSecretCredential."""
credential: AzureClientSecretCredential
class GetOrganizationsResponse(BaseModel):
organizations: list[Organization]

View File

@@ -7,6 +7,7 @@ from onepassword.client import Client as OnePasswordClient
from skyvern.config import settings
from skyvern.exceptions import (
AzureConfigurationError,
BitwardenBaseError,
CredentialParameterNotFoundError,
SkyvernException,
@@ -14,7 +15,7 @@ from skyvern.exceptions import (
)
from skyvern.forge import app
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.api.azure import AsyncAzureClient
from skyvern.forge.sdk.api.azure import AsyncAzureVaultClient
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.schemas.credentials import PasswordCredential
from skyvern.forge.sdk.schemas.organizations import Organization
@@ -56,7 +57,6 @@ class WorkflowRunContext:
async def init(
cls,
aws_client: AsyncAWSClient,
azure_client: AsyncAzureClient,
organization: Organization,
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
workflow_output_parameters: list[OutputParameter],
@@ -71,7 +71,7 @@ class WorkflowRunContext:
block_outputs: dict[str, Any] | None = None,
) -> Self:
# key is label name
workflow_run_context = cls(aws_client=aws_client, azure_client=azure_client)
workflow_run_context = cls(aws_client=aws_client)
for parameter, run_parameter in workflow_parameter_tuples:
if parameter.workflow_parameter_type == WorkflowParameterType.CREDENTIAL_ID:
await workflow_run_context.register_secret_workflow_parameter_value(
@@ -109,7 +109,9 @@ class WorkflowRunContext:
secret_parameter, organization
)
elif isinstance(secret_parameter, AzureVaultCredentialParameter):
await workflow_run_context.register_azure_vault_credential_parameter_value(secret_parameter)
await workflow_run_context.register_azure_vault_credential_parameter_value(
secret_parameter, organization
)
elif isinstance(secret_parameter, BitwardenLoginCredentialParameter):
await workflow_run_context.register_bitwarden_login_credential_parameter_value(
secret_parameter, organization
@@ -131,13 +133,12 @@ class WorkflowRunContext:
return workflow_run_context
def __init__(self, aws_client: AsyncAWSClient, azure_client: AsyncAzureClient) -> None:
def __init__(self, aws_client: AsyncAWSClient) -> None:
self.blocks_metadata: dict[str, BlockMetadata] = {}
self.parameters: dict[str, PARAMETER_TYPE] = {}
self.values: dict[str, Any] = {}
self.secrets: dict[str, Any] = {}
self._aws_client = aws_client
self._azure_client = azure_client
def get_parameter(self, key: str) -> Parameter:
return self.parameters[key]
@@ -343,10 +344,16 @@ class WorkflowRunContext:
self,
parameter: AzureSecretParameter,
) -> None:
vault_name = settings.AZURE_STORAGE_ACCOUNT_NAME
if vault_name is None:
LOG.error("AZURE_STORAGE_ACCOUNT_NAME is not configured, cannot register Azure secret parameter value")
raise AzureConfigurationError("AZURE_STORAGE_ACCOUNT_NAME is not configured")
# If the parameter is an Azure secret, fetch the secret value and store it in the secrets dict
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
# We'll replace the random secret id with the actual secret value when we need to use it.
secret_value = await self._azure_client.get_secret(parameter.azure_key)
azure_vault_client = AsyncAzureVaultClient.create_default()
secret_value = await azure_vault_client.get_secret(parameter.azure_key, vault_name)
if secret_value is not None:
random_secret_id = self.generate_random_secret_id()
self.secrets[random_secret_id] = secret_value
@@ -491,7 +498,11 @@ class WorkflowRunContext:
LOG.error(f"Failed to get secret from Bitwarden. Error: {e}")
raise e
async def register_azure_vault_credential_parameter_value(self, parameter: AzureVaultCredentialParameter) -> None:
async def register_azure_vault_credential_parameter_value(
self,
parameter: AzureVaultCredentialParameter,
organization: Organization,
) -> None:
vault_name = self._resolve_parameter_value(parameter.vault_name)
if not vault_name:
raise ValueError("Azure Vault Name is missing")
@@ -504,18 +515,28 @@ class WorkflowRunContext:
totp_secret_key = self._resolve_parameter_value(parameter.totp_secret_key)
secret_login = await self._azure_client.get_secret(username_key, vault_name)
secret_password = await self._azure_client.get_secret(password_key, vault_name)
azure_vault_client = await self._get_azure_vault_client_for_organization(organization)
secret_username = await azure_vault_client.get_secret(username_key, vault_name)
if not secret_username:
raise ValueError(f"Azure Vault username not found by key: {username_key}")
secret_password = await azure_vault_client.get_secret(password_key, vault_name)
if not secret_password:
raise ValueError(f"Azure Vault password not found by key: {password_key}")
if totp_secret_key:
totp_secret = await self._azure_client.get_secret(totp_secret_key, vault_name)
totp_secret = await azure_vault_client.get_secret(totp_secret_key, vault_name)
if not totp_secret:
raise ValueError(f"Azure Vault TOTP not found by key: {totp_secret_key}")
else:
totp_secret = None
if secret_login is not None and secret_password is not None:
if secret_username is not None and secret_password is not None:
random_secret_id = self.generate_random_secret_id()
# login secret
username_secret_id = f"{random_secret_id}_username"
self.secrets[username_secret_id] = secret_login
self.secrets[username_secret_id] = secret_username
# password secret
password_secret_id = f"{random_secret_id}_password"
self.secrets[password_secret_id] = secret_password
@@ -895,10 +916,21 @@ class WorkflowRunContext:
else:
return jinja_sandbox_env.from_string(parameter_value).render(self.values)
@staticmethod
async def _get_azure_vault_client_for_organization(organization: Organization) -> AsyncAzureVaultClient:
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential
)
if org_auth_token:
azure_vault_client = AsyncAzureVaultClient.create_from_client_secret(org_auth_token.credential)
else:
# Use the DefaultAzureCredential if not configured on organization level
azure_vault_client = AsyncAzureVaultClient.create_default()
return azure_vault_client
class WorkflowContextManager:
aws_client: AsyncAWSClient
azure_client: AsyncAzureClient
workflow_run_contexts: dict[str, WorkflowRunContext]
parameters: dict[str, PARAMETER_TYPE]
@@ -907,10 +939,6 @@ class WorkflowContextManager:
def __init__(self) -> None:
self.aws_client = AsyncAWSClient()
self.azure_client = AsyncAzureClient(
storage_account_name=settings.AZURE_STORAGE_ACCOUNT_NAME,
storage_account_key=settings.AZURE_STORAGE_ACCOUNT_KEY,
)
self.workflow_run_contexts = {}
def _validate_workflow_run_context(self, workflow_run_id: str) -> None:
@@ -935,7 +963,6 @@ class WorkflowContextManager:
) -> WorkflowRunContext:
workflow_run_context = await WorkflowRunContext.init(
self.aws_client,
self.azure_client,
organization,
workflow_parameter_tuples,
workflow_output_parameters,

View File

@@ -34,6 +34,7 @@ from skyvern.constants import (
MAX_UPLOAD_FILE_COUNT,
)
from skyvern.exceptions import (
AzureConfigurationError,
ContextParameterValueNotFound,
MissingBrowserState,
MissingBrowserStatePage,
@@ -44,7 +45,7 @@ from skyvern.exceptions import (
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.api.azure import AsyncAzureClient
from skyvern.forge.sdk.api.azure import AsyncAzureStorageClient
from skyvern.forge.sdk.api.files import (
calculate_sha256_for_file,
create_named_temporary_file,
@@ -2061,7 +2062,10 @@ class FileUploadBlock(Block):
workflow_run_context.get_original_secret_value_or_none(self.azure_storage_account_key)
or self.azure_storage_account_key
)
azure_client = AsyncAzureClient(
if actual_azure_storage_account_name is None or actual_azure_storage_account_key is None:
raise AzureConfigurationError("Azure Storage is not configured")
azure_client = AsyncAzureStorageClient(
storage_account_name=actual_azure_storage_account_name,
storage_account_key=actual_azure_storage_account_key,
)