Migrate credentials to Azure Key Vault (#3681)
This commit is contained in:
committed by
GitHub
parent
c3ce5b1952
commit
32e6aed8ce
@@ -0,0 +1,41 @@
|
||||
"""Creds in Azure Vault
|
||||
|
||||
Revision ID: d648e2df239e
|
||||
Revises: 7cd6f55be8d2
|
||||
Create Date: 2025-10-10 15:33:02.700316+00:00
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d648e2df239e"
|
||||
down_revision: Union[str, None] = "7cd6f55be8d2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("credentials", sa.Column("vault_type", sa.String(), nullable=True))
|
||||
op.add_column("credentials", sa.Column("username", sa.String(), nullable=True))
|
||||
op.add_column("credentials", sa.Column("totp_identifier", sa.String(), nullable=True))
|
||||
op.add_column("credentials", sa.Column("card_last4", sa.String(), nullable=True))
|
||||
op.add_column("credentials", sa.Column("card_brand", sa.String(), nullable=True))
|
||||
op.add_column("organization_bitwarden_collections", sa.Column("deleted_at", sa.DateTime(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("organization_bitwarden_collections", "deleted_at")
|
||||
op.drop_column("credentials", "card_brand")
|
||||
op.drop_column("credentials", "card_last4")
|
||||
op.drop_column("credentials", "totp_identifier")
|
||||
op.drop_column("credentials", "username")
|
||||
op.drop_column("credentials", "vault_type")
|
||||
# ### end Alembic commands ###
|
||||
@@ -295,6 +295,16 @@ class Settings(BaseSettings):
|
||||
BITWARDEN_EMAIL: str | None = None
|
||||
OP_SERVICE_ACCOUNT_TOKEN: str | None = None
|
||||
|
||||
# Where credentials are stored: bitwarden or azure_vault
|
||||
CREDENTIAL_VAULT_TYPE: str = "bitwarden"
|
||||
|
||||
# Azure Setting
|
||||
AZURE_TENANT_ID: str | None = None
|
||||
AZURE_CLIENT_ID: str | None = None
|
||||
AZURE_CLIENT_SECRET: str | None = None
|
||||
# The Azure Key Vault name to store credentials
|
||||
AZURE_CREDENTIAL_VAULT: str | None = None
|
||||
|
||||
# Skyvern Auth Bitwarden Settings
|
||||
SKYVERN_AUTH_BITWARDEN_CLIENT_ID: str | None = None
|
||||
SKYVERN_AUTH_BITWARDEN_CLIENT_SECRET: str | None = None
|
||||
|
||||
@@ -13,7 +13,9 @@ from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
|
||||
from skyvern.forge.sdk.cache.factory import CacheFactory
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider
|
||||
from skyvern.forge.sdk.schemas.credentials import CredentialVaultType
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.services.credential.azure_credential_vault_service import AzureCredentialVaultService
|
||||
from skyvern.forge.sdk.services.credential.bitwarden_credential_service import BitwardenCredentialVaultService
|
||||
from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
@@ -96,7 +98,20 @@ WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
||||
WORKFLOW_SERVICE = WorkflowService()
|
||||
AGENT_FUNCTION = AgentFunction()
|
||||
PERSISTENT_SESSIONS_MANAGER = PersistentSessionsManager(database=DATABASE)
|
||||
CREDENTIAL_VAULT_SERVICE: CredentialVaultService = BitwardenCredentialVaultService()
|
||||
|
||||
BITWARDEN_CREDENTIAL_VAULT_SERVICE: BitwardenCredentialVaultService = BitwardenCredentialVaultService()
|
||||
AZURE_CREDENTIAL_VAULT_SERVICE: AzureCredentialVaultService | None = None
|
||||
if SettingsManager.get_settings().AZURE_CREDENTIAL_VAULT:
|
||||
AZURE_CREDENTIAL_VAULT_SERVICE = AzureCredentialVaultService(
|
||||
tenant_id=SettingsManager.get_settings().AZURE_TENANT_ID, # type: ignore
|
||||
client_id=SettingsManager.get_settings().AZURE_CLIENT_ID, # type: ignore
|
||||
client_secret=SettingsManager.get_settings().AZURE_CLIENT_SECRET, # type: ignore
|
||||
vault_name=SettingsManager.get_settings().AZURE_CREDENTIAL_VAULT, # type: ignore
|
||||
)
|
||||
CREDENTIAL_VAULT_SERVICES: dict[str, CredentialVaultService | None] = {
|
||||
CredentialVaultType.BITWARDEN: BITWARDEN_CREDENTIAL_VAULT_SERVICE,
|
||||
CredentialVaultType.AZURE_VAULT: AZURE_CREDENTIAL_VAULT_SERVICE,
|
||||
}
|
||||
|
||||
scrape_exclude: ScrapeExcludeFunc | None = None
|
||||
authentication_function: Callable[[str], Awaitable[Organization]] | None = None
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Self
|
||||
|
||||
import structlog
|
||||
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
|
||||
from azure.keyvault.secrets.aio import SecretClient
|
||||
@@ -9,24 +11,57 @@ LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class AsyncAzureVaultClient:
|
||||
def __init__(self, credential: ClientSecretCredential | DefaultAzureCredential):
|
||||
def __init__(self, credential: ClientSecretCredential | DefaultAzureCredential) -> None:
|
||||
self.credential = credential
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object
|
||||
) -> None:
|
||||
await self.credential.close()
|
||||
|
||||
async def get_secret(self, secret_name: str, vault_name: str) -> str | None:
|
||||
secret_client = await self._get_secret_client(vault_name)
|
||||
try:
|
||||
# Azure Key Vault URL format: https://<your-key-vault-name>.vault.azure.net
|
||||
# Assuming the secret_name is actually the Key Vault URL and the secret name
|
||||
# This needs to be clarified or passed as separate parameters
|
||||
# For now, let's assume secret_name is the actual secret name and Key Vault URL is in settings.
|
||||
key_vault_url = f"https://{vault_name}.vault.azure.net" # Placeholder, adjust as needed
|
||||
secret_client = SecretClient(vault_url=key_vault_url, credential=self.credential)
|
||||
secret = await secret_client.get_secret(secret_name)
|
||||
return secret.value
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to get secret from Azure Key Vault.", secret_name=secret_name, error=e)
|
||||
return None
|
||||
finally:
|
||||
await self.credential.close()
|
||||
await secret_client.close()
|
||||
|
||||
async def create_secret(self, secret_name: str, secret_value: str, vault_name: str) -> str:
|
||||
secret_client = await self._get_secret_client(vault_name)
|
||||
try:
|
||||
secret = await secret_client.set_secret(secret_name, secret_value)
|
||||
return secret.name
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to create secret from Azure Key Vault.", secret_name=secret_name, error=e)
|
||||
raise e
|
||||
finally:
|
||||
await secret_client.close()
|
||||
|
||||
async def delete_secret(self, secret_name: str, vault_name: str) -> str:
|
||||
secret_client = await self._get_secret_client(vault_name)
|
||||
try:
|
||||
secret = await secret_client.delete_secret(secret_name)
|
||||
return secret.name
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to delete secret from Azure Key Vault.", secret_name=secret_name, error=e)
|
||||
raise e
|
||||
finally:
|
||||
await secret_client.close()
|
||||
|
||||
async def _get_secret_client(self, vault_name: str) -> SecretClient:
|
||||
# Azure Key Vault URL format: https://<your-key-vault-name>.vault.azure.net
|
||||
# Assuming the secret_name is actually the Key Vault URL and the secret name
|
||||
# This needs to be clarified or passed as separate parameters
|
||||
# For now, let's assume secret_name is the actual secret name and Key Vault URL is in settings.
|
||||
key_vault_url = f"https://{vault_name}.vault.azure.net" # Placeholder, adjust as needed
|
||||
return SecretClient(vault_url=key_vault_url, credential=self.credential)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.credential.close()
|
||||
|
||||
@@ -77,7 +77,7 @@ from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
|
||||
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType, CredentialVaultType
|
||||
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession
|
||||
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
@@ -3628,19 +3628,27 @@ class AgentDB:
|
||||
|
||||
async def create_credential(
|
||||
self,
|
||||
name: str,
|
||||
credential_type: CredentialType,
|
||||
organization_id: str,
|
||||
name: str,
|
||||
vault_type: CredentialVaultType,
|
||||
item_id: str,
|
||||
totp_type: str = "none",
|
||||
credential_type: CredentialType,
|
||||
username: str | None,
|
||||
totp_type: str,
|
||||
card_last4: str | None,
|
||||
card_brand: str | None,
|
||||
) -> Credential:
|
||||
async with self.Session() as session:
|
||||
credential = CredentialModel(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
credential_type=credential_type,
|
||||
vault_type=vault_type,
|
||||
item_id=item_id,
|
||||
credential_type=credential_type,
|
||||
username=username,
|
||||
totp_type=totp_type,
|
||||
card_last4=card_last4,
|
||||
card_brand=card_brand,
|
||||
)
|
||||
session.add(credential)
|
||||
await session.commit()
|
||||
@@ -3733,7 +3741,9 @@ class AgentDB:
|
||||
async with self.Session() as session:
|
||||
organization_bitwarden_collection = (
|
||||
await session.scalars(
|
||||
select(OrganizationBitwardenCollectionModel).filter_by(organization_id=organization_id)
|
||||
select(OrganizationBitwardenCollectionModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
if organization_bitwarden_collection:
|
||||
|
||||
@@ -807,6 +807,7 @@ class OrganizationBitwardenCollectionModel(Base):
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class CredentialModel(Base):
|
||||
@@ -814,11 +815,16 @@ class CredentialModel(Base):
|
||||
|
||||
credential_id = Column(String, primary_key=True, default=generate_credential_id)
|
||||
organization_id = Column(String, nullable=False)
|
||||
vault_type = Column(String, nullable=True)
|
||||
item_id = Column(String, nullable=True)
|
||||
|
||||
name = Column(String, nullable=False)
|
||||
credential_type = Column(String, nullable=False)
|
||||
username = Column(String, nullable=True)
|
||||
totp_type = Column(String, nullable=False, default="none")
|
||||
totp_identifier = Column(String, nullable=True, default=None)
|
||||
card_last4 = Column(String, nullable=True)
|
||||
card_brand = Column(String, nullable=True)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import structlog
|
||||
from fastapi import BackgroundTasks, Body, Depends, HTTPException, Path, Query
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
@@ -17,6 +18,7 @@ from skyvern.forge.sdk.schemas.credentials import (
|
||||
CreateCredentialRequest,
|
||||
CredentialResponse,
|
||||
CredentialType,
|
||||
CredentialVaultType,
|
||||
CreditCardCredentialResponse,
|
||||
PasswordCredentialResponse,
|
||||
)
|
||||
@@ -30,6 +32,7 @@ from skyvern.forge.sdk.schemas.organizations import (
|
||||
from skyvern.forge.sdk.schemas.totp_codes import TOTPCode, TOTPCodeCreate
|
||||
from skyvern.forge.sdk.services import org_auth_service
|
||||
from skyvern.forge.sdk.services.bitwarden import BitwardenService
|
||||
from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -155,12 +158,13 @@ async def create_credential(
|
||||
),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> CredentialResponse:
|
||||
credential = await app.CREDENTIAL_VAULT_SERVICE.create_credential(
|
||||
organization_id=current_org.organization_id, data=data
|
||||
)
|
||||
credential_service = await _get_credential_vault_service(current_org.organization_id)
|
||||
|
||||
# Early resyncing the Bitwarden vault
|
||||
background_tasks.add_task(fetch_credential_item_background, credential.item_id)
|
||||
credential = await credential_service.create_credential(organization_id=current_org.organization_id, data=data)
|
||||
|
||||
if credential.vault_type == CredentialVaultType.BITWARDEN:
|
||||
# Early resyncing the Bitwarden vault
|
||||
background_tasks.add_task(fetch_credential_item_background, credential.item_id)
|
||||
|
||||
if data.credential_type == CredentialType.PASSWORD:
|
||||
credential_response = PasswordCredentialResponse(
|
||||
@@ -221,7 +225,12 @@ async def delete_credential(
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail=f"Credential not found, credential_id={credential_id}")
|
||||
|
||||
await app.CREDENTIAL_VAULT_SERVICE.delete_credential(credential)
|
||||
vault_type = credential.vault_type or CredentialVaultType.BITWARDEN
|
||||
credential_service = app.CREDENTIAL_VAULT_SERVICES.get(vault_type)
|
||||
if not credential_service:
|
||||
raise HTTPException(status_code=400, detail="Unsupported credential storage type")
|
||||
|
||||
await credential_service.delete_credential(credential)
|
||||
|
||||
return None
|
||||
|
||||
@@ -253,7 +262,9 @@ async def get_credential(
|
||||
),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> CredentialResponse:
|
||||
return await app.CREDENTIAL_VAULT_SERVICE.get_credential(current_org.organization_id, credential_id)
|
||||
credential_service = await _get_credential_vault_service(current_org.organization_id)
|
||||
|
||||
return await credential_service.get_credential(current_org.organization_id, credential_id)
|
||||
|
||||
|
||||
@legacy_base_router.get("/credentials")
|
||||
@@ -291,7 +302,9 @@ async def get_credentials(
|
||||
openapi_extra={"x-fern-sdk-parameter-name": "page_size"},
|
||||
),
|
||||
) -> list[CredentialResponse]:
|
||||
return await app.CREDENTIAL_VAULT_SERVICE.get_credentials(current_org.organization_id, page, page_size)
|
||||
credential_service = await _get_credential_vault_service(current_org.organization_id)
|
||||
|
||||
return await credential_service.get_credentials(current_org.organization_id, page, page_size)
|
||||
|
||||
|
||||
@base_router.get(
|
||||
@@ -498,3 +511,16 @@ async def update_azure_client_secret_credential(
|
||||
status_code=500,
|
||||
detail=f"Failed to create or update Azure Client Secret Credential: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
async def _get_credential_vault_service(organization_id: str) -> CredentialVaultService:
|
||||
org_collection = await app.DATABASE.get_organization_bitwarden_collection(organization_id)
|
||||
|
||||
if settings.CREDENTIAL_VAULT_TYPE == CredentialVaultType.BITWARDEN or org_collection:
|
||||
return app.BITWARDEN_CREDENTIAL_VAULT_SERVICE
|
||||
elif settings.CREDENTIAL_VAULT_TYPE == CredentialVaultType.AZURE_VAULT:
|
||||
if not app.AZURE_CREDENTIAL_VAULT_SERVICE:
|
||||
raise HTTPException(status_code=400, detail="Azure Vault credential is not supported")
|
||||
return app.AZURE_CREDENTIAL_VAULT_SERVICE
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Credential storage not supported")
|
||||
|
||||
@@ -4,6 +4,11 @@ from enum import StrEnum
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class CredentialVaultType(StrEnum):
|
||||
BITWARDEN = "bitwarden"
|
||||
AZURE_VAULT = "azure_vault"
|
||||
|
||||
|
||||
class CredentialType(StrEnum):
|
||||
"""Type of credential stored in the system."""
|
||||
|
||||
@@ -141,13 +146,17 @@ class Credential(BaseModel):
|
||||
..., description="ID of the organization that owns the credential", examples=["o_1234567890"]
|
||||
)
|
||||
name: str = Field(..., description="Name of the credential", examples=["Skyvern Login"])
|
||||
credential_type: CredentialType = Field(..., description="Type of the credential. Eg password, credit card, etc.")
|
||||
vault_type: CredentialVaultType | None = Field(..., description="Where the secret is stored: Bitwarden vs Azure")
|
||||
item_id: str = Field(..., description="ID of the associated credential item", examples=["item_1234567890"])
|
||||
credential_type: CredentialType = Field(..., description="Type of the credential. Eg password, credit card, etc.")
|
||||
username: str | None = Field(..., description="For password credentials: the username")
|
||||
totp_type: TotpType = Field(
|
||||
TotpType.NONE,
|
||||
description="Type of 2FA method used for this credential",
|
||||
examples=[TotpType.AUTHENTICATOR],
|
||||
)
|
||||
card_last4: str | None = Field(..., description="For credit_card credentials: the last four digits of the card")
|
||||
card_brand: str | None = Field(..., description="For credit_card credentials: the card brand")
|
||||
|
||||
created_at: datetime = Field(..., description="Timestamp when the credential was created")
|
||||
modified_at: datetime = Field(..., description="Timestamp when the credential was last modified")
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
import uuid
|
||||
from typing import Annotated, Literal, Union
|
||||
|
||||
from azure.identity.aio import ClientSecretCredential
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.azure import AsyncAzureVaultClient
|
||||
from skyvern.forge.sdk.schemas.credentials import (
|
||||
CreateCredentialRequest,
|
||||
Credential,
|
||||
CredentialItem,
|
||||
CredentialResponse,
|
||||
CredentialType,
|
||||
CredentialVaultType,
|
||||
CreditCardCredential,
|
||||
CreditCardCredentialResponse,
|
||||
PasswordCredential,
|
||||
PasswordCredentialResponse,
|
||||
)
|
||||
from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService
|
||||
|
||||
|
||||
class AzureCredentialVaultService(CredentialVaultService):
|
||||
class _PasswordCredentialDataImage(BaseModel):
|
||||
type: Literal["password"]
|
||||
password: str
|
||||
username: str
|
||||
totp: str | None = None
|
||||
|
||||
class _CreditCardCredentialDataImage(BaseModel):
|
||||
type: Literal["credit_card"]
|
||||
card_number: str
|
||||
card_cvv: str
|
||||
card_exp_month: str
|
||||
card_exp_year: str
|
||||
card_brand: str
|
||||
card_holder_name: str
|
||||
|
||||
_CredentialDataImage = Annotated[
|
||||
Union[_PasswordCredentialDataImage, _CreditCardCredentialDataImage], Field(discriminator="type")
|
||||
]
|
||||
|
||||
def __init__(self, tenant_id: str, client_id: str, client_secret: str, vault_name: str):
|
||||
self._client = AsyncAzureVaultClient(
|
||||
ClientSecretCredential(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
)
|
||||
self._vault_name = vault_name
|
||||
|
||||
async def create_credential(self, organization_id: str, data: CreateCredentialRequest) -> Credential:
|
||||
item_id = await self._create_azure_secret_item(
|
||||
organization_id=organization_id,
|
||||
credential=data.credential,
|
||||
)
|
||||
|
||||
credential = await self._create_db_credential(
|
||||
organization_id=organization_id,
|
||||
data=data,
|
||||
item_id=item_id,
|
||||
vault_type=CredentialVaultType.AZURE_VAULT,
|
||||
)
|
||||
|
||||
return credential
|
||||
|
||||
async def delete_credential(
|
||||
self,
|
||||
credential: Credential,
|
||||
) -> None:
|
||||
await app.DATABASE.delete_credential(credential.credential_id, credential.organization_id)
|
||||
await self.delete_credential_item(credential.item_id)
|
||||
|
||||
async def get_credential(self, organization_id: str, credential_id: str) -> CredentialResponse:
|
||||
credential = await app.DATABASE.get_credential(credential_id=credential_id, organization_id=organization_id)
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credential not found")
|
||||
|
||||
return _convert_to_response(credential)
|
||||
|
||||
async def get_credentials(self, organization_id: str, page: int, page_size: int) -> list[CredentialResponse]:
|
||||
credentials = await app.DATABASE.get_credentials(organization_id, page=page, page_size=page_size)
|
||||
return [_convert_to_response(credential) for credential in credentials]
|
||||
|
||||
async def delete_credential_item(self, item_id: str) -> None:
|
||||
await self._client.delete_secret(
|
||||
vault_name=self._vault_name,
|
||||
secret_name=item_id,
|
||||
)
|
||||
|
||||
async def get_credential_item(self, db_credential: Credential) -> CredentialItem:
|
||||
secret_json_str = await self._client.get_secret(secret_name=db_credential.item_id, vault_name=self._vault_name)
|
||||
if secret_json_str is None:
|
||||
raise ValueError(f"Azure Credential Vault secret not found for {db_credential.item_id}")
|
||||
|
||||
data = TypeAdapter(AzureCredentialVaultService._CredentialDataImage).validate_json(secret_json_str)
|
||||
if isinstance(data, AzureCredentialVaultService._PasswordCredentialDataImage):
|
||||
return CredentialItem(
|
||||
item_id=db_credential.item_id,
|
||||
credential=PasswordCredential(
|
||||
username=data.username,
|
||||
password=data.password,
|
||||
totp=data.totp,
|
||||
totp_type=db_credential.totp_type,
|
||||
),
|
||||
name=db_credential.name,
|
||||
credential_type=CredentialType.PASSWORD,
|
||||
)
|
||||
elif isinstance(data, AzureCredentialVaultService._CreditCardCredentialDataImage):
|
||||
return CredentialItem(
|
||||
item_id=db_credential.item_id,
|
||||
credential=CreditCardCredential(
|
||||
card_holder_name=data.card_holder_name,
|
||||
card_number=data.card_number,
|
||||
card_exp_month=data.card_exp_month,
|
||||
card_exp_year=data.card_exp_year,
|
||||
card_cvv=data.card_cvv,
|
||||
card_brand=data.card_brand,
|
||||
),
|
||||
name=db_credential.name,
|
||||
credential_type=CredentialType.CREDIT_CARD,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Invalid credential type: {type(data)}")
|
||||
|
||||
async def _create_azure_secret_item(
|
||||
self,
|
||||
organization_id: str,
|
||||
credential: PasswordCredential | CreditCardCredential,
|
||||
) -> str:
|
||||
if isinstance(credential, PasswordCredential):
|
||||
data = AzureCredentialVaultService._PasswordCredentialDataImage(
|
||||
type="password",
|
||||
username=credential.username,
|
||||
password=credential.password,
|
||||
totp=credential.totp,
|
||||
)
|
||||
elif isinstance(credential, CreditCardCredential):
|
||||
data = AzureCredentialVaultService._CreditCardCredentialDataImage(
|
||||
type="credit_card",
|
||||
card_number=credential.card_number,
|
||||
card_cvv=credential.card_cvv,
|
||||
card_exp_month=credential.card_exp_month,
|
||||
card_exp_year=credential.card_exp_year,
|
||||
card_brand=credential.card_brand,
|
||||
card_holder_name=credential.card_holder_name,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Invalid credential type: {type(credential)}")
|
||||
|
||||
secret_name = f"{organization_id}-{uuid.uuid4()}".replace("_", "")
|
||||
secret_value = data.model_dump_json(exclude_none=True)
|
||||
|
||||
return await self._client.create_secret(
|
||||
vault_name=self._vault_name,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_response(credential: Credential) -> CredentialResponse:
|
||||
if credential.credential_type == CredentialType.PASSWORD:
|
||||
credential_response = PasswordCredentialResponse(
|
||||
username=credential.username or credential.credential_id,
|
||||
totp_type=credential.totp_type,
|
||||
)
|
||||
return CredentialResponse(
|
||||
credential=credential_response,
|
||||
credential_id=credential.credential_id,
|
||||
credential_type=credential.credential_type,
|
||||
name=credential.name,
|
||||
)
|
||||
elif credential.credential_type == CredentialType.CREDIT_CARD:
|
||||
credential_response = CreditCardCredentialResponse(
|
||||
last_four=credential.card_last4 or "****",
|
||||
brand=credential.card_brand or "Card Brand",
|
||||
)
|
||||
return CredentialResponse(
|
||||
credential=credential_response,
|
||||
credential_id=credential.credential_id,
|
||||
credential_type=credential.credential_type,
|
||||
name=credential.name,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Credential type not supported")
|
||||
@@ -8,6 +8,7 @@ from skyvern.forge.sdk.schemas.credentials import (
|
||||
CredentialItem,
|
||||
CredentialResponse,
|
||||
CredentialType,
|
||||
CredentialVaultType,
|
||||
CreditCardCredentialResponse,
|
||||
PasswordCredentialResponse,
|
||||
)
|
||||
@@ -40,12 +41,11 @@ class BitwardenCredentialVaultService(CredentialVaultService):
|
||||
credential=data.credential,
|
||||
)
|
||||
|
||||
credential = await app.DATABASE.create_credential(
|
||||
credential = await self._create_db_credential(
|
||||
organization_id=organization_id,
|
||||
data=data,
|
||||
item_id=item_id,
|
||||
name=data.name,
|
||||
credential_type=data.credential_type,
|
||||
totp_type=data.credential.totp_type if hasattr(data.credential, "totp_type") else "none",
|
||||
vault_type=CredentialVaultType.BITWARDEN,
|
||||
)
|
||||
|
||||
return credential
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.credentials import (
|
||||
CreateCredentialRequest,
|
||||
Credential,
|
||||
CredentialItem,
|
||||
CredentialResponse,
|
||||
CredentialType,
|
||||
CredentialVaultType,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,3 +37,37 @@ class CredentialVaultService(ABC):
|
||||
@abstractmethod
|
||||
async def get_credential_item(self, db_credential: Credential) -> CredentialItem:
|
||||
"""Retrieve the full credential data from the vault."""
|
||||
|
||||
@staticmethod
|
||||
async def _create_db_credential(
|
||||
organization_id: str,
|
||||
data: CreateCredentialRequest,
|
||||
item_id: str,
|
||||
vault_type: CredentialVaultType,
|
||||
) -> Credential:
|
||||
if data.credential_type == CredentialType.PASSWORD:
|
||||
return await app.DATABASE.create_credential(
|
||||
organization_id=organization_id,
|
||||
name=data.name,
|
||||
vault_type=vault_type,
|
||||
item_id=item_id,
|
||||
credential_type=data.credential_type,
|
||||
username=data.credential.username,
|
||||
totp_type=data.credential.totp_type,
|
||||
card_last4=None,
|
||||
card_brand=None,
|
||||
)
|
||||
elif data.credential_type == CredentialType.CREDIT_CARD:
|
||||
return await app.DATABASE.create_credential(
|
||||
organization_id=organization_id,
|
||||
name=data.name,
|
||||
vault_type=vault_type,
|
||||
item_id=item_id,
|
||||
credential_type=data.credential_type,
|
||||
username=None,
|
||||
totp_type="none",
|
||||
card_last4=data.credential.card_number[-4:],
|
||||
card_brand=data.credential.card_brand,
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Unsupported credential type: {data.credential_type}")
|
||||
|
||||
@@ -17,7 +17,7 @@ from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.aws import AsyncAWSClient
|
||||
from skyvern.forge.sdk.api.azure import AsyncAzureVaultClient
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.schemas.credentials import PasswordCredential
|
||||
from skyvern.forge.sdk.schemas.credentials import CredentialVaultType, PasswordCredential
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.services.bitwarden import BitwardenConstants, BitwardenService
|
||||
@@ -295,7 +295,12 @@ class WorkflowRunContext:
|
||||
if db_credential is None:
|
||||
raise CredentialParameterNotFoundError(credential_id)
|
||||
|
||||
credential_item = await app.CREDENTIAL_VAULT_SERVICE.get_credential_item(db_credential)
|
||||
vault_type = db_credential.vault_type or CredentialVaultType.BITWARDEN
|
||||
credential_service = app.CREDENTIAL_VAULT_SERVICES.get(vault_type)
|
||||
if credential_service is None:
|
||||
raise CredentialParameterNotFoundError(credential_id)
|
||||
|
||||
credential_item = await credential_service.get_credential_item(db_credential)
|
||||
credential = credential_item.credential
|
||||
|
||||
self.parameters[parameter.key] = parameter
|
||||
@@ -347,7 +352,12 @@ class WorkflowRunContext:
|
||||
if db_credential is None:
|
||||
raise CredentialParameterNotFoundError(credential_id)
|
||||
|
||||
credential_item = await app.CREDENTIAL_VAULT_SERVICE.get_credential_item(db_credential)
|
||||
vault_type = db_credential.vault_type or CredentialVaultType.BITWARDEN
|
||||
credential_service = app.CREDENTIAL_VAULT_SERVICES.get(vault_type)
|
||||
if credential_service is None:
|
||||
raise CredentialParameterNotFoundError(credential_id)
|
||||
|
||||
credential_item = await credential_service.get_credential_item(db_credential)
|
||||
credential = credential_item.credential
|
||||
|
||||
self.parameters[parameter.key] = parameter
|
||||
@@ -398,13 +408,13 @@ class WorkflowRunContext:
|
||||
# If the parameter is an Azure secret, fetch the secret value and store it in the secrets dict
|
||||
# The value of the parameter will be the random secret id with format `secret_<uuid>`.
|
||||
# We'll replace the random secret id with the actual secret value when we need to use it.
|
||||
azure_vault_client = AsyncAzureVaultClient.create_default()
|
||||
secret_value = await azure_vault_client.get_secret(parameter.azure_key, vault_name)
|
||||
if secret_value is not None:
|
||||
random_secret_id = self.generate_random_secret_id()
|
||||
self.secrets[random_secret_id] = secret_value
|
||||
self.values[parameter.key] = random_secret_id
|
||||
self.parameters[parameter.key] = parameter
|
||||
async with AsyncAzureVaultClient.create_default() as azure_vault_client:
|
||||
secret_value = await azure_vault_client.get_secret(parameter.azure_key, vault_name)
|
||||
if secret_value is not None:
|
||||
random_secret_id = self.generate_random_secret_id()
|
||||
self.secrets[random_secret_id] = secret_value
|
||||
self.values[parameter.key] = random_secret_id
|
||||
self.parameters[parameter.key] = parameter
|
||||
|
||||
async def register_onepassword_credential_parameter_value(
|
||||
self, parameter: OnePasswordCredentialParameter, organization: Organization
|
||||
@@ -562,22 +572,21 @@ class WorkflowRunContext:
|
||||
|
||||
totp_secret_key = self._resolve_parameter_value(parameter.totp_secret_key)
|
||||
|
||||
azure_vault_client = await self._get_azure_vault_client_for_organization(organization)
|
||||
async with await self._get_azure_vault_client_for_organization(organization) as azure_vault_client:
|
||||
secret_username = await azure_vault_client.get_secret(username_key, vault_name)
|
||||
if not secret_username:
|
||||
raise ValueError(f"Azure Vault username not found by key: {username_key}")
|
||||
|
||||
secret_username = await azure_vault_client.get_secret(username_key, vault_name)
|
||||
if not secret_username:
|
||||
raise ValueError(f"Azure Vault username not found by key: {username_key}")
|
||||
secret_password = await azure_vault_client.get_secret(password_key, vault_name)
|
||||
if not secret_password:
|
||||
raise ValueError(f"Azure Vault password not found by key: {password_key}")
|
||||
|
||||
secret_password = await azure_vault_client.get_secret(password_key, vault_name)
|
||||
if not secret_password:
|
||||
raise ValueError(f"Azure Vault password not found by key: {password_key}")
|
||||
|
||||
if totp_secret_key:
|
||||
totp_secret = await azure_vault_client.get_secret(totp_secret_key, vault_name)
|
||||
if not totp_secret:
|
||||
raise ValueError(f"Azure Vault TOTP not found by key: {totp_secret_key}")
|
||||
else:
|
||||
totp_secret = None
|
||||
if totp_secret_key:
|
||||
totp_secret = await azure_vault_client.get_secret(totp_secret_key, vault_name)
|
||||
if not totp_secret:
|
||||
raise ValueError(f"Azure Vault TOTP not found by key: {totp_secret_key}")
|
||||
else:
|
||||
totp_secret = None
|
||||
|
||||
if secret_username is not None and secret_password is not None:
|
||||
random_secret_id = self.generate_random_secret_id()
|
||||
|
||||
Reference in New Issue
Block a user