Azure ClientSecretCredential support (#3456)
Co-authored-by: Suchintan <suchintan@users.noreply.github.com> Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, List, Sequence
|
||||
from typing import Any, List, Literal, Sequence, overload
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, asc, delete, distinct, func, or_, pool, select, tuple_, update
|
||||
@@ -79,7 +79,12 @@ from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
|
||||
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession
|
||||
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
AzureClientSecretCredential,
|
||||
AzureOrganizationAuthToken,
|
||||
Organization,
|
||||
OrganizationAuthToken,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||
@@ -865,11 +870,25 @@ class AgentDB:
|
||||
await session.refresh(organization)
|
||||
return Organization.model_validate(organization)
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account],
|
||||
) -> OrganizationAuthToken | None: ...
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential],
|
||||
) -> AzureOrganizationAuthToken | None: ...
|
||||
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> OrganizationAuthToken | None:
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | None:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
if token := (
|
||||
@@ -881,7 +900,7 @@ class AgentDB:
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).first():
|
||||
return await convert_to_organization_auth_token(token)
|
||||
return await convert_to_organization_auth_token(token, token_type)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -907,7 +926,7 @@ class AgentDB:
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).all()
|
||||
return [await convert_to_organization_auth_token(token) for token in tokens]
|
||||
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
@@ -941,7 +960,7 @@ class AgentDB:
|
||||
if valid is not None:
|
||||
query = query.filter_by(valid=valid)
|
||||
if token_obj := (await session.scalars(query)).first():
|
||||
return await convert_to_organization_auth_token(token_obj)
|
||||
return await convert_to_organization_auth_token(token_obj, token_type)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -955,14 +974,22 @@ class AgentDB:
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
token: str | AzureClientSecretCredential,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken:
|
||||
plaintext_token = token
|
||||
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
||||
if not isinstance(token, AzureClientSecretCredential):
|
||||
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
else:
|
||||
if not isinstance(token, str):
|
||||
raise TypeError("Expected str token for this token_type")
|
||||
plaintext_token = token
|
||||
|
||||
encrypted_token = ""
|
||||
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
||||
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
||||
plaintext_token = ""
|
||||
|
||||
async with self.Session() as session:
|
||||
@@ -977,7 +1004,7 @@ class AgentDB:
|
||||
await session.commit()
|
||||
await session.refresh(auth_token)
|
||||
|
||||
return await convert_to_organization_auth_token(auth_token)
|
||||
return await convert_to_organization_auth_token(auth_token, token_type)
|
||||
|
||||
async def invalidate_org_auth_tokens(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user