diff --git a/skyvern-frontend/src/api/types.ts b/skyvern-frontend/src/api/types.ts index 96f141a9..b007f9f3 100644 --- a/skyvern-frontend/src/api/types.ts +++ b/skyvern-frontend/src/api/types.ts @@ -194,6 +194,30 @@ export type CreateOnePasswordTokenResponse = { token: OnePasswordTokenApiResponse; }; +export interface AzureClientSecretCredential { + tenant_id: string; + client_id: string; + client_secret: string; +} + +export interface AzureOrganizationAuthToken { + id: string; + organization_id: string; + credential: AzureClientSecretCredential; + created_at: string; + modified_at: string; + token_type: string; + valid: boolean; +} + +export interface CreateAzureClientSecretCredentialRequest { + credential: AzureClientSecretCredential; +} + +export interface AzureClientSecretCredentialResponse { + token: AzureOrganizationAuthToken; +} + // TODO complete this export const ActionTypes = { InputText: "input_text", diff --git a/skyvern-frontend/src/components/AzureClientSecretCredentialTokenForm.tsx b/skyvern-frontend/src/components/AzureClientSecretCredentialTokenForm.tsx new file mode 100644 index 00000000..6a797ad5 --- /dev/null +++ b/skyvern-frontend/src/components/AzureClientSecretCredentialTokenForm.tsx @@ -0,0 +1,202 @@ +import { useEffect, useState } from "react"; +import { useForm } from "react-hook-form"; +import { zodResolver } from "@hookform/resolvers/zod"; +import * as z from "zod"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { useAzureClientCredentialToken } from "@/hooks/useAzureClientCredentialToken"; +import { EyeOpenIcon, EyeClosedIcon } from "@radix-ui/react-icons"; + +const AzureClientSecretCredentialSchema = z + .object({ + tenant_id: z.string().min(1, "tenant_id is required"), + client_id: z.string().min(1, "client_id is required"), + client_secret: z.string().min(1, "client_secret is required"), + }) + .strict(); + +const formSchema = z + .object({ + credential: AzureClientSecretCredentialSchema, + }) + .strict(); + +type FormData = z.infer; + +export function AzureClientSecretCredentialTokenForm() { + const [showClientSecret, setShowClientSecret] = useState(false); + const { + azureOrganizationAuthToken, + isLoading, + createOrUpdateToken, + isUpdating, + } = useAzureClientCredentialToken(); + + const form = useForm({ + resolver: zodResolver(formSchema), + defaultValues: { + credential: azureOrganizationAuthToken?.credential || { + tenant_id: "", + client_id: "", + client_secret: "", + }, + }, + }); + + const onSubmit = (data: FormData) => { + createOrUpdateToken(data); + }; + + const toggleClientSecretVisibility = () => { + setShowClientSecret((v) => !v); + }; + + useEffect(() => { + if (azureOrganizationAuthToken?.credential) { + form.reset({ credential: azureOrganizationAuthToken.credential }); + } + }, [azureOrganizationAuthToken, form]); + + return ( +
+
+
+

+ Azure Client Secret Credential +

+

+ Configure your Azure Client Secret Credential to give access to your + Azure account. +

