Azure ClientSecretCredential support (#3456)
Co-authored-by: Suchintan <suchintan@users.noreply.github.com> Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, List, Sequence
|
||||
from typing import Any, List, Literal, Sequence, overload
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import and_, asc, delete, distinct, func, or_, pool, select, tuple_, update
|
||||
@@ -79,7 +79,12 @@ from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.credentials import Credential, CredentialType
|
||||
from skyvern.forge.sdk.schemas.debug_sessions import BlockRun, DebugSession
|
||||
from skyvern.forge.sdk.schemas.organization_bitwarden_collections import OrganizationBitwardenCollection
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
AzureClientSecretCredential,
|
||||
AzureOrganizationAuthToken,
|
||||
Organization,
|
||||
OrganizationAuthToken,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.runs import Run
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||
@@ -865,11 +870,25 @@ class AgentDB:
|
||||
await session.refresh(organization)
|
||||
return Organization.model_validate(organization)
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account],
|
||||
) -> OrganizationAuthToken | None: ...
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential],
|
||||
) -> AzureOrganizationAuthToken | None: ...
|
||||
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
) -> OrganizationAuthToken | None:
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | None:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
if token := (
|
||||
@@ -881,7 +900,7 @@ class AgentDB:
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).first():
|
||||
return await convert_to_organization_auth_token(token)
|
||||
return await convert_to_organization_auth_token(token, token_type)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -907,7 +926,7 @@ class AgentDB:
|
||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).all()
|
||||
return [await convert_to_organization_auth_token(token) for token in tokens]
|
||||
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
@@ -941,7 +960,7 @@ class AgentDB:
|
||||
if valid is not None:
|
||||
query = query.filter_by(valid=valid)
|
||||
if token_obj := (await session.scalars(query)).first():
|
||||
return await convert_to_organization_auth_token(token_obj)
|
||||
return await convert_to_organization_auth_token(token_obj, token_type)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
@@ -955,14 +974,22 @@ class AgentDB:
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
token: str | AzureClientSecretCredential,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken:
|
||||
plaintext_token = token
|
||||
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
||||
if not isinstance(token, AzureClientSecretCredential):
|
||||
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
||||
plaintext_token = token.model_dump_json()
|
||||
else:
|
||||
if not isinstance(token, str):
|
||||
raise TypeError("Expected str token for this token_type")
|
||||
plaintext_token = token
|
||||
|
||||
encrypted_token = ""
|
||||
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
||||
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
||||
plaintext_token = ""
|
||||
|
||||
async with self.Session() as session:
|
||||
@@ -977,7 +1004,7 @@ class AgentDB:
|
||||
await session.commit()
|
||||
await session.refresh(auth_token)
|
||||
|
||||
return await convert_to_organization_auth_token(auth_token)
|
||||
return await convert_to_organization_auth_token(auth_token, token_type)
|
||||
|
||||
async def invalidate_org_auth_tokens(
|
||||
self,
|
||||
|
||||
@@ -4,6 +4,7 @@ from enum import StrEnum
|
||||
class OrganizationAuthTokenType(StrEnum):
|
||||
api = "api"
|
||||
onepassword_service_account = "onepassword_service_account"
|
||||
azure_client_secret_credential = "azure_client_secret_credential"
|
||||
|
||||
|
||||
class TaskType(StrEnum):
|
||||
|
||||
@@ -30,7 +30,12 @@ from skyvern.forge.sdk.db.models import (
|
||||
from skyvern.forge.sdk.encrypt import encryptor
|
||||
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
||||
from skyvern.forge.sdk.schemas.organizations import (
|
||||
AzureClientSecretCredential,
|
||||
AzureOrganizationAuthToken,
|
||||
Organization,
|
||||
OrganizationAuthToken,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
@@ -195,21 +200,33 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
||||
|
||||
|
||||
async def convert_to_organization_auth_token(
|
||||
org_auth_token: OrganizationAuthTokenModel,
|
||||
) -> OrganizationAuthToken:
|
||||
org_auth_token: OrganizationAuthTokenModel, token_type: OrganizationAuthTokenType
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken:
|
||||
token = org_auth_token.token
|
||||
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
|
||||
token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method))
|
||||
|
||||
return OrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
token=token,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
)
|
||||
if token_type == OrganizationAuthTokenType.azure_client_secret_credential:
|
||||
credential = AzureClientSecretCredential.model_validate_json(token)
|
||||
return AzureOrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
credential=credential,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
)
|
||||
else:
|
||||
return OrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
token=token,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_artifact(artifact_model: ArtifactModel, debug_enabled: bool = False) -> Artifact:
|
||||
|
||||
Reference in New Issue
Block a user