fix mypy issue for org tokens (#3541)

This commit is contained in:
Shuchang Zheng
2025-09-26 16:35:47 -07:00
committed by GitHub
parent 7cd1b37d9c
commit 8c54475fda
9 changed files with 15 additions and 13 deletions

View File

@@ -20,7 +20,7 @@ async def setup_local_organization() -> str:
organization = await skyvern_agent.get_organization()
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
token_type=OrganizationAuthTokenType.api.value,
)
return org_auth_token.token if org_auth_token else ""

View File

@@ -26,7 +26,7 @@ async def poll_verification_code(
timeout = timedelta(minutes=settings.VERIFICATION_CODE_POLLING_TIMEOUT_MINS)
start_datetime = datetime.utcnow()
timeout_datetime = start_datetime + timeout
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api)
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value)
if not org_token:
LOG.error("Failed to get organization token when trying to get verification code")
return None

View File

@@ -874,22 +874,23 @@ class AgentDB:
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account],
token_type: Literal["api", "onepassword_service_account"],
) -> OrganizationAuthToken | None: ...
@overload
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential],
token_type: Literal["azure_client_secret_credential"],
) -> AzureOrganizationAuthToken | None: ...
async def get_valid_org_auth_token(
self,
organization_id: str,
token_type: OrganizationAuthTokenType,
token_type: Literal["api", "onepassword_service_account", "azure_client_secret_credential"],
) -> OrganizationAuthToken | AzureOrganizationAuthToken | None:
try:
print("lol")
async with self.Session() as session:
if token := (
await session.scalars(

View File

@@ -200,7 +200,7 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
async def convert_to_organization_auth_token(
org_auth_token: OrganizationAuthTokenModel, token_type: OrganizationAuthTokenType
org_auth_token: OrganizationAuthTokenModel, token_type: str
) -> OrganizationAuthToken | AzureOrganizationAuthToken:
token = org_auth_token.token
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:

View File

@@ -1985,7 +1985,7 @@ async def get_api_keys(
if organization_id != current_org.organization_id:
raise HTTPException(status_code=403, detail="You do not have permission to access this organization")
api_keys = []
org_auth_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api)
org_auth_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value)
if org_auth_token:
api_keys.append(org_auth_token)
return GetOrganizationAPIKeysResponse(api_keys=api_keys)

View File

@@ -400,7 +400,7 @@ async def get_onepassword_token(
try:
auth_token = await app.DATABASE.get_valid_org_auth_token(
organization_id=current_org.organization_id,
token_type=OrganizationAuthTokenType.onepassword_service_account,
token_type=OrganizationAuthTokenType.onepassword_service_account.value,
)
if not auth_token:
raise HTTPException(
@@ -503,7 +503,7 @@ async def get_azure_client_secret_credential(
try:
auth_token = await app.DATABASE.get_valid_org_auth_token(
organization_id=current_org.organization_id,
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
token_type=OrganizationAuthTokenType.azure_client_secret_credential.value,
)
if not auth_token:
raise HTTPException(

View File

@@ -364,7 +364,8 @@ class WorkflowRunContext:
self, parameter: OnePasswordCredentialParameter, organization: Organization
) -> None:
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
organization.organization_id, OrganizationAuthTokenType.onepassword_service_account
organization.organization_id,
OrganizationAuthTokenType.onepassword_service_account.value,
)
token = settings.OP_SERVICE_ACCOUNT_TOKEN
if org_auth_token:
@@ -919,7 +920,7 @@ class WorkflowRunContext:
@staticmethod
async def _get_azure_vault_client_for_organization(organization: Organization) -> AsyncAzureVaultClient:
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential
organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential.value
)
if org_auth_token:
azure_vault_client = AsyncAzureVaultClient.create_from_client_secret(org_auth_token.credential)

View File

@@ -124,7 +124,7 @@ class Skyvern(AsyncSkyvern):
) -> None:
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
token_type=OrganizationAuthTokenType.api.value,
)
step = await app.DATABASE.create_step(

View File

@@ -1723,7 +1723,7 @@ async def send_task_v2_webhook(task_v2: TaskV2) -> None:
return
api_key = await app.DATABASE.get_valid_org_auth_token(
organization_id,
OrganizationAuthTokenType.api,
OrganizationAuthTokenType.api.value,
)
if not api_key:
LOG.warning(