diff --git a/skyvern/cli/mcp.py b/skyvern/cli/mcp.py index 498e6e29..1aa25741 100644 --- a/skyvern/cli/mcp.py +++ b/skyvern/cli/mcp.py @@ -20,7 +20,7 @@ async def setup_local_organization() -> str: organization = await skyvern_agent.get_organization() org_auth_token = await app.DATABASE.get_valid_org_auth_token( organization_id=organization.organization_id, - token_type=OrganizationAuthTokenType.api, + token_type=OrganizationAuthTokenType.api.value, ) return org_auth_token.token if org_auth_token else "" diff --git a/skyvern/core/totp.py b/skyvern/core/totp.py index 69b1e64b..aea88f43 100644 --- a/skyvern/core/totp.py +++ b/skyvern/core/totp.py @@ -26,7 +26,7 @@ async def poll_verification_code( timeout = timedelta(minutes=settings.VERIFICATION_CODE_POLLING_TIMEOUT_MINS) start_datetime = datetime.utcnow() timeout_datetime = start_datetime + timeout - org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api) + org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value) if not org_token: LOG.error("Failed to get organization token when trying to get verification code") return None diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index dadb41dc..e2fa1e72 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -874,22 +874,23 @@ class AgentDB: async def get_valid_org_auth_token( self, organization_id: str, - token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account], + token_type: Literal["api", "onepassword_service_account"], ) -> OrganizationAuthToken | None: ... @overload async def get_valid_org_auth_token( self, organization_id: str, - token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential], + token_type: Literal["azure_client_secret_credential"], ) -> AzureOrganizationAuthToken | None: ... async def get_valid_org_auth_token( self, organization_id: str, - token_type: OrganizationAuthTokenType, + token_type: Literal["api", "onepassword_service_account", "azure_client_secret_credential"], ) -> OrganizationAuthToken | AzureOrganizationAuthToken | None: try: + print("lol") async with self.Session() as session: if token := ( await session.scalars( diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 2232b658..7500eadb 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -200,7 +200,7 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization: async def convert_to_organization_auth_token( - org_auth_token: OrganizationAuthTokenModel, token_type: OrganizationAuthTokenType + org_auth_token: OrganizationAuthTokenModel, token_type: str ) -> OrganizationAuthToken | AzureOrganizationAuthToken: token = org_auth_token.token if org_auth_token.encrypted_token and org_auth_token.encrypted_method: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 60b1e70b..65386c55 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -1985,7 +1985,7 @@ async def get_api_keys( if organization_id != current_org.organization_id: raise HTTPException(status_code=403, detail="You do not have permission to access this organization") api_keys = [] - org_auth_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api) + org_auth_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value) if org_auth_token: api_keys.append(org_auth_token) return GetOrganizationAPIKeysResponse(api_keys=api_keys) diff --git a/skyvern/forge/sdk/routes/credentials.py b/skyvern/forge/sdk/routes/credentials.py index 9375c2e8..3dd7aaec 100644 --- a/skyvern/forge/sdk/routes/credentials.py +++ b/skyvern/forge/sdk/routes/credentials.py @@ -400,7 +400,7 @@ async def get_onepassword_token( try: auth_token = await app.DATABASE.get_valid_org_auth_token( organization_id=current_org.organization_id, - token_type=OrganizationAuthTokenType.onepassword_service_account, + token_type=OrganizationAuthTokenType.onepassword_service_account.value, ) if not auth_token: raise HTTPException( @@ -503,7 +503,7 @@ async def get_azure_client_secret_credential( try: auth_token = await app.DATABASE.get_valid_org_auth_token( organization_id=current_org.organization_id, - token_type=OrganizationAuthTokenType.azure_client_secret_credential, + token_type=OrganizationAuthTokenType.azure_client_secret_credential.value, ) if not auth_token: raise HTTPException( diff --git a/skyvern/forge/sdk/workflow/context_manager.py b/skyvern/forge/sdk/workflow/context_manager.py index 99a7f0f5..c1cd77fc 100644 --- a/skyvern/forge/sdk/workflow/context_manager.py +++ b/skyvern/forge/sdk/workflow/context_manager.py @@ -364,7 +364,8 @@ class WorkflowRunContext: self, parameter: OnePasswordCredentialParameter, organization: Organization ) -> None: org_auth_token = await app.DATABASE.get_valid_org_auth_token( - organization.organization_id, OrganizationAuthTokenType.onepassword_service_account + organization.organization_id, + OrganizationAuthTokenType.onepassword_service_account.value, ) token = settings.OP_SERVICE_ACCOUNT_TOKEN if org_auth_token: @@ -919,7 +920,7 @@ class WorkflowRunContext: @staticmethod async def _get_azure_vault_client_for_organization(organization: Organization) -> AsyncAzureVaultClient: org_auth_token = await app.DATABASE.get_valid_org_auth_token( - organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential + organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential.value ) if org_auth_token: azure_vault_client = AsyncAzureVaultClient.create_from_client_secret(org_auth_token.credential) diff --git a/skyvern/library/skyvern.py b/skyvern/library/skyvern.py index 87adec08..a7851ce2 100644 --- a/skyvern/library/skyvern.py +++ b/skyvern/library/skyvern.py @@ -124,7 +124,7 @@ class Skyvern(AsyncSkyvern): ) -> None: org_auth_token = await app.DATABASE.get_valid_org_auth_token( organization_id=organization.organization_id, - token_type=OrganizationAuthTokenType.api, + token_type=OrganizationAuthTokenType.api.value, ) step = await app.DATABASE.create_step( diff --git a/skyvern/services/task_v2_service.py b/skyvern/services/task_v2_service.py index 2bfb6739..bc7d5f72 100644 --- a/skyvern/services/task_v2_service.py +++ b/skyvern/services/task_v2_service.py @@ -1723,7 +1723,7 @@ async def send_task_v2_webhook(task_v2: TaskV2) -> None: return api_key = await app.DATABASE.get_valid_org_auth_token( organization_id, - OrganizationAuthTokenType.api, + OrganizationAuthTokenType.api.value, ) if not api_key: LOG.warning(