Migrate credentials to Azure Key Vault (#3681)
This commit is contained in:
committed by
GitHub
parent
c3ce5b1952
commit
32e6aed8ce
@@ -0,0 +1,188 @@
|
||||
import uuid
|
||||
from typing import Annotated, Literal, Union
|
||||
|
||||
from azure.identity.aio import ClientSecretCredential
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.azure import AsyncAzureVaultClient
|
||||
from skyvern.forge.sdk.schemas.credentials import (
|
||||
CreateCredentialRequest,
|
||||
Credential,
|
||||
CredentialItem,
|
||||
CredentialResponse,
|
||||
CredentialType,
|
||||
CredentialVaultType,
|
||||
CreditCardCredential,
|
||||
CreditCardCredentialResponse,
|
||||
PasswordCredential,
|
||||
PasswordCredentialResponse,
|
||||
)
|
||||
from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService
|
||||
|
||||
|
||||
class AzureCredentialVaultService(CredentialVaultService):
|
||||
class _PasswordCredentialDataImage(BaseModel):
|
||||
type: Literal["password"]
|
||||
password: str
|
||||
username: str
|
||||
totp: str | None = None
|
||||
|
||||
class _CreditCardCredentialDataImage(BaseModel):
|
||||
type: Literal["credit_card"]
|
||||
card_number: str
|
||||
card_cvv: str
|
||||
card_exp_month: str
|
||||
card_exp_year: str
|
||||
card_brand: str
|
||||
card_holder_name: str
|
||||
|
||||
_CredentialDataImage = Annotated[
|
||||
Union[_PasswordCredentialDataImage, _CreditCardCredentialDataImage], Field(discriminator="type")
|
||||
]
|
||||
|
||||
def __init__(self, tenant_id: str, client_id: str, client_secret: str, vault_name: str):
|
||||
self._client = AsyncAzureVaultClient(
|
||||
ClientSecretCredential(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
)
|
||||
self._vault_name = vault_name
|
||||
|
||||
async def create_credential(self, organization_id: str, data: CreateCredentialRequest) -> Credential:
|
||||
item_id = await self._create_azure_secret_item(
|
||||
organization_id=organization_id,
|
||||
credential=data.credential,
|
||||
)
|
||||
|
||||
credential = await self._create_db_credential(
|
||||
organization_id=organization_id,
|
||||
data=data,
|
||||
item_id=item_id,
|
||||
vault_type=CredentialVaultType.AZURE_VAULT,
|
||||
)
|
||||
|
||||
return credential
|
||||
|
||||
async def delete_credential(
|
||||
self,
|
||||
credential: Credential,
|
||||
) -> None:
|
||||
await app.DATABASE.delete_credential(credential.credential_id, credential.organization_id)
|
||||
await self.delete_credential_item(credential.item_id)
|
||||
|
||||
async def get_credential(self, organization_id: str, credential_id: str) -> CredentialResponse:
|
||||
credential = await app.DATABASE.get_credential(credential_id=credential_id, organization_id=organization_id)
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credential not found")
|
||||
|
||||
return _convert_to_response(credential)
|
||||
|
||||
async def get_credentials(self, organization_id: str, page: int, page_size: int) -> list[CredentialResponse]:
|
||||
credentials = await app.DATABASE.get_credentials(organization_id, page=page, page_size=page_size)
|
||||
return [_convert_to_response(credential) for credential in credentials]
|
||||
|
||||
async def delete_credential_item(self, item_id: str) -> None:
|
||||
await self._client.delete_secret(
|
||||
vault_name=self._vault_name,
|
||||
secret_name=item_id,
|
||||
)
|
||||
|
||||
async def get_credential_item(self, db_credential: Credential) -> CredentialItem:
|
||||
secret_json_str = await self._client.get_secret(secret_name=db_credential.item_id, vault_name=self._vault_name)
|
||||
if secret_json_str is None:
|
||||
raise ValueError(f"Azure Credential Vault secret not found for {db_credential.item_id}")
|
||||
|
||||
data = TypeAdapter(AzureCredentialVaultService._CredentialDataImage).validate_json(secret_json_str)
|
||||
if isinstance(data, AzureCredentialVaultService._PasswordCredentialDataImage):
|
||||
return CredentialItem(
|
||||
item_id=db_credential.item_id,
|
||||
credential=PasswordCredential(
|
||||
username=data.username,
|
||||
password=data.password,
|
||||
totp=data.totp,
|
||||
totp_type=db_credential.totp_type,
|
||||
),
|
||||
name=db_credential.name,
|
||||
credential_type=CredentialType.PASSWORD,
|
||||
)
|
||||
elif isinstance(data, AzureCredentialVaultService._CreditCardCredentialDataImage):
|
||||
return CredentialItem(
|
||||
item_id=db_credential.item_id,
|
||||
credential=CreditCardCredential(
|
||||
card_holder_name=data.card_holder_name,
|
||||
card_number=data.card_number,
|
||||
card_exp_month=data.card_exp_month,
|
||||
card_exp_year=data.card_exp_year,
|
||||
card_cvv=data.card_cvv,
|
||||
card_brand=data.card_brand,
|
||||
),
|
||||
name=db_credential.name,
|
||||
credential_type=CredentialType.CREDIT_CARD,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Invalid credential type: {type(data)}")
|
||||
|
||||
async def _create_azure_secret_item(
|
||||
self,
|
||||
organization_id: str,
|
||||
credential: PasswordCredential | CreditCardCredential,
|
||||
) -> str:
|
||||
if isinstance(credential, PasswordCredential):
|
||||
data = AzureCredentialVaultService._PasswordCredentialDataImage(
|
||||
type="password",
|
||||
username=credential.username,
|
||||
password=credential.password,
|
||||
totp=credential.totp,
|
||||
)
|
||||
elif isinstance(credential, CreditCardCredential):
|
||||
data = AzureCredentialVaultService._CreditCardCredentialDataImage(
|
||||
type="credit_card",
|
||||
card_number=credential.card_number,
|
||||
card_cvv=credential.card_cvv,
|
||||
card_exp_month=credential.card_exp_month,
|
||||
card_exp_year=credential.card_exp_year,
|
||||
card_brand=credential.card_brand,
|
||||
card_holder_name=credential.card_holder_name,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Invalid credential type: {type(credential)}")
|
||||
|
||||
secret_name = f"{organization_id}-{uuid.uuid4()}".replace("_", "")
|
||||
secret_value = data.model_dump_json(exclude_none=True)
|
||||
|
||||
return await self._client.create_secret(
|
||||
vault_name=self._vault_name,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_response(credential: Credential) -> CredentialResponse:
|
||||
if credential.credential_type == CredentialType.PASSWORD:
|
||||
credential_response = PasswordCredentialResponse(
|
||||
username=credential.username or credential.credential_id,
|
||||
totp_type=credential.totp_type,
|
||||
)
|
||||
return CredentialResponse(
|
||||
credential=credential_response,
|
||||
credential_id=credential.credential_id,
|
||||
credential_type=credential.credential_type,
|
||||
name=credential.name,
|
||||
)
|
||||
elif credential.credential_type == CredentialType.CREDIT_CARD:
|
||||
credential_response = CreditCardCredentialResponse(
|
||||
last_four=credential.card_last4 or "****",
|
||||
brand=credential.card_brand or "Card Brand",
|
||||
)
|
||||
return CredentialResponse(
|
||||
credential=credential_response,
|
||||
credential_id=credential.credential_id,
|
||||
credential_type=credential.credential_type,
|
||||
name=credential.name,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Credential type not supported")
|
||||
@@ -8,6 +8,7 @@ from skyvern.forge.sdk.schemas.credentials import (
|
||||
CredentialItem,
|
||||
CredentialResponse,
|
||||
CredentialType,
|
||||
CredentialVaultType,
|
||||
CreditCardCredentialResponse,
|
||||
PasswordCredentialResponse,
|
||||
)
|
||||
@@ -40,12 +41,11 @@ class BitwardenCredentialVaultService(CredentialVaultService):
|
||||
credential=data.credential,
|
||||
)
|
||||
|
||||
credential = await app.DATABASE.create_credential(
|
||||
credential = await self._create_db_credential(
|
||||
organization_id=organization_id,
|
||||
data=data,
|
||||
item_id=item_id,
|
||||
name=data.name,
|
||||
credential_type=data.credential_type,
|
||||
totp_type=data.credential.totp_type if hasattr(data.credential, "totp_type") else "none",
|
||||
vault_type=CredentialVaultType.BITWARDEN,
|
||||
)
|
||||
|
||||
return credential
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.credentials import (
|
||||
CreateCredentialRequest,
|
||||
Credential,
|
||||
CredentialItem,
|
||||
CredentialResponse,
|
||||
CredentialType,
|
||||
CredentialVaultType,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,3 +37,37 @@ class CredentialVaultService(ABC):
|
||||
@abstractmethod
|
||||
async def get_credential_item(self, db_credential: Credential) -> CredentialItem:
|
||||
"""Retrieve the full credential data from the vault."""
|
||||
|
||||
@staticmethod
|
||||
async def _create_db_credential(
|
||||
organization_id: str,
|
||||
data: CreateCredentialRequest,
|
||||
item_id: str,
|
||||
vault_type: CredentialVaultType,
|
||||
) -> Credential:
|
||||
if data.credential_type == CredentialType.PASSWORD:
|
||||
return await app.DATABASE.create_credential(
|
||||
organization_id=organization_id,
|
||||
name=data.name,
|
||||
vault_type=vault_type,
|
||||
item_id=item_id,
|
||||
credential_type=data.credential_type,
|
||||
username=data.credential.username,
|
||||
totp_type=data.credential.totp_type,
|
||||
card_last4=None,
|
||||
card_brand=None,
|
||||
)
|
||||
elif data.credential_type == CredentialType.CREDIT_CARD:
|
||||
return await app.DATABASE.create_credential(
|
||||
organization_id=organization_id,
|
||||
name=data.name,
|
||||
vault_type=vault_type,
|
||||
item_id=item_id,
|
||||
credential_type=data.credential_type,
|
||||
username=None,
|
||||
totp_type="none",
|
||||
card_last4=data.credential.card_number[-4:],
|
||||
card_brand=data.credential.card_brand,
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Unsupported credential type: {data.credential_type}")
|
||||
|
||||
Reference in New Issue
Block a user