add api key expired message to the 403 when an api key is expired/invalid (#532)
This commit is contained in:
@@ -566,18 +566,19 @@ class AgentDB:
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
valid: bool | None = True,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
if token_obj := (
|
||||
await session.scalars(
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(token=token)
|
||||
.filter_by(valid=True)
|
||||
)
|
||||
).first():
|
||||
query = (
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(token=token)
|
||||
)
|
||||
if valid is not None:
|
||||
query = query.filter_by(valid=valid)
|
||||
if token_obj := (await session.scalars(query)).first():
|
||||
return convert_to_organization_auth_token(token_obj)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -109,6 +109,7 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
token=x_api_key,
|
||||
valid=None,
|
||||
)
|
||||
if not api_key_db_obj:
|
||||
raise HTTPException(
|
||||
@@ -116,6 +117,12 @@ async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
|
||||
if api_key_db_obj.valid is False:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Your API key has expired. Please retrieve the latest one from https://app.skyvern.com/settings",
|
||||
)
|
||||
|
||||
# set organization_id in skyvern context and log context
|
||||
context = skyvern_context.current()
|
||||
if context:
|
||||
|
||||
Reference in New Issue
Block a user