Azure ClientSecretCredential support (#3456)
Co-authored-by: Suchintan <suchintan@users.noreply.github.com> Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -1,39 +1,24 @@
|
||||
import structlog
|
||||
from azure.identity.aio import DefaultAzureCredential
|
||||
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
|
||||
from azure.keyvault.secrets.aio import SecretClient
|
||||
from azure.storage.blob.aio import BlobServiceClient
|
||||
|
||||
from skyvern.exceptions import AzureConfigurationError
|
||||
from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AsyncAzureClient:
|
||||
def __init__(self, storage_account_name: str | None, storage_account_key: str | None):
|
||||
self.storage_account_name = storage_account_name
|
||||
self.storage_account_key = storage_account_key
|
||||
|
||||
if storage_account_name and storage_account_key:
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=f"https://{storage_account_name}.blob.core.windows.net",
|
||||
credential=storage_account_key,
|
||||
)
|
||||
else:
|
||||
self.blob_service_client = None
|
||||
|
||||
self.credential = DefaultAzureCredential()
|
||||
|
||||
async def get_secret(self, secret_name: str, vault_name: str | None = None) -> str | None:
|
||||
vault_subdomain = vault_name or self.storage_account_name
|
||||
if not vault_subdomain:
|
||||
raise AzureConfigurationError("Missing vault")
|
||||
class AsyncAzureVaultClient:
|
||||
def __init__(self, credential: ClientSecretCredential | DefaultAzureCredential):
|
||||
self.credential = credential
|
||||
|
||||
async def get_secret(self, secret_name: str, vault_name: str) -> str | None:
|
||||
try:
|
||||
# Azure Key Vault URL format: https://<your-key-vault-name>.vault.azure.net
|
||||
# Assuming the secret_name is actually the Key Vault URL and the secret name
|
||||
# This needs to be clarified or passed as separate parameters
|
||||
# For now, let's assume secret_name is the actual secret name and Key Vault URL is in settings.
|
||||
key_vault_url = f"https://{vault_subdomain}.vault.azure.net" # Placeholder, adjust as needed
|
||||
key_vault_url = f"https://{vault_name}.vault.azure.net" # Placeholder, adjust as needed
|
||||
secret_client = SecretClient(vault_url=key_vault_url, credential=self.credential)
|
||||
secret = await secret_client.get_secret(secret_name)
|
||||
return secret.value
|
||||
@@ -43,10 +28,34 @@ class AsyncAzureClient:
|
||||
finally:
|
||||
await self.credential.close()
|
||||
|
||||
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
|
||||
if not self.blob_service_client:
|
||||
raise AzureConfigurationError("Storage is not configured")
|
||||
async def close(self) -> None:
|
||||
await self.credential.close()
|
||||
|
||||
@classmethod
|
||||
def create_default(cls) -> "AsyncAzureVaultClient":
|
||||
return cls(DefaultAzureCredential())
|
||||
|
||||
@classmethod
|
||||
def create_from_client_secret(
|
||||
cls,
|
||||
credential: AzureClientSecretCredential,
|
||||
) -> "AsyncAzureVaultClient":
|
||||
cred = ClientSecretCredential(
|
||||
tenant_id=credential.tenant_id,
|
||||
client_id=credential.client_id,
|
||||
client_secret=credential.client_secret,
|
||||
)
|
||||
return cls(cred)
|
||||
|
||||
|
||||
class AsyncAzureStorageClient:
|
||||
def __init__(self, storage_account_name: str, storage_account_key: str):
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=f"https://{storage_account_name}.blob.core.windows.net",
|
||||
credential=storage_account_key,
|
||||
)
|
||||
|
||||
async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None:
|
||||
try:
|
||||
container_client = self.blob_service_client.get_container_client(container_name)
|
||||
# Create the container if it doesn't exist
|
||||
@@ -69,4 +78,3 @@ class AsyncAzureClient:
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.blob_service_client.close()
|
||||
await self.credential.close()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, List, Sequence
|
||||
from typing import Any, List, Literal, Sequence, overload
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, asc, delete, distinct, func, or_, pool, select, tuple_, update
|
||||
@@ -79,7 +79,12 @@ from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
|
||||
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession
|
||||
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
AzureClientSecretCredential,
|
||||
AzureOrganizationAuthToken,
|
||||
Organization,
|
||||
OrganizationAuthToken,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||
@@ -865,11 +870,25 @@ class AgentDB:
|
||||
await session.refresh(organization)
|
||||
return Organization.model_validate(organization)
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account],
|
||||
) -> OrganizationAuthToken | None: ...
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential],
|
||||
) -> AzureOrganizationAuthToken | None: ...
|
||||
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> OrganizationAuthToken | None:
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | None:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
if token := (
|
||||
@@ -881,7 +900,7 @@ class AgentDB:
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).first():
|
||||
return await convert_to_organization_auth_token(token)
|
||||
return await convert_to_organization_auth_token(token, token_type)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -907,7 +926,7 @@ class AgentDB:
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).all()
|
||||
return [await convert_to_organization_auth_token(token) for token in tokens]
|
||||
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
@@ -941,7 +960,7 @@ class AgentDB:
|
||||
if valid is not None:
|
||||
query = query.filter_by(valid=valid)
|
||||
if token_obj := (await session.scalars(query)).first():
|
||||
return await convert_to_organization_auth_token(token_obj)
|
||||
return await convert_to_organization_auth_token(token_obj, token_type)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -955,14 +974,22 @@ class AgentDB:
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
token: str | AzureClientSecretCredential,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken:
|
||||
plaintext_token = token
|
||||
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
||||
if not isinstance(token, AzureClientSecretCredential):
|
||||
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
else:
|
||||
if not isinstance(token, str):
|
||||
raise TypeError("Expected str token for this token_type")
|
||||
plaintext_token = token
|
||||
|
||||
encrypted_token = ""
|
||||
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
||||
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
||||
plaintext_token = ""
|
||||
|
||||
async with self.Session() as session:
|
||||
@@ -977,7 +1004,7 @@ class AgentDB:
|
||||
await session.commit()
|
||||
await session.refresh(auth_token)
|
||||
|
||||
return await convert_to_organization_auth_token(auth_token)
|
||||
return await convert_to_organization_auth_token(auth_token, token_type)
|
||||
|
||||
async def invalidate_org_auth_tokens(
|
||||
self,
|
||||
|
||||
@@ -4,6 +4,7 @@ from enum import StrEnum
|
||||
class OrganizationAuthTokenType(StrEnum):
|
||||
api = "api"
|
||||
onepassword_service_account = "onepassword_service_account"
|
||||
azure_client_secret_credential = "azure_client_secret_credential"
|
||||
|
||||
|
||||
class TaskType(StrEnum):
|
||||
|
||||
@@ -30,7 +30,12 @@ from skyvern.forge.sdk.db.models import (
|
||||
from skyvern.forge.sdk.encrypt import encryptor
|
||||
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
AzureClientSecretCredential,
|
||||
AzureOrganizationAuthToken,
|
||||
Organization,
|
||||
OrganizationAuthToken,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
@@ -195,21 +200,33 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
||||
|
||||
|
||||
async def convert_to_organization_auth_token(
|
||||
org_auth_token: OrganizationAuthTokenModel,
|
||||
) -> OrganizationAuthToken:
|
||||
org_auth_token: OrganizationAuthTokenModel, token_type: OrganizationAuthTokenType
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken:
|
||||
token = org_auth_token.token
|
||||
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
|
||||
token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method))
|
||||
|
||||
return OrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
token=token,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
)
|
||||
if token_type == OrganizationAuthTokenType.azure_client_secret_credential:
|
||||
credential = AzureClientSecretCredential.model_validate_json(token)
|
||||
return AzureOrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
credential=credential,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
)
|
||||
else:
|
||||
return OrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
token=token,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:
|
||||
|
||||
@@ -21,6 +21,8 @@ from skyvern.forge.sdk.schemas.credentials import (
|
||||
PasswordCredentialResponse,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
AzureClientSecretCredentialResponse,
|
||||
CreateAzureClientSecretCredentialRequest,
|
||||
CreateOnePasswordTokenRequest,
|
||||
CreateOnePasswordTokenResponse,
|
||||
Organization,
|
||||
@@ -478,3 +480,106 @@ async def update_onepassword_token(
|
||||
status_code=500,
|
||||
detail=f"Failed to create or update OnePassword service account token: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
"/credentials/azure_credential/get",
|
||||
response_model=AzureClientSecretCredentialResponse,
|
||||
summary="Get Azure Client Secret Credential",
|
||||
description="Retrieves the current Azure Client Secret Credential for the organization.",
|
||||
include_in_schema=False,
|
||||
)
|
||||
@base_router.get(
|
||||
"/credentials/azure_credential/get/",
|
||||
response_model=AzureClientSecretCredentialResponse,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_azure_client_secret_credential(
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> AzureClientSecretCredentialResponse:
|
||||
"""
|
||||
Get the current Azure Client Secret Credential for the organization.
|
||||
"""
|
||||
try:
|
||||
auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id=current_org.organization_id,
|
||||
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
|
||||
)
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Azure Client Secret Credential found for this organization",
|
||||
)
|
||||
|
||||
return AzureClientSecretCredentialResponse(token=auth_token)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"Failed to get Azure Client Secret Credential",
|
||||
organization_id=current_org.organization_id,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get Azure Client Secret Credential: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/credentials/azure_credential/create",
|
||||
response_model=AzureClientSecretCredentialResponse,
|
||||
summary="Create or update Azure Client Secret Credential",
|
||||
description="Creates or updates a Azure Client Secret Credential for the current organization. Only one valid record is allowed per organization.",
|
||||
include_in_schema=False,
|
||||
)
|
||||
@base_router.post(
|
||||
"/credentials/azure_credential/create/",
|
||||
response_model=AzureClientSecretCredentialResponse,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def update_azure_client_secret_credential(
|
||||
request: CreateAzureClientSecretCredentialRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> AzureClientSecretCredentialResponse:
|
||||
"""
|
||||
Create or update an Azure Client Secret Credential for the current organization.
|
||||
|
||||
This endpoint ensures only one valid Azure Client Secret Credential exists per organization.
|
||||
If a valid token already exists, it will be invalidated before creating the new one.
|
||||
"""
|
||||
try:
|
||||
# Invalidate any existing valid Azure Client Secret Credential for this organization
|
||||
await app.DATABASE.invalidate_org_auth_tokens(
|
||||
organization_id=current_org.organization_id,
|
||||
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
|
||||
)
|
||||
|
||||
# Create the new Azure token
|
||||
auth_token = await app.DATABASE.create_org_auth_token(
|
||||
organization_id=current_org.organization_id,
|
||||
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
|
||||
token=request.credential,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Created or updated Azure Client Secret Credential",
|
||||
organization_id=current_org.organization_id,
|
||||
token_id=auth_token.id,
|
||||
)
|
||||
|
||||
return AzureClientSecretCredentialResponse(token=auth_token)
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"Failed to create or update Azure Client Secret Credential",
|
||||
organization_id=current_org.organization_id,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create or update Azure Client Secret Credential: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -21,16 +21,31 @@ class Organization(BaseModel):
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class OrganizationAuthToken(BaseModel):
|
||||
class OrganizationAuthTokenBase(BaseModel):
|
||||
id: str
|
||||
organization_id: str
|
||||
token_type: OrganizationAuthTokenType
|
||||
token: str
|
||||
valid: bool
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class OrganizationAuthToken(OrganizationAuthTokenBase):
|
||||
token: str
|
||||
|
||||
|
||||
class AzureClientSecretCredential(BaseModel):
|
||||
tenant_id: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
class AzureOrganizationAuthToken(OrganizationAuthTokenBase):
|
||||
"""Represents OrganizationAuthToken for Azure; defined by 3 fields: tenant_id, client_id, and client_secret"""
|
||||
|
||||
credential: AzureClientSecretCredential
|
||||
|
||||
|
||||
class CreateOnePasswordTokenRequest(BaseModel):
|
||||
"""Request model for creating or updating a 1Password service account token."""
|
||||
|
||||
@@ -50,6 +65,21 @@ class CreateOnePasswordTokenResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class AzureClientSecretCredentialResponse(BaseModel):
|
||||
"""Response model for Azure ClientSecretCredential operations."""
|
||||
|
||||
token: AzureOrganizationAuthToken = Field(
|
||||
...,
|
||||
description="The created or updated Azure ClientSecretCredential",
|
||||
)
|
||||
|
||||
|
||||
class CreateAzureClientSecretCredentialRequest(BaseModel):
|
||||
"""Request model for creating or updating an Azure ClientSecretCredential."""
|
||||
|
||||
credential: AzureClientSecretCredential
|
||||
|
||||
|
||||
class GetOrganizationsResponse(BaseModel):
|
||||
organizations: list[Organization]
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from onepassword.client import Client as OnePasswordClient
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import (
|
||||
AzureConfigurationError,
|
||||
BitwardenBaseError,
|
||||
CredentialParameterNotFoundError,
|
||||
SkyvernException,
|
||||
@@ -14,7 +15,7 @@ from skyvern.exceptions import (
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.api.azure import AsyncAzureClient
|
||||
from skyvern.forge.sdk.api.azure import AsyncAzureVaultClient
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.schemas.credentials import PasswordCredential
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
@@ -56,7 +57,6 @@ class WorkflowRunContext:
|
||||
async def init(
|
||||
cls,
|
||||
aws_client: AsyncAWSClient,
|
||||
azure_client: AsyncAzureClient,
|
||||
organization: Organization,
|
||||
workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]],
|
||||
workflow_output_parameters: list[OutputParameter],
|
||||
@@ -71,7 +71,7 @@ class WorkflowRunContext:
|
||||
block_outputs: dict[str, Any] | None = None,
|
||||
) -> Self:
|
||||
# key is label name
|
||||
workflow_run_context = cls(aws_client=aws_client, azure_client=azure_client)
|
||||
workflow_run_context = cls(aws_client=aws_client)
|
||||
for parameter, run_parameter in workflow_parameter_tuples:
|
||||
if parameter.workflow_parameter_type == WorkflowParameterType.CREDENTIAL_ID:
|
||||
await workflow_run_context.register_secret_workflow_parameter_value(
|
||||
@@ -109,7 +109,9 @@ class WorkflowRunContext:
|
||||
secret_parameter, organization
|
||||
)
|
||||
elif isinstance(secret_parameter, AzureVaultCredentialParameter):
|
||||
await workflow_run_context.register_azure_vault_credential_parameter_value(secret_parameter)
|
||||
await workflow_run_context.register_azure_vault_credential_parameter_value(
|
||||
secret_parameter, organization
|
||||
)
|
||||
elif isinstance(secret_parameter, BitwardenLoginCredentialParameter):
|
||||
await workflow_run_context.register_bitwarden_login_credential_parameter_value(
|
||||
secret_parameter, organization
|
||||
@@ -131,13 +133,12 @@ class WorkflowRunContext:
|
||||
|
||||
return workflow_run_context
|
||||
|
||||
def __init__(self, aws_client: AsyncAWSClient, azure_client: AsyncAzureClient) -> None:
|
||||
def __init__(self, aws_client: AsyncAWSClient) -> None:
|
||||
self.blocks_metadata: dict[str, BlockMetadata] = {}
|
||||
self.parameters: dict[str, PARAMETER_TYPE] = {}
|
||||
self.values: dict[str, Any] = {}
|
||||
self.secrets: dict[str, Any] = {}
|
||||
self._aws_client = aws_client
|
||||
self._azure_client = azure_client
|
||||
|
||||
def get_parameter(self, key: str) -> Parameter:
|
||||
return self.parameters[key]
|
||||
@@ -343,10 +344,16 @@ class WorkflowRunContext:
|
||||
self,
|
||||
parameter: AzureSecretParameter,
|
||||
) -> None:
|
||||
vault_name = settings.AZURE_STORAGE_ACCOUNT_NAME
|
||||
if vault_name is None:
|
||||
LOG.error("AZURE_STORAGE_ACCOUNT_NAME is not configured, cannot register Azure secret parameter value")
|
||||
raise AzureConfigurationError("AZURE_STORAGE_ACCOUNT_NAME is not configured")
|
||||
|
||||
# If the parameter is an Azure secret, fetch the secret value and store it in the secrets dict
|
||||
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
|
||||
# We'll replace the random secret id with the actual secret value when we need to use it.
|
||||
secret_value = await self._azure_client.get_secret(parameter.azure_key)
|
||||
azure_vault_client = AsyncAzureVaultClient.create_default()
|
||||
secret_value = await azure_vault_client.get_secret(parameter.azure_key, vault_name)
|
||||
if secret_value is not None:
|
||||
random_secret_id = self.generate_random_secret_id()
|
||||
self.secrets[random_secret_id] = secret_value
|
||||
@@ -491,7 +498,11 @@ class WorkflowRunContext:
|
||||
LOG.error(f"Failed to get secret from Bitwarden. Error: {e}")
|
||||
raise e
|
||||
|
||||
async def register_azure_vault_credential_parameter_value(self, parameter: AzureVaultCredentialParameter) -> None:
|
||||
async def register_azure_vault_credential_parameter_value(
|
||||
self,
|
||||
parameter: AzureVaultCredentialParameter,
|
||||
organization: Organization,
|
||||
) -> None:
|
||||
vault_name = self._resolve_parameter_value(parameter.vault_name)
|
||||
if not vault_name:
|
||||
raise ValueError("Azure Vault Name is missing")
|
||||
@@ -504,18 +515,28 @@ class WorkflowRunContext:
|
||||
|
||||
totp_secret_key = self._resolve_parameter_value(parameter.totp_secret_key)
|
||||
|
||||
secret_login = await self._azure_client.get_secret(username_key, vault_name)
|
||||
secret_password = await self._azure_client.get_secret(password_key, vault_name)
|
||||
azure_vault_client = await self._get_azure_vault_client_for_organization(organization)
|
||||
|
||||
secret_username = await azure_vault_client.get_secret(username_key, vault_name)
|
||||
if not secret_username:
|
||||
raise ValueError(f"Azure Vault username not found by key: {username_key}")
|
||||
|
||||
secret_password = await azure_vault_client.get_secret(password_key, vault_name)
|
||||
if not secret_password:
|
||||
raise ValueError(f"Azure Vault password not found by key: {password_key}")
|
||||
|
||||
if totp_secret_key:
|
||||
totp_secret = await self._azure_client.get_secret(totp_secret_key, vault_name)
|
||||
totp_secret = await azure_vault_client.get_secret(totp_secret_key, vault_name)
|
||||
if not totp_secret:
|
||||
raise ValueError(f"Azure Vault TOTP not found by key: {totp_secret_key}")
|
||||
else:
|
||||
totp_secret = None
|
||||
|
||||
if secret_login is not None and secret_password is not None:
|
||||
if secret_username is not None and secret_password is not None:
|
||||
random_secret_id = self.generate_random_secret_id()
|
||||
# login secret
|
||||
username_secret_id = f"{random_secret_id}_username"
|
||||
self.secrets[username_secret_id] = secret_login
|
||||
self.secrets[username_secret_id] = secret_username
|
||||
# password secret
|
||||
password_secret_id = f"{random_secret_id}_password"
|
||||
self.secrets[password_secret_id] = secret_password
|
||||
@@ -895,10 +916,21 @@ class WorkflowRunContext:
|
||||
else:
|
||||
return jinja_sandbox_env.from_string(parameter_value).render(self.values)
|
||||
|
||||
@staticmethod
|
||||
async def _get_azure_vault_client_for_organization(organization: Organization) -> AsyncAzureVaultClient:
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential
|
||||
)
|
||||
if org_auth_token:
|
||||
azure_vault_client = AsyncAzureVaultClient.create_from_client_secret(org_auth_token.credential)
|
||||
else:
|
||||
# Use the DefaultAzureCredential if not configured on organization level
|
||||
azure_vault_client = AsyncAzureVaultClient.create_default()
|
||||
return azure_vault_client
|
||||
|
||||
|
||||
class WorkflowContextManager:
|
||||
aws_client: AsyncAWSClient
|
||||
azure_client: AsyncAzureClient
|
||||
workflow_run_contexts: dict[str, WorkflowRunContext]
|
||||
|
||||
parameters: dict[str, PARAMETER_TYPE]
|
||||
@@ -907,10 +939,6 @@ class WorkflowContextManager:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.aws_client = AsyncAWSClient()
|
||||
self.azure_client = AsyncAzureClient(
|
||||
storage_account_name=settings.AZURE_STORAGE_ACCOUNT_NAME,
|
||||
storage_account_key=settings.AZURE_STORAGE_ACCOUNT_KEY,
|
||||
)
|
||||
self.workflow_run_contexts = {}
|
||||
|
||||
def _validate_workflow_run_context(self, workflow_run_id: str) -> None:
|
||||
@@ -935,7 +963,6 @@ class WorkflowContextManager:
|
||||
) -> WorkflowRunContext:
|
||||
workflow_run_context = await WorkflowRunContext.init(
|
||||
self.aws_client,
|
||||
self.azure_client,
|
||||
organization,
|
||||
workflow_parameter_tuples,
|
||||
workflow_output_parameters,
|
||||
|
||||
@@ -34,6 +34,7 @@ from skyvern.constants import (
|
||||
MAX_UPLOAD_FILE_COUNT,
|
||||
)
|
||||
from skyvern.exceptions import (
|
||||
AzureConfigurationError,
|
||||
ContextParameterValueNotFound,
|
||||
MissingBrowserState,
|
||||
MissingBrowserStatePage,
|
||||
@@ -44,7 +45,7 @@ from skyvern.exceptions import (
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.api.azure import AsyncAzureClient
|
||||
from skyvern.forge.sdk.api.azure import AsyncAzureStorageClient
|
||||
from skyvern.forge.sdk.api.files import (
|
||||
calculate_sha256_for_file,
|
||||
create_named_temporary_file,
|
||||
@@ -2061,7 +2062,10 @@ class FileUploadBlock(Block):
|
||||
workflow_run_context.get_original_secret_value_or_none(self.azure_storage_account_key)
|
||||
or self.azure_storage_account_key
|
||||
)
|
||||
azure_client = AsyncAzureClient(
|
||||
if actual_azure_storage_account_name is None or actual_azure_storage_account_key is None:
|
||||
raise AzureConfigurationError("Azure Storage is not configured")
|
||||
|
||||
azure_client = AsyncAzureStorageClient(
|
||||
storage_account_name=actual_azure_storage_account_name,
|
||||
storage_account_key=actual_azure_storage_account_key,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user