Azure ClientSecretCredential support (#3456)

Co-authored-by: Suchintan <suchintan@users.noreply.github.com>
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
stenn930
2025-09-23 10:16:48 -06:00
committed by GitHub
parent 10fac9bad0
commit a29a2bc49b
12 changed files with 592 additions and 71 deletions

View File

@@ -1,6 +1,6 @@
import json
from datetime import datetime, timedelta, timezone
from typing import Any, List, Sequence
from typing import Any, List, Literal, Sequence, overload
import structlog
from sqlalchemy import and_, asc, delete, distinct, func, or_, pool, select, tuple_, update
@@ -79,7 +79,12 @@ from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
from skyvern.forge.sdk.schemas.organizations import (
AzureClientSecretCredential,
AzureOrganizationAuthToken,
Organization,
OrganizationAuthToken,
)
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
from skyvern.forge.sdk.schemas.runs import Run
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
@@ -865,11 +870,25 @@ class AgentDB:
await session.refresh(organization)
return Organization.model_validate(organization)
@overload
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account],
) -> OrganizationAuthToken | None: ...
@overload
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential],
) -> AzureOrganizationAuthToken | None: ...
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
) -> OrganizationAuthToken | None:
) -> OrganizationAuthToken | AzureOrganizationAuthToken | None:
try:
async with self.Session() as session:
if token := (
@@ -881,7 +900,7 @@ class AgentDB:
.order_by(OrganizationAuthTokenModel.created_at.desc())
)
).first():
return await convert_to_organization_auth_token(token)
return await convert_to_organization_auth_token(token, token_type)
else:
return None
except SQLAlchemyError:
@@ -907,7 +926,7 @@ class AgentDB:
.order_by(OrganizationAuthTokenModel.created_at.desc())
)
).all()
return [await convert_to_organization_auth_token(token) for token in tokens]
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
@@ -941,7 +960,7 @@ class AgentDB:
if valid is not None:
query = query.filter_by(valid=valid)
if token_obj := (await session.scalars(query)).first():
return await convert_to_organization_auth_token(token_obj)
return await convert_to_organization_auth_token(token_obj, token_type)
else:
return None
except SQLAlchemyError:
@@ -955,14 +974,22 @@ class AgentDB:
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
token: str | AzureClientSecretCredential,
encrypted_method: EncryptMethod | None = None,
) -> OrganizationAuthToken:
plaintext_token = token
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
if not isinstance(token, AzureClientSecretCredential):
raise TypeError("Expected AzureClientSecretCredential for this token_type")
plaintext_token = token.model_dump_json()
else:
if not isinstance(token, str):
raise TypeError("Expected str token for this token_type")
plaintext_token = token
encrypted_token = ""
if encrypted_method is not None:
encrypted_token = await encryptor.encrypt(token, encrypted_method)
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
plaintext_token = ""
async with self.Session() as session:
@@ -977,7 +1004,7 @@ class AgentDB:
await session.commit()
await session.refresh(auth_token)
return await convert_to_organization_auth_token(auth_token)
return await convert_to_organization_auth_token(auth_token, token_type)
async def invalidate_org_auth_tokens(
self,