fix mypy issue for org tokens (#3541)
This commit is contained in:
@@ -20,7 +20,7 @@ async def setup_local_organization() -> str:
|
||||
organization = await skyvern_agent.get_organization()
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
token_type=OrganizationAuthTokenType.api.value,
|
||||
)
|
||||
return org_auth_token.token if org_auth_token else ""
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ async def poll_verification_code(
|
||||
timeout = timedelta(minutes=settings.VERIFICATION_CODE_POLLING_TIMEOUT_MINS)
|
||||
start_datetime = datetime.utcnow()
|
||||
timeout_datetime = start_datetime + timeout
|
||||
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api)
|
||||
org_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value)
|
||||
if not org_token:
|
||||
LOG.error("Failed to get organization token when trying to get verification code")
|
||||
return None
|
||||
|
||||
@@ -874,22 +874,23 @@ class AgentDB:
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[OrganizationAuthTokenType.api, OrganizationAuthTokenType.onepassword_service_account],
|
||||
token_type: Literal["api", "onepassword_service_account"],
|
||||
) -> OrganizationAuthToken | None: ...
|
||||
|
||||
@overload
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: Literal[OrganizationAuthTokenType.azure_client_secret_credential],
|
||||
token_type: Literal["azure_client_secret_credential"],
|
||||
) -> AzureOrganizationAuthToken | None: ...
|
||||
|
||||
async def get_valid_org_auth_token(
|
||||
self,
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token_type: Literal["api", "onepassword_service_account", "azure_client_secret_credential"],
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken | None:
|
||||
try:
|
||||
print("lol")
|
||||
async with self.Session() as session:
|
||||
if token := (
|
||||
await session.scalars(
|
||||
|
||||
@@ -200,7 +200,7 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
||||
|
||||
|
||||
async def convert_to_organization_auth_token(
|
||||
org_auth_token: OrganizationAuthTokenModel, token_type: OrganizationAuthTokenType
|
||||
org_auth_token: OrganizationAuthTokenModel, token_type: str
|
||||
) -> OrganizationAuthToken | AzureOrganizationAuthToken:
|
||||
token = org_auth_token.token
|
||||
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
|
||||
|
||||
@@ -1985,7 +1985,7 @@ async def get_api_keys(
|
||||
if organization_id != current_org.organization_id:
|
||||
raise HTTPException(status_code=403, detail="You do not have permission to access this organization")
|
||||
api_keys = []
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api)
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(organization_id, OrganizationAuthTokenType.api.value)
|
||||
if org_auth_token:
|
||||
api_keys.append(org_auth_token)
|
||||
return GetOrganizationAPIKeysResponse(api_keys=api_keys)
|
||||
|
||||
@@ -400,7 +400,7 @@ async def get_onepassword_token(
|
||||
try:
|
||||
auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id=current_org.organization_id,
|
||||
token_type=OrganizationAuthTokenType.onepassword_service_account,
|
||||
token_type=OrganizationAuthTokenType.onepassword_service_account.value,
|
||||
)
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
@@ -503,7 +503,7 @@ async def get_azure_client_secret_credential(
|
||||
try:
|
||||
auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id=current_org.organization_id,
|
||||
token_type=OrganizationAuthTokenType.azure_client_secret_credential,
|
||||
token_type=OrganizationAuthTokenType.azure_client_secret_credential.value,
|
||||
)
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -364,7 +364,8 @@ class WorkflowRunContext:
|
||||
self, parameter: OnePasswordCredentialParameter, organization: Organization
|
||||
) -> None:
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization.organization_id, OrganizationAuthTokenType.onepassword_service_account
|
||||
organization.organization_id,
|
||||
OrganizationAuthTokenType.onepassword_service_account.value,
|
||||
)
|
||||
token = settings.OP_SERVICE_ACCOUNT_TOKEN
|
||||
if org_auth_token:
|
||||
@@ -919,7 +920,7 @@ class WorkflowRunContext:
|
||||
@staticmethod
|
||||
async def _get_azure_vault_client_for_organization(organization: Organization) -> AsyncAzureVaultClient:
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential
|
||||
organization.organization_id, OrganizationAuthTokenType.azure_client_secret_credential.value
|
||||
)
|
||||
if org_auth_token:
|
||||
azure_vault_client = AsyncAzureVaultClient.create_from_client_secret(org_auth_token.credential)
|
||||
|
||||
@@ -124,7 +124,7 @@ class Skyvern(AsyncSkyvern):
|
||||
) -> None:
|
||||
org_auth_token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
token_type=OrganizationAuthTokenType.api.value,
|
||||
)
|
||||
|
||||
step = await app.DATABASE.create_step(
|
||||
|
||||
@@ -1723,7 +1723,7 @@ async def send_task_v2_webhook(task_v2: TaskV2) -> None:
|
||||
return
|
||||
api_key = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id,
|
||||
OrganizationAuthTokenType.api,
|
||||
OrganizationAuthTokenType.api.value,
|
||||
)
|
||||
if not api_key:
|
||||
LOG.warning(
|
||||
|
||||
Reference in New Issue
Block a user