+
+ {azureOrganizationAuthToken && ( +
+ Status: + + {azureOrganizationAuthToken.valid ? "Active" : "Inactive"} + +
+ )} +
+ +
+ + ( + + Tenant ID +
+ + + +
+ +
+ )} + /> + ( + + Client ID +
+ + + +
+ +
+ )} + /> + ( + + Client Secret +
+ + + + +
+ +
+ )} + /> + +
+ + {azureOrganizationAuthToken && ( +
+ Last updated:{" "} + {new Date( + azureOrganizationAuthToken.modified_at, + ).toLocaleDateString()} +
+ )} +
+ + + + {azureOrganizationAuthToken && ( +
+

Credential Information

+
+
ID: {azureOrganizationAuthToken.id}
+
Type: {azureOrganizationAuthToken.token_type}
+
+ Created:{" "} + {new Date( + azureOrganizationAuthToken.created_at, + ).toLocaleDateString()} +
+
+
+ )} +
+ ); +} diff --git a/skyvern-frontend/src/hooks/useAzureClientCredentialToken.ts b/skyvern-frontend/src/hooks/useAzureClientCredentialToken.ts new file mode 100644 index 00000000..0f097451 --- /dev/null +++ b/skyvern-frontend/src/hooks/useAzureClientCredentialToken.ts @@ -0,0 +1,66 @@ +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { getClient } from "@/api/AxiosClient"; +import { useCredentialGetter } from "./useCredentialGetter"; +import { + AzureClientSecretCredentialResponse, + AzureOrganizationAuthToken, + CreateAzureClientSecretCredentialRequest, +} from "@/api/types"; +import { useToast } from "@/components/ui/use-toast"; + +export function useAzureClientCredentialToken() { + const credentialGetter = useCredentialGetter(); + const queryClient = useQueryClient(); + const { toast } = useToast(); + + const { data: azureOrganizationAuthToken, isLoading } = + useQuery({ + queryKey: ["azureOrganizationAuthToken"], + queryFn: async () => { + const client = await getClient(credentialGetter, "sans-api-v1"); + return await client + .get("/credentials/azure_credential/get") + .then((response) => response.data.token) + .catch(() => null); + }, + }); + + const createOrUpdateTokenMutation = useMutation({ + mutationFn: async (data: CreateAzureClientSecretCredentialRequest) => { + const client = await getClient(credentialGetter, "sans-api-v1"); + return await client + .post("/credentials/azure_credential/create", data) + .then( + (response) => response.data as AzureClientSecretCredentialResponse, + ); + }, + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: ["azureOrganizationAuthToken"], + }); + toast({ + title: "Success", + description: "Azure Client Secret Credential updated successfully", + }); + }, + onError: (error: unknown) => { + const message = + (error as { response?: { data?: { detail?: string } } })?.response?.data + ?.detail || + (error as Error)?.message || + "Failed to update Azure Client Secret Credential"; + toast({ + title: "Error", + description: message, + variant: "destructive", + }); + }, + }); + + return { + azureOrganizationAuthToken, + isLoading, + createOrUpdateToken: createOrUpdateTokenMutation.mutate, + isUpdating: createOrUpdateTokenMutation.isPending, + }; +} diff --git a/skyvern-frontend/src/routes/settings/Settings.tsx b/skyvern-frontend/src/routes/settings/Settings.tsx index 1d28778e..38f48765 100644 --- a/skyvern-frontend/src/routes/settings/Settings.tsx +++ b/skyvern-frontend/src/routes/settings/Settings.tsx @@ -17,6 +17,7 @@ import { import { envCredential } from "@/util/env"; import { HiddenCopyableInput } from "@/components/ui/hidden-copyable-input"; import { OnePasswordTokenForm } from "@/components/OnePasswordTokenForm"; +import { AzureClientSecretCredentialTokenForm } from "@/components/AzureClientSecretCredentialTokenForm"; function Settings() { const { environment, organization, setEnvironment, setOrganization } = @@ -87,6 +88,15 @@ function Settings() { + + + Azure Integration + Manage your Azure integration + + + + + ); } diff --git a/skyvern/forge/sdk/api/azure.py b/skyvern/forge/sdk/api/azure.py index d4073523..37e6cff7 100644 --- a/skyvern/forge/sdk/api/azure.py +++ b/skyvern/forge/sdk/api/azure.py @@ -1,39 +1,24 @@ import structlog -from azure.identity.aio import DefaultAzureCredential +from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential from azure.keyvault.secrets.aio import SecretClient from azure.storage.blob.aio import BlobServiceClient -from skyvern.exceptions import AzureConfigurationError +from skyvern.forge.sdk.schemas.organizations import AzureClientSecretCredential LOG = structlog.get_logger() -class AsyncAzureClient: - def __init__(self, storage_account_name: str | None, storage_account_key: str | None): - self.storage_account_name = storage_account_name - self.storage_account_key = storage_account_key - - if storage_account_name and storage_account_key: - self.blob_service_client = BlobServiceClient( - account_url=f"https://{storage_account_name}.blob.core.windows.net", - credential=storage_account_key, - ) - else: - self.blob_service_client = None - - self.credential = DefaultAzureCredential() - - async def get_secret(self, secret_name: str, vault_name: str | None = None) -> str | None: - vault_subdomain = vault_name or self.storage_account_name - if not vault_subdomain: - raise AzureConfigurationError("Missing vault") +class AsyncAzureVaultClient: + def __init__(self, credential: ClientSecretCredential | DefaultAzureCredential): + self.credential = credential + async def get_secret(self, secret_name: str, vault_name: str) -> str | None: try: # Azure Key Vault URL format: https://.vault.azure.net # Assuming the secret_name is actually the Key Vault URL and the secret name # This needs to be clarified or passed as separate parameters # For now, let's assume secret_name is the actual secret name and Key Vault URL is in settings. - key_vault_url = f"https://{vault_subdomain}.vault.azure.net" # Placeholder, adjust as needed + key_vault_url = f"https://{vault_name}.vault.azure.net" # Placeholder, adjust as needed secret_client = SecretClient(vault_url=key_vault_url, credential=self.credential) secret = await secret_client.get_secret(secret_name) return secret.value @@ -43,10 +28,34 @@ class AsyncAzureClient: finally: await self.credential.close() - async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None: - if not self.blob_service_client: - raise AzureConfigurationError("Storage is not configured") + async def close(self) -> None: + await self.credential.close() + @classmethod + def create_default(cls) -> "AsyncAzureVaultClient": + return cls(DefaultAzureCredential()) + + @classmethod + def create_from_client_secret( + cls, + credential: AzureClientSecretCredential, + ) -> "AsyncAzureVaultClient": + cred = ClientSecretCredential( + tenant_id=credential.tenant_id, + client_id=credential.client_id, + client_secret=credential.client_secret, + ) + return cls(cred) + + +class AsyncAzureStorageClient: + def __init__(self, storage_account_name: str, storage_account_key: str): + self.blob_service_client = BlobServiceClient( + account_url=f"https://{storage_account_name}.blob.core.windows.net", + credential=storage_account_key, + ) + + async def upload_file_from_path(self, container_name: str, blob_name: str, file_path: str) -> None: try: container_client = self.blob_service_client.get_container_client(container_name) # Create the container if it doesn't exist @@ -69,4 +78,3 @@ class AsyncAzureClient: async def close(self) -> None: await self.blob_service_client.close() - await self.credential.close() diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 8704b955..3b7dbd81 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timedelta, timezone -from typing import Any, List, Sequence +from typing import Any, List, Literal, Sequence, overload import structlog from sqlalchemy import and_, asc, delete, distinct, func, or_, pool, select, tuple_, update @@ -79,7 +79,12 @@ from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection -from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken +from skyvern.forge.sdk.schemas.organizations import ( + AzureClientSecretCredential, + AzureOrganizationAuthToken, + Organization, + OrganizationAuthToken, +) from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession from skyvern.forge.sdk.schemas.runs import Run from skyvern.forge.sdk.schemas.task_generations import TaskGeneration @@ -865,11 +870,25 @@ class AgentDB: await session.refresh(organization) return Organization.model_validate(organization) + @overload + async def get_valid_org_auth_token( + self, + organization_id: str, + token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account], + ) -> OrganizationAuthToken | None: ... + + @overload + async def get_valid_org_auth_token( + self, + organization_id: str, + token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential], + ) -> AzureOrganizationAuthToken | None: ... + async def get_valid_org_auth_token( self, organization_id: str, token_type: OrganizationAuthTokenType, - ) -> OrganizationAuthToken | None: + ) -> OrganizationAuthToken | AzureOrganizationAuthToken | None: try: async with self.Session() as session: if token := ( @@ -881,7 +900,7 @@ class AgentDB: .order_by(OrganizationAuthTokenModel.created_at.desc()) ) ).first(): - return await convert_to_organization_auth_token(token) + return await convert_to_organization_auth_token(token, token_type) else: return None except SQLAlchemyError: @@ -907,7 +926,7 @@ class AgentDB: .order_by(OrganizationAuthTokenModel.created_at.desc()) ) ).all() - return [await convert_to_organization_auth_token(token) for token in tokens] + return [await convert_to_organization_auth_token(token, token_type) for token in tokens] except SQLAlchemyError: LOG.error("SQLAlchemyError", exc_info=True) raise @@ -941,7 +960,7 @@ class AgentDB: if valid is not None: query = query.filter_by(valid=valid) if token_obj := (await session.scalars(query)).first(): - return await convert_to_organization_auth_token(token_obj) + return await convert_to_organization_auth_token(token_obj, token_type) else: return None except SQLAlchemyError: @@ -955,14 +974,22 @@ class AgentDB: self, organization_id: str, token_type: OrganizationAuthTokenType, - token: str, + token: str | AzureClientSecretCredential, encrypted_method: EncryptMethod | None = None, ) -> OrganizationAuthToken: - plaintext_token = token + if token_type is OrganizationAuthTokenType.azure_client_secret_credential: + if not isinstance(token, AzureClientSecretCredential): + raise TypeError("Expected AzureClientSecretCredential for this token_type") + plaintext_token = token.model_dump_json() + else: + if not isinstance(token, str): + raise TypeError("Expected str token for this token_type") + plaintext_token = token + encrypted_token = "" if encrypted_method is not None: - encrypted_token = await encryptor.encrypt(token, encrypted_method) + encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method) plaintext_token = "" async with self.Session() as session: @@ -977,7 +1004,7 @@ class AgentDB: await session.commit() await session.refresh(auth_token) - return await convert_to_organization_auth_token(auth_token) + return await convert_to_organization_auth_token(auth_token, token_type) async def invalidate_org_auth_tokens( self, diff --git a/skyvern/forge/sdk/db/enums.py b/skyvern/forge/sdk/db/enums.py index eea33f35..1718358d 100644 --- a/skyvern/forge/sdk/db/enums.py +++ b/skyvern/forge/sdk/db/enums.py @@ -4,6 +4,7 @@ from enum import StrEnum class OrganizationAuthTokenType(StrEnum): api = "api" onepassword_service_account = "onepassword_service_account" + azure_client_secret_credential = "azure_client_secret_credential" class TaskType(StrEnum): diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index ef21520f..12aafa79 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -30,7 +30,12 @@ from skyvern.forge.sdk.db.models import ( from skyvern.forge.sdk.encrypt import encryptor from skyvern.forge.sdk.encrypt.base import EncryptMethod from skyvern.forge.sdk.models import Step, StepStatus -from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken +from skyvern.forge.sdk.schemas.organizations import ( + AzureClientSecretCredential, + AzureOrganizationAuthToken, + Organization, + OrganizationAuthToken, +) from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock from skyvern.forge.sdk.workflow.models.parameter import ( @@ -195,21 +200,33 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization: async def convert_to_organization_auth_token( - org_auth_token: OrganizationAuthTokenModel, -) -> OrganizationAuthToken: + org_auth_token: OrganizationAuthTokenModel, token_type: OrganizationAuthTokenType +) -> OrganizationAuthToken | AzureOrganizationAuthToken: token = org_auth_token.token if org_auth_token.encrypted_token and org_auth_token.encrypted_method: token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method)) - return OrganizationAuthToken( - id=org_auth_token.id, - organization_id=org_auth_token.organization_id, - token_type=OrganizationAuthTokenType(org_auth_token.token_type), - token=token, - valid=org_auth_token.valid, - created_at=org_auth_token.created_at, - modified_at=org_auth_token.modified_at, - ) + if token_type == OrganizationAuthTokenType.azure_client_secret_credential: + credential = AzureClientSecretCredential.model_validate_json(token) + return AzureOrganizationAuthToken( + id=org_auth_token.id, + organization_id=org_auth_token.organization_id, + token_type=OrganizationAuthTokenType(org_auth_token.token_type), + credential=credential, + valid=org_auth_token.valid, + created_at=org_auth_token.created_at, + modified_at=org_auth_token.modified_at, + ) + else: + return OrganizationAuthToken( + id=org_auth_token.id, + organization_id=org_auth_token.organization_id, + token_type=OrganizationAuthTokenType(org_auth_token.token_type), + token=token, + valid=org_auth_token.valid, + created_at=org_auth_token.created_at, + modified_at=org_auth_token.modified_at, + ) def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact: diff --git a/skyvern/forge/sdk/routes/credentials.py b/skyvern/forge/sdk/routes/credentials.py index 16167599..9375c2e8 100644 --- a/skyvern/forge/sdk/routes/credentials.py +++ b/skyvern/forge/sdk/routes/credentials.py @@ -21,6 +21,8 @@ from skyvern.forge.sdk.schemas.credentials import ( PasswordCredentialResponse, ) from skyvern.forge.sdk.schemas.organizations import ( + AzureClientSecretCredentialResponse, + CreateAzureClientSecretCredentialRequest, CreateOnePasswordTokenRequest, CreateOnePasswordTokenResponse, Organization, @@ -478,3 +480,106 @@ async def update_onepassword_token( status_code=500, detail=f"Failed to create or update OnePassword service account token: {str(e)}", ) + + +@base_router.get( + "/credentials/azure_credential/get", + response_model=AzureClientSecretCredentialResponse, + summary="Get Azure Client Secret Credential", + description="Retrieves the current Azure Client Secret Credential for the organization.", + include_in_schema=False, +) +@base_router.get( + "/credentials/azure_credential/get/", + response_model=AzureClientSecretCredentialResponse, + include_in_schema=False, +) +async def get_azure_client_secret_credential( + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> AzureClientSecretCredentialResponse: + """ + Get the current Azure Client Secret Credential for the organization. + """ + try: + auth_token = await app.DATABASE.get_valid_org_auth_token( + organization_id=current_org.organization_id, + token_type=OrganizationAuthTokenType.azure_client_secret_credential, + ) + if not auth_token: + raise HTTPException( + status_code=404, + detail="No Azure Client Secret Credential found for this organization", + ) + + return AzureClientSecretCredentialResponse(token=auth_token) + + except HTTPException: + raise + except Exception as e: + LOG.error( + "Failed to get Azure Client Secret Credential", + organization_id=current_org.organization_id, + error=str(e), + exc_info=True, + ) + raise HTTPException( + status_code=500, + detail=f"Failed to get Azure Client Secret Credential: {str(e)}", + ) + + +@base_router.post( + "/credentials/azure_credential/create", + response_model=AzureClientSecretCredentialResponse, + summary="Create or update Azure Client Secret Credential", + description="Creates or updates a Azure Client Secret Credential for the current organization. Only one valid record is allowed per organization.", + include_in_schema=False, +) +@base_router.post( + "/credentials/azure_credential/create/", + response_model=AzureClientSecretCredentialResponse, + include_in_schema=False, +) +async def update_azure_client_secret_credential( + request: CreateAzureClientSecretCredentialRequest, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> AzureClientSecretCredentialResponse: + """ + Create or update an Azure Client Secret Credential for the current organization. + + This endpoint ensures only one valid Azure Client Secret Credential exists per organization. + If a valid token already exists, it will be invalidated before creating the new one. + """ + try: + # Invalidate any existing valid Azure Client Secret Credential for this organization + await app.DATABASE.invalidate_org_auth_tokens( + organization_id=current_org.organization_id, + token_type=OrganizationAuthTokenType.azure_client_secret_credential, + ) + + # Create the new Azure token + auth_token = await app.DATABASE.create_org_auth_token( + organization_id=current_org.organization_id, + token_type=OrganizationAuthTokenType.azure_client_secret_credential, + token=request.credential, + ) + + LOG.info( + "Created or updated Azure Client Secret Credential", + organization_id=current_org.organization_id, + token_id=auth_token.id, + ) + + return AzureClientSecretCredentialResponse(token=auth_token) + + except Exception as e: + LOG.error( + "Failed to create or update Azure Client Secret Credential", + organization_id=current_org.organization_id, + error=str(e), + exc_info=True, + ) + raise HTTPException( + status_code=500, + detail=f"Failed to create or update Azure Client Secret Credential: {str(e)}", + ) diff --git a/skyvern/forge/sdk/schemas/organizations.py b/skyvern/forge/sdk/schemas/organizations.py index bd6dff18..faaf20f8 100644 --- a/skyvern/forge/sdk/schemas/organizations.py +++ b/skyvern/forge/sdk/schemas/organizations.py @@ -21,16 +21,31 @@ class Organization(BaseModel): modified_at: datetime -class OrganizationAuthToken(BaseModel): +class OrganizationAuthTokenBase(BaseModel): id: str organization_id: str token_type: OrganizationAuthTokenType - token: str valid: bool created_at: datetime modified_at: datetime +class OrganizationAuthToken(OrganizationAuthTokenBase): + token: str + + +class AzureClientSecretCredential(BaseModel): + tenant_id: str + client_id: str + client_secret: str + + +class AzureOrganizationAuthToken(OrganizationAuthTokenBase): + """Represents OrganizationAuthToken for Azure; defined by 3 fields: tenant_id, client_id, and client_secret""" + + credential: AzureClientSecretCredential + + class CreateOnePasswordTokenRequest(BaseModel): """Request model for creating or updating a 1Password service account token.""" @@ -50,6 +65,21 @@ class CreateOnePasswordTokenResponse(BaseModel): ) +class AzureClientSecretCredentialResponse(BaseModel): + """Response model for Azure ClientSecretCredential operations.""" + + token: AzureOrganizationAuthToken = Field( + ..., + description="The created or updated Azure ClientSecretCredential", + ) + + +class CreateAzureClientSecretCredentialRequest(BaseModel): + """Request model for creating or updating an Azure ClientSecretCredential.""" + + credential: AzureClientSecretCredential + + class GetOrganizationsResponse(BaseModel): organizations: list[Organization] diff --git a/skyvern/forge/sdk/workflow/context_manager.py b/skyvern/forge/sdk/workflow/context_manager.py index 290ce809..99a7f0f5 100644 --- a/skyvern/forge/sdk/workflow/context_manager.py +++ b/skyvern/forge/sdk/workflow/context_manager.py @@ -7,6 +7,7 @@ from onepassword.client import Client as OnePasswordClient from skyvern.config import settings from skyvern.exceptions import ( + AzureConfigurationError, BitwardenBaseError, CredentialParameterNotFoundError, SkyvernException, @@ -14,7 +15,7 @@ from skyvern.exceptions import ( ) from skyvern.forge import app from skyvern.forge.sdk.api.aws import AsyncAWSClient -from skyvern.forge.sdk.api.azure import AsyncAzureClient +from skyvern.forge.sdk.api.azure import AsyncAzureVaultClient from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.schemas.credentials import PasswordCredential from skyvern.forge.sdk.schemas.organizations import Organization @@ -56,7 +57,6 @@ class WorkflowRunContext: async def init( cls, aws_client: AsyncAWSClient, - azure_client: AsyncAzureClient, organization: Organization, workflow_parameter_tuples: list[tuple[WorkflowParameter, "WorkflowRunParameter"]], workflow_output_parameters: list[OutputParameter], @@ -71,7 +71,7 @@ class WorkflowRunContext: block_outputs: dict[str, Any] | None = None, ) -> Self: # key is label name - workflow_run_context = cls(aws_client=aws_client, azure_client=azure_client) + workflow_run_context = cls(aws_client=aws_client) for parameter, run_parameter in workflow_parameter_tuples: if parameter.workflow_parameter_type == WorkflowParameterType.CREDENTIAL_ID: await workflow_run_context.register_secret_workflow_parameter_value( @@ -109,7 +109,9 @@ class WorkflowRunContext: secret_parameter, organization ) elif isinstance(secret_parameter, AzureVaultCredentialParameter): - await workflow_run_context.register_azure_vault_credential_parameter_value(secret_parameter) + await workflow_run_context.register_azure_vault_credential_parameter_value( + secret_parameter, organization + ) elif isinstance(secret_parameter, BitwardenLoginCredentialParameter): await workflow_run_context.register_bitwarden_login_credential_parameter_value( secret_parameter, organization @@ -131,13 +133,12 @@ class WorkflowRunContext: return workflow_run_context - def __init__(self, aws_client: AsyncAWSClient, azure_client: AsyncAzureClient) -> None: + def __init__(self, aws_client: AsyncAWSClient) -> None: self.blocks_metadata: dict[str, BlockMetadata] = {} self.parameters: dict[str, PARAMETER_TYPE] = {} self.values: dict[str, Any] = {} self.secrets: dict[str, Any] = {} self._aws_client = aws_client - self._azure_client = azure_client def get_parameter(self, key: str) -> Parameter: return self.parameters[key] @@ -343,10 +344,16 @@ class WorkflowRunContext: self, parameter: AzureSecretParameter, ) -> None: + vault_name = settings.AZURE_STORAGE_ACCOUNT_NAME + if vault_name is None: + LOG.error("AZURE_STORAGE_ACCOUNT_NAME is not configured, cannot register Azure secret parameter value") + raise AzureConfigurationError("AZURE_STORAGE_ACCOUNT_NAME is not configured") + # If the parameter is an Azure secret, fetch the secret value and store it in the secrets dict # The value of the parameter will be the random secret id with format `secret_`. # We'll replace the random secret id with the actual secret value when we need to use it. - secret_value = await self._azure_client.get_secret(parameter.azure_key) + azure_vault_client = AsyncAzureVaultClient.create_default() + secret_value = await azure_vault_client.get_secret(parameter.azure_key, vault_name) if secret_value is not None: random_secret_id = self.generate_random_secret_id() self.secrets[random_secret_id] = secret_value @@ -491,7 +498,11 @@ class WorkflowRunContext: LOG.error(f"Failed to get secret from Bitwarden. Error: {e}") raise e - async def register_azure_vault_credential_parameter_value(self, parameter: AzureVaultCredentialParameter) -> None: + async def register_azure_vault_credential_parameter_value( + self, + parameter: AzureVaultCredentialParameter, + organization: Organization, + ) -> None: vault_name = self._resolve_parameter_value(parameter.vault_name) if not vault_name: raise ValueError("Azure Vault Name is missing") @@ -504,18 +515,28 @@ class WorkflowRunContext: totp_secret_key = self._resolve_parameter_value(parameter.totp_secret_key) - secret_login = await self._azure_client.get_secret(username_key, vault_name) - secret_password = await self._azure_client.get_secret(password_key, vault_name) + azure_vault_client = await self._get_azure_vault_client_for_organization(organization) + + secret_username = await azure_vault_client.get_secret(username_key, vault_name) + if not secret_username: + raise ValueError(f"Azure Vault username not found by key: {username_key}") + + secret_password = await azure_vault_client.get_secret(password_key, vault_name) + if not secret_password: + raise ValueError(f"Azure Vault password not found by key: {password_key}") + if totp_secret_key: - totp_secret = await self._azure_client.get_secret(totp_secret_key, vault_name) + totp_secret = await azure_vault_client.get_secret(totp_secret_key, vault_name) + if not totp_secret: + raise ValueError(f"Azure Vault TOTP not found by key: {totp_secret_key}") else: totp_secret = None - if secret_login is not None and secret_password is not None: + if secret_username is not None and secret_password is not None: random_secret_id = self.generate_random_secret_id() # login secret username_secret_id = f"{random_secret_id}_username" - self.secrets[username_secret_id] = secret_login + self.secrets[username_secret_id] = secret_username # password secret password_secret_id = f"{random_secret_id}_password" self.secrets[password_secret_id] = secret_password @@ -895,10 +916,21 @@ class WorkflowRunContext: else: return jinja_sandbox_env.from_string(parameter_value).render(self.values) + @staticmethod + async def _get_azure_vault_client_for_organization(organization: Organization) -> AsyncAzureVaultClient: + org_auth_token = await app.DATABASE.get_valid_org_auth_token( + organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential + ) + if org_auth_token: + azure_vault_client = AsyncAzureVaultClient.create_from_client_secret(org_auth_token.credential) + else: + # Use the DefaultAzureCredential if not configured on organization level + azure_vault_client = AsyncAzureVaultClient.create_default() + return azure_vault_client + class WorkflowContextManager: aws_client: AsyncAWSClient - azure_client: AsyncAzureClient workflow_run_contexts: dict[str, WorkflowRunContext] parameters: dict[str, PARAMETER_TYPE] @@ -907,10 +939,6 @@ class WorkflowContextManager: def __init__(self) -> None: self.aws_client = AsyncAWSClient() - self.azure_client = AsyncAzureClient( - storage_account_name=settings.AZURE_STORAGE_ACCOUNT_NAME, - storage_account_key=settings.AZURE_STORAGE_ACCOUNT_KEY, - ) self.workflow_run_contexts = {} def _validate_workflow_run_context(self, workflow_run_id: str) -> None: @@ -935,7 +963,6 @@ class WorkflowContextManager: ) -> WorkflowRunContext: workflow_run_context = await WorkflowRunContext.init( self.aws_client, - self.azure_client, organization, workflow_parameter_tuples, workflow_output_parameters, diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 4d9bfd27..24541687 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -34,6 +34,7 @@ from skyvern.constants import ( MAX_UPLOAD_FILE_COUNT, ) from skyvern.exceptions import ( + AzureConfigurationError, ContextParameterValueNotFound, MissingBrowserState, MissingBrowserStatePage, @@ -44,7 +45,7 @@ from skyvern.exceptions import ( from skyvern.forge import app from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.api.aws import AsyncAWSClient -from skyvern.forge.sdk.api.azure import AsyncAzureClient +from skyvern.forge.sdk.api.azure import AsyncAzureStorageClient from skyvern.forge.sdk.api.files import ( calculate_sha256_for_file, create_named_temporary_file, @@ -2061,7 +2062,10 @@ class FileUploadBlock(Block): workflow_run_context.get_original_secret_value_or_none(self.azure_storage_account_key) or self.azure_storage_account_key ) - azure_client = AsyncAzureClient( + if actual_azure_storage_account_name is None or actual_azure_storage_account_key is None: + raise AzureConfigurationError("Azure Storage is not configured") + + azure_client = AsyncAzureStorageClient( storage_account_name=actual_azure_storage_account_name, storage_account_key=actual_azure_storage_account_key, )