From 32e6aed8ceeafa8f05b95716ff689ea94e0dc717 Mon Sep 17 00:00:00 2001 From: Stanislav Novosad Date: Fri, 10 Oct 2025 10:10:18 -0600 Subject: [PATCH] Migrate credentials to Azure Key Vault (#3681) --- ..._1533-d648e2df239e_creds_in_azure_vault.py | 41 ++++ skyvern/config.py | 10 + skyvern/forge/app.py | 17 +- skyvern/forge/sdk/api/azure.py | 51 ++++- skyvern/forge/sdk/db/client.py | 22 +- skyvern/forge/sdk/db/models.py | 6 + skyvern/forge/sdk/routes/credentials.py | 42 +++- skyvern/forge/sdk/schemas/credentials.py | 11 +- .../azure_credential_vault_service.py | 188 ++++++++++++++++++ .../bitwarden_credential_service.py | 8 +- .../credential/credential_vault_service.py | 37 ++++ skyvern/forge/sdk/workflow/context_manager.py | 57 +++--- 12 files changed, 438 insertions(+), 52 deletions(-) create mode 100644 alembic/versions/2025_10_10_1533-d648e2df239e_creds_in_azure_vault.py create mode 100644 skyvern/forge/sdk/services/credential/azure_credential_vault_service.py diff --git a/alembic/versions/2025_10_10_1533-d648e2df239e_creds_in_azure_vault.py b/alembic/versions/2025_10_10_1533-d648e2df239e_creds_in_azure_vault.py new file mode 100644 index 00000000..cb4f2664 --- /dev/null +++ b/alembic/versions/2025_10_10_1533-d648e2df239e_creds_in_azure_vault.py @@ -0,0 +1,41 @@ +"""Creds in Azure Vault + +Revision ID: d648e2df239e +Revises: 7cd6f55be8d2 +Create Date: 2025-10-10 15:33:02.700316+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d648e2df239e" +down_revision: Union[str, None] = "7cd6f55be8d2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("credentials", sa.Column("vault_type", sa.String(), nullable=True)) + op.add_column("credentials", sa.Column("username", sa.String(), nullable=True)) + op.add_column("credentials", sa.Column("totp_identifier", sa.String(), nullable=True)) + op.add_column("credentials", sa.Column("card_last4", sa.String(), nullable=True)) + op.add_column("credentials", sa.Column("card_brand", sa.String(), nullable=True)) + op.add_column("organization_bitwarden_collections", sa.Column("deleted_at", sa.DateTime(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("organization_bitwarden_collections", "deleted_at") + op.drop_column("credentials", "card_brand") + op.drop_column("credentials", "card_last4") + op.drop_column("credentials", "totp_identifier") + op.drop_column("credentials", "username") + op.drop_column("credentials", "vault_type") + # ### end Alembic commands ### diff --git a/skyvern/config.py b/skyvern/config.py index d2c8ce1e..b90450ff 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -295,6 +295,16 @@ class Settings(BaseSettings): BITWARDEN_EMAIL: str | None = None OP_SERVICE_ACCOUNT_TOKEN: str | None = None + # Where credentials are stored: bitwarden or azure_vault + CREDENTIAL_VAULT_TYPE: str = "bitwarden" + + # Azure Setting + AZURE_TENANT_ID: str | None = None + AZURE_CLIENT_ID: str | None = None + AZURE_CLIENT_SECRET: str | None = None + # The Azure Key Vault name to store credentials + AZURE_CREDENTIAL_VAULT: str | None = None + # Skyvern Auth Bitwarden Settings SKYVERN_AUTH_BITWARDEN_CLIENT_ID: str | None = None SKYVERN_AUTH_BITWARDEN_CLIENT_SECRET: str | None = None diff --git a/skyvern/forge/app.py b/skyvern/forge/app.py index 27c4f8b8..9d7773af 100644 --- a/skyvern/forge/app.py +++ b/skyvern/forge/app.py @@ -13,7 +13,9 @@ from skyvern.forge.sdk.artifact.storage.s3 import S3Storage from skyvern.forge.sdk.cache.factory import CacheFactory from skyvern.forge.sdk.db.client import AgentDB from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider +from skyvern.forge.sdk.schemas.credentials import CredentialVaultType from skyvern.forge.sdk.schemas.organizations import Organization +from skyvern.forge.sdk.services.credential.azure_credential_vault_service import AzureCredentialVaultService from skyvern.forge.sdk.services.credential.bitwarden_credential_service import BitwardenCredentialVaultService from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService from skyvern.forge.sdk.settings_manager import SettingsManager @@ -96,7 +98,20 @@ WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager() WORKFLOW_SERVICE = WorkflowService() AGENT_FUNCTION = AgentFunction() PERSISTENT_SESSIONS_MANAGER = PersistentSessionsManager(database=DATABASE) -CREDENTIAL_VAULT_SERVICE: CredentialVaultService = BitwardenCredentialVaultService() + +BITWARDEN_CREDENTIAL_VAULT_SERVICE: BitwardenCredentialVaultService = BitwardenCredentialVaultService() +AZURE_CREDENTIAL_VAULT_SERVICE: AzureCredentialVaultService | None = None +if SettingsManager.get_settings().AZURE_CREDENTIAL_VAULT: + AZURE_CREDENTIAL_VAULT_SERVICE = AzureCredentialVaultService( + tenant_id=SettingsManager.get_settings().AZURE_TENANT_ID, # type: ignore + client_id=SettingsManager.get_settings().AZURE_CLIENT_ID, # type: ignore + client_secret=SettingsManager.get_settings().AZURE_CLIENT_SECRET, # type: ignore + vault_name=SettingsManager.get_settings().AZURE_CREDENTIAL_VAULT, # type: ignore + ) +CREDENTIAL_VAULT_SERVICES: dict[str, CredentialVaultService | None] = { + CredentialVaultType.BITWARDEN: BITWARDEN_CREDENTIAL_VAULT_SERVICE, + CredentialVaultType.AZURE_VAULT: AZURE_CREDENTIAL_VAULT_SERVICE, +} scrape_exclude: ScrapeExcludeFunc | None = None authentication_function: Callable[[str], Awaitable[Organization]] | None = None diff --git a/skyvern/forge/sdk/api/azure.py b/skyvern/forge/sdk/api/azure.py index 37e6cff7..b0121f59 100644 --- a/skyvern/forge/sdk/api/azure.py +++ b/skyvern/forge/sdk/api/azure.py @@ -1,3 +1,5 @@ +from typing import Self + import structlog from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential from azure.keyvault.secrets.aio import SecretClient @@ -9,24 +11,57 @@ LOG = structlog.get_logger() class AsyncAzureVaultClient: - def __init__(self, credential: ClientSecretCredential | DefaultAzureCredential): + def __init__(self, credential: ClientSecretCredential | DefaultAzureCredential) -> None: self.credential = credential + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object + ) -> None: + await self.credential.close() + async def get_secret(self, secret_name: str, vault_name: str) -> str | None: + secret_client = await self._get_secret_client(vault_name) try: - # Azure Key Vault URL format: https://.vault.azure.net - # Assuming the secret_name is actually the Key Vault URL and the secret name - # This needs to be clarified or passed as separate parameters - # For now, let's assume secret_name is the actual secret name and Key Vault URL is in settings. - key_vault_url = f"https://{vault_name}.vault.azure.net" # Placeholder, adjust as needed - secret_client = SecretClient(vault_url=key_vault_url, credential=self.credential) secret = await secret_client.get_secret(secret_name) return secret.value except Exception as e: LOG.exception("Failed to get secret from Azure Key Vault.", secret_name=secret_name, error=e) return None finally: - await self.credential.close() + await secret_client.close() + + async def create_secret(self, secret_name: str, secret_value: str, vault_name: str) -> str: + secret_client = await self._get_secret_client(vault_name) + try: + secret = await secret_client.set_secret(secret_name, secret_value) + return secret.name + except Exception as e: + LOG.exception("Failed to create secret from Azure Key Vault.", secret_name=secret_name, error=e) + raise e + finally: + await secret_client.close() + + async def delete_secret(self, secret_name: str, vault_name: str) -> str: + secret_client = await self._get_secret_client(vault_name) + try: + secret = await secret_client.delete_secret(secret_name) + return secret.name + except Exception as e: + LOG.exception("Failed to delete secret from Azure Key Vault.", secret_name=secret_name, error=e) + raise e + finally: + await secret_client.close() + + async def _get_secret_client(self, vault_name: str) -> SecretClient: + # Azure Key Vault URL format: https://.vault.azure.net + # Assuming the secret_name is actually the Key Vault URL and the secret name + # This needs to be clarified or passed as separate parameters + # For now, let's assume secret_name is the actual secret name and Key Vault URL is in settings. + key_vault_url = f"https://{vault_name}.vault.azure.net" # Placeholder, adjust as needed + return SecretClient(vault_url=key_vault_url, credential=self.credential) async def close(self) -> None: await self.credential.close() diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 29010002..623973f9 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -77,7 +77,7 @@ from skyvern.forge.sdk.encrypt.base import EncryptMethod from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs from skyvern.forge.sdk.models import Step, StepStatus from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion -from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType +from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType, CredentialVaultType from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection from skyvern.forge.sdk.schemas.organizations import ( @@ -3628,19 +3628,27 @@ class AgentDB: async def create_credential( self, - name: str, - credential_type: CredentialType, organization_id: str, + name: str, + vault_type: CredentialVaultType, item_id: str, - totp_type: str = "none", + credential_type: CredentialType, + username: str | None, + totp_type: str, + card_last4: str | None, + card_brand: str | None, ) -> Credential: async with self.Session() as session: credential = CredentialModel( organization_id=organization_id, name=name, - credential_type=credential_type, + vault_type=vault_type, item_id=item_id, + credential_type=credential_type, + username=username, totp_type=totp_type, + card_last4=card_last4, + card_brand=card_brand, ) session.add(credential) await session.commit() @@ -3733,7 +3741,9 @@ class AgentDB: async with self.Session() as session: organization_bitwarden_collection = ( await session.scalars( - select(OrganizationBitwardenCollectionModel).filter_by(organization_id=organization_id) + select(OrganizationBitwardenCollectionModel) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) ) ).first() if organization_bitwarden_collection: diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 003b1fe2..4d82a3a2 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -807,6 +807,7 @@ class OrganizationBitwardenCollectionModel(Base): created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) + deleted_at = Column(DateTime, nullable=True) class CredentialModel(Base): @@ -814,11 +815,16 @@ class CredentialModel(Base): credential_id = Column(String, primary_key=True, default=generate_credential_id) organization_id = Column(String, nullable=False) + vault_type = Column(String, nullable=True) item_id = Column(String, nullable=True) name = Column(String, nullable=False) credential_type = Column(String, nullable=False) + username = Column(String, nullable=True) totp_type = Column(String, nullable=False, default="none") + totp_identifier = Column(String, nullable=True, default=None) + card_last4 = Column(String, nullable=True) + card_brand = Column(String, nullable=True) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/routes/credentials.py b/skyvern/forge/sdk/routes/credentials.py index 30719349..b82f8db1 100644 --- a/skyvern/forge/sdk/routes/credentials.py +++ b/skyvern/forge/sdk/routes/credentials.py @@ -1,6 +1,7 @@ import structlog from fastapi import BackgroundTasks, Body, Depends, HTTPException, Path, Query +from skyvern.config import settings from skyvern.forge import app from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType @@ -17,6 +18,7 @@ from skyvern.forge.sdk.schemas.credentials import ( CreateCredentialRequest, CredentialResponse, CredentialType, + CredentialVaultType, CreditCardCredentialResponse, PasswordCredentialResponse, ) @@ -30,6 +32,7 @@ from skyvern.forge.sdk.schemas.organizations import ( from skyvern.forge.sdk.schemas.totp_codes import TOTPCode, TOTPCodeCreate from skyvern.forge.sdk.services import org_auth_service from skyvern.forge.sdk.services.bitwarden import BitwardenService +from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService LOG = structlog.get_logger() @@ -155,12 +158,13 @@ async def create_credential( ), current_org: Organization = Depends(org_auth_service.get_current_org), ) -> CredentialResponse: - credential = await app.CREDENTIAL_VAULT_SERVICE.create_credential( - organization_id=current_org.organization_id, data=data - ) + credential_service = await _get_credential_vault_service(current_org.organization_id) - # Early resyncing the Bitwarden vault - background_tasks.add_task(fetch_credential_item_background, credential.item_id) + credential = await credential_service.create_credential(organization_id=current_org.organization_id, data=data) + + if credential.vault_type == CredentialVaultType.BITWARDEN: + # Early resyncing the Bitwarden vault + background_tasks.add_task(fetch_credential_item_background, credential.item_id) if data.credential_type == CredentialType.PASSWORD: credential_response = PasswordCredentialResponse( @@ -221,7 +225,12 @@ async def delete_credential( if not credential: raise HTTPException(status_code=404, detail=f"Credential not found, credential_id={credential_id}") - await app.CREDENTIAL_VAULT_SERVICE.delete_credential(credential) + vault_type = credential.vault_type or CredentialVaultType.BITWARDEN + credential_service = app.CREDENTIAL_VAULT_SERVICES.get(vault_type) + if not credential_service: + raise HTTPException(status_code=400, detail="Unsupported credential storage type") + + await credential_service.delete_credential(credential) return None @@ -253,7 +262,9 @@ async def get_credential( ), current_org: Organization = Depends(org_auth_service.get_current_org), ) -> CredentialResponse: - return await app.CREDENTIAL_VAULT_SERVICE.get_credential(current_org.organization_id, credential_id) + credential_service = await _get_credential_vault_service(current_org.organization_id) + + return await credential_service.get_credential(current_org.organization_id, credential_id) @legacy_base_router.get("/credentials") @@ -291,7 +302,9 @@ async def get_credentials( openapi_extra={"x-fern-sdk-parameter-name": "page_size"}, ), ) -> list[CredentialResponse]: - return await app.CREDENTIAL_VAULT_SERVICE.get_credentials(current_org.organization_id, page, page_size) + credential_service = await _get_credential_vault_service(current_org.organization_id) + + return await credential_service.get_credentials(current_org.organization_id, page, page_size) @base_router.get( @@ -498,3 +511,16 @@ async def update_azure_client_secret_credential( status_code=500, detail=f"Failed to create or update Azure Client Secret Credential: {str(e)}", ) + + +async def _get_credential_vault_service(organization_id: str) -> CredentialVaultService: + org_collection = await app.DATABASE.get_organization_bitwarden_collection(organization_id) + + if settings.CREDENTIAL_VAULT_TYPE == CredentialVaultType.BITWARDEN or org_collection: + return app.BITWARDEN_CREDENTIAL_VAULT_SERVICE + elif settings.CREDENTIAL_VAULT_TYPE == CredentialVaultType.AZURE_VAULT: + if not app.AZURE_CREDENTIAL_VAULT_SERVICE: + raise HTTPException(status_code=400, detail="Azure Vault credential is not supported") + return app.AZURE_CREDENTIAL_VAULT_SERVICE + else: + raise HTTPException(status_code=400, detail="Credential storage not supported") diff --git a/skyvern/forge/sdk/schemas/credentials.py b/skyvern/forge/sdk/schemas/credentials.py index 8232507e..4690881b 100644 --- a/skyvern/forge/sdk/schemas/credentials.py +++ b/skyvern/forge/sdk/schemas/credentials.py @@ -4,6 +4,11 @@ from enum import StrEnum from pydantic import BaseModel, ConfigDict, Field +class CredentialVaultType(StrEnum): + BITWARDEN = "bitwarden" + AZURE_VAULT = "azure_vault" + + class CredentialType(StrEnum): """Type of credential stored in the system.""" @@ -141,13 +146,17 @@ class Credential(BaseModel): ..., description="ID of the organization that owns the credential", examples=["o_1234567890"] ) name: str = Field(..., description="Name of the credential", examples=["Skyvern Login"]) - credential_type: CredentialType = Field(..., description="Type of the credential. Eg password, credit card, etc.") + vault_type: CredentialVaultType | None = Field(..., description="Where the secret is stored: Bitwarden vs Azure") item_id: str = Field(..., description="ID of the associated credential item", examples=["item_1234567890"]) + credential_type: CredentialType = Field(..., description="Type of the credential. Eg password, credit card, etc.") + username: str | None = Field(..., description="For password credentials: the username") totp_type: TotpType = Field( TotpType.NONE, description="Type of 2FA method used for this credential", examples=[TotpType.AUTHENTICATOR], ) + card_last4: str | None = Field(..., description="For credit_card credentials: the last four digits of the card") + card_brand: str | None = Field(..., description="For credit_card credentials: the card brand") created_at: datetime = Field(..., description="Timestamp when the credential was created") modified_at: datetime = Field(..., description="Timestamp when the credential was last modified") diff --git a/skyvern/forge/sdk/services/credential/azure_credential_vault_service.py b/skyvern/forge/sdk/services/credential/azure_credential_vault_service.py new file mode 100644 index 00000000..866ec234 --- /dev/null +++ b/skyvern/forge/sdk/services/credential/azure_credential_vault_service.py @@ -0,0 +1,188 @@ +import uuid +from typing import Annotated, Literal, Union + +from azure.identity.aio import ClientSecretCredential +from fastapi import HTTPException +from pydantic import BaseModel, Field, TypeAdapter + +from skyvern.forge import app +from skyvern.forge.sdk.api.azure import AsyncAzureVaultClient +from skyvern.forge.sdk.schemas.credentials import ( + CreateCredentialRequest, + Credential, + CredentialItem, + CredentialResponse, + CredentialType, + CredentialVaultType, + CreditCardCredential, + CreditCardCredentialResponse, + PasswordCredential, + PasswordCredentialResponse, +) +from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService + + +class AzureCredentialVaultService(CredentialVaultService): + class _PasswordCredentialDataImage(BaseModel): + type: Literal["password"] + password: str + username: str + totp: str | None = None + + class _CreditCardCredentialDataImage(BaseModel): + type: Literal["credit_card"] + card_number: str + card_cvv: str + card_exp_month: str + card_exp_year: str + card_brand: str + card_holder_name: str + + _CredentialDataImage = Annotated[ + Union[_PasswordCredentialDataImage, _CreditCardCredentialDataImage], Field(discriminator="type") + ] + + def __init__(self, tenant_id: str, client_id: str, client_secret: str, vault_name: str): + self._client = AsyncAzureVaultClient( + ClientSecretCredential( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + ) + ) + self._vault_name = vault_name + + async def create_credential(self, organization_id: str, data: CreateCredentialRequest) -> Credential: + item_id = await self._create_azure_secret_item( + organization_id=organization_id, + credential=data.credential, + ) + + credential = await self._create_db_credential( + organization_id=organization_id, + data=data, + item_id=item_id, + vault_type=CredentialVaultType.AZURE_VAULT, + ) + + return credential + + async def delete_credential( + self, + credential: Credential, + ) -> None: + await app.DATABASE.delete_credential(credential.credential_id, credential.organization_id) + await self.delete_credential_item(credential.item_id) + + async def get_credential(self, organization_id: str, credential_id: str) -> CredentialResponse: + credential = await app.DATABASE.get_credential(credential_id=credential_id, organization_id=organization_id) + if not credential: + raise HTTPException(status_code=404, detail="Credential not found") + + return _convert_to_response(credential) + + async def get_credentials(self, organization_id: str, page: int, page_size: int) -> list[CredentialResponse]: + credentials = await app.DATABASE.get_credentials(organization_id, page=page, page_size=page_size) + return [_convert_to_response(credential) for credential in credentials] + + async def delete_credential_item(self, item_id: str) -> None: + await self._client.delete_secret( + vault_name=self._vault_name, + secret_name=item_id, + ) + + async def get_credential_item(self, db_credential: Credential) -> CredentialItem: + secret_json_str = await self._client.get_secret(secret_name=db_credential.item_id, vault_name=self._vault_name) + if secret_json_str is None: + raise ValueError(f"Azure Credential Vault secret not found for {db_credential.item_id}") + + data = TypeAdapter(AzureCredentialVaultService._CredentialDataImage).validate_json(secret_json_str) + if isinstance(data, AzureCredentialVaultService._PasswordCredentialDataImage): + return CredentialItem( + item_id=db_credential.item_id, + credential=PasswordCredential( + username=data.username, + password=data.password, + totp=data.totp, + totp_type=db_credential.totp_type, + ), + name=db_credential.name, + credential_type=CredentialType.PASSWORD, + ) + elif isinstance(data, AzureCredentialVaultService._CreditCardCredentialDataImage): + return CredentialItem( + item_id=db_credential.item_id, + credential=CreditCardCredential( + card_holder_name=data.card_holder_name, + card_number=data.card_number, + card_exp_month=data.card_exp_month, + card_exp_year=data.card_exp_year, + card_cvv=data.card_cvv, + card_brand=data.card_brand, + ), + name=db_credential.name, + credential_type=CredentialType.CREDIT_CARD, + ) + else: + raise TypeError(f"Invalid credential type: {type(data)}") + + async def _create_azure_secret_item( + self, + organization_id: str, + credential: PasswordCredential | CreditCardCredential, + ) -> str: + if isinstance(credential, PasswordCredential): + data = AzureCredentialVaultService._PasswordCredentialDataImage( + type="password", + username=credential.username, + password=credential.password, + totp=credential.totp, + ) + elif isinstance(credential, CreditCardCredential): + data = AzureCredentialVaultService._CreditCardCredentialDataImage( + type="credit_card", + card_number=credential.card_number, + card_cvv=credential.card_cvv, + card_exp_month=credential.card_exp_month, + card_exp_year=credential.card_exp_year, + card_brand=credential.card_brand, + card_holder_name=credential.card_holder_name, + ) + else: + raise TypeError(f"Invalid credential type: {type(credential)}") + + secret_name = f"{organization_id}-{uuid.uuid4()}".replace("_", "") + secret_value = data.model_dump_json(exclude_none=True) + + return await self._client.create_secret( + vault_name=self._vault_name, + secret_name=secret_name, + secret_value=secret_value, + ) + + +def _convert_to_response(credential: Credential) -> CredentialResponse: + if credential.credential_type == CredentialType.PASSWORD: + credential_response = PasswordCredentialResponse( + username=credential.username or credential.credential_id, + totp_type=credential.totp_type, + ) + return CredentialResponse( + credential=credential_response, + credential_id=credential.credential_id, + credential_type=credential.credential_type, + name=credential.name, + ) + elif credential.credential_type == CredentialType.CREDIT_CARD: + credential_response = CreditCardCredentialResponse( + last_four=credential.card_last4 or "****", + brand=credential.card_brand or "Card Brand", + ) + return CredentialResponse( + credential=credential_response, + credential_id=credential.credential_id, + credential_type=credential.credential_type, + name=credential.name, + ) + else: + raise HTTPException(status_code=400, detail="Credential type not supported") diff --git a/skyvern/forge/sdk/services/credential/bitwarden_credential_service.py b/skyvern/forge/sdk/services/credential/bitwarden_credential_service.py index f41fd975..78caefbb 100644 --- a/skyvern/forge/sdk/services/credential/bitwarden_credential_service.py +++ b/skyvern/forge/sdk/services/credential/bitwarden_credential_service.py @@ -8,6 +8,7 @@ from skyvern.forge.sdk.schemas.credentials import ( CredentialItem, CredentialResponse, CredentialType, + CredentialVaultType, CreditCardCredentialResponse, PasswordCredentialResponse, ) @@ -40,12 +41,11 @@ class BitwardenCredentialVaultService(CredentialVaultService): credential=data.credential, ) - credential = await app.DATABASE.create_credential( + credential = await self._create_db_credential( organization_id=organization_id, + data=data, item_id=item_id, - name=data.name, - credential_type=data.credential_type, - totp_type=data.credential.totp_type if hasattr(data.credential, "totp_type") else "none", + vault_type=CredentialVaultType.BITWARDEN, ) return credential diff --git a/skyvern/forge/sdk/services/credential/credential_vault_service.py b/skyvern/forge/sdk/services/credential/credential_vault_service.py index 2161a673..f39e2b28 100644 --- a/skyvern/forge/sdk/services/credential/credential_vault_service.py +++ b/skyvern/forge/sdk/services/credential/credential_vault_service.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod +from skyvern.forge import app from skyvern.forge.sdk.schemas.credentials import ( CreateCredentialRequest, Credential, CredentialItem, CredentialResponse, + CredentialType, + CredentialVaultType, ) @@ -34,3 +37,37 @@ class CredentialVaultService(ABC): @abstractmethod async def get_credential_item(self, db_credential: Credential) -> CredentialItem: """Retrieve the full credential data from the vault.""" + + @staticmethod + async def _create_db_credential( + organization_id: str, + data: CreateCredentialRequest, + item_id: str, + vault_type: CredentialVaultType, + ) -> Credential: + if data.credential_type == CredentialType.PASSWORD: + return await app.DATABASE.create_credential( + organization_id=organization_id, + name=data.name, + vault_type=vault_type, + item_id=item_id, + credential_type=data.credential_type, + username=data.credential.username, + totp_type=data.credential.totp_type, + card_last4=None, + card_brand=None, + ) + elif data.credential_type == CredentialType.CREDIT_CARD: + return await app.DATABASE.create_credential( + organization_id=organization_id, + name=data.name, + vault_type=vault_type, + item_id=item_id, + credential_type=data.credential_type, + username=None, + totp_type="none", + card_last4=data.credential.card_number[-4:], + card_brand=data.credential.card_brand, + ) + else: + raise Exception(f"Unsupported credential type: {data.credential_type}") diff --git a/skyvern/forge/sdk/workflow/context_manager.py b/skyvern/forge/sdk/workflow/context_manager.py index b54c3053..c94710c2 100644 --- a/skyvern/forge/sdk/workflow/context_manager.py +++ b/skyvern/forge/sdk/workflow/context_manager.py @@ -17,7 +17,7 @@ from skyvern.forge import app from skyvern.forge.sdk.api.aws import AsyncAWSClient from skyvern.forge.sdk.api.azure import AsyncAzureVaultClient from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType -from skyvern.forge.sdk.schemas.credentials import PasswordCredential +from skyvern.forge.sdk.schemas.credentials import CredentialVaultType, PasswordCredential from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.tasks import TaskStatus from skyvern.forge.sdk.services.bitwarden import BitwardenConstants, BitwardenService @@ -295,7 +295,12 @@ class WorkflowRunContext: if db_credential is None: raise CredentialParameterNotFoundError(credential_id) - credential_item = await app.CREDENTIAL_VAULT_SERVICE.get_credential_item(db_credential) + vault_type = db_credential.vault_type or CredentialVaultType.BITWARDEN + credential_service = app.CREDENTIAL_VAULT_SERVICES.get(vault_type) + if credential_service is None: + raise CredentialParameterNotFoundError(credential_id) + + credential_item = await credential_service.get_credential_item(db_credential) credential = credential_item.credential self.parameters[parameter.key] = parameter @@ -347,7 +352,12 @@ class WorkflowRunContext: if db_credential is None: raise CredentialParameterNotFoundError(credential_id) - credential_item = await app.CREDENTIAL_VAULT_SERVICE.get_credential_item(db_credential) + vault_type = db_credential.vault_type or CredentialVaultType.BITWARDEN + credential_service = app.CREDENTIAL_VAULT_SERVICES.get(vault_type) + if credential_service is None: + raise CredentialParameterNotFoundError(credential_id) + + credential_item = await credential_service.get_credential_item(db_credential) credential = credential_item.credential self.parameters[parameter.key] = parameter @@ -398,13 +408,13 @@ class WorkflowRunContext: # If the parameter is an Azure secret, fetch the secret value and store it in the secrets dict # The value of the parameter will be the random secret id with format `secret_`. # We'll replace the random secret id with the actual secret value when we need to use it. - azure_vault_client = AsyncAzureVaultClient.create_default() - secret_value = await azure_vault_client.get_secret(parameter.azure_key, vault_name) - if secret_value is not None: - random_secret_id = self.generate_random_secret_id() - self.secrets[random_secret_id] = secret_value - self.values[parameter.key] = random_secret_id - self.parameters[parameter.key] = parameter + async with AsyncAzureVaultClient.create_default() as azure_vault_client: + secret_value = await azure_vault_client.get_secret(parameter.azure_key, vault_name) + if secret_value is not None: + random_secret_id = self.generate_random_secret_id() + self.secrets[random_secret_id] = secret_value + self.values[parameter.key] = random_secret_id + self.parameters[parameter.key] = parameter async def register_onepassword_credential_parameter_value( self, parameter: OnePasswordCredentialParameter, organization: Organization @@ -562,22 +572,21 @@ class WorkflowRunContext: totp_secret_key = self._resolve_parameter_value(parameter.totp_secret_key) - azure_vault_client = await self._get_azure_vault_client_for_organization(organization) + async with await self._get_azure_vault_client_for_organization(organization) as azure_vault_client: + secret_username = await azure_vault_client.get_secret(username_key, vault_name) + if not secret_username: + raise ValueError(f"Azure Vault username not found by key: {username_key}") - secret_username = await azure_vault_client.get_secret(username_key, vault_name) - if not secret_username: - raise ValueError(f"Azure Vault username not found by key: {username_key}") + secret_password = await azure_vault_client.get_secret(password_key, vault_name) + if not secret_password: + raise ValueError(f"Azure Vault password not found by key: {password_key}") - secret_password = await azure_vault_client.get_secret(password_key, vault_name) - if not secret_password: - raise ValueError(f"Azure Vault password not found by key: {password_key}") - - if totp_secret_key: - totp_secret = await azure_vault_client.get_secret(totp_secret_key, vault_name) - if not totp_secret: - raise ValueError(f"Azure Vault TOTP not found by key: {totp_secret_key}") - else: - totp_secret = None + if totp_secret_key: + totp_secret = await azure_vault_client.get_secret(totp_secret_key, vault_name) + if not totp_secret: + raise ValueError(f"Azure Vault TOTP not found by key: {totp_secret_key}") + else: + totp_secret = None if secret_username is not None and secret_password is not None: random_secret_id = self.generate_random_secret_id()