Move the code over from private repository (#3)
This commit is contained in:
0
skyvern/forge/sdk/services/__init__.py
Normal file
0
skyvern/forge/sdk/services/__init__.py
Normal file
76
skyvern/forge/sdk/services/org_auth_service.py
Normal file
76
skyvern/forge/sdk/services/org_auth_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
from asyncache import cached
|
||||
from cachetools import TTLCache
|
||||
from fastapi import Header, HTTPException, status
|
||||
from jose import jwt
|
||||
from jose.exceptions import JWTError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthTokenType, TokenPayload
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
AUTHENTICATION_TTL = 60 * 60 # one hour
|
||||
CACHE_SIZE = 128
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
async def get_current_org(
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> Organization:
|
||||
if not x_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
return await _get_current_org_cached(x_api_key, app.DATABASE)
|
||||
|
||||
|
||||
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
|
||||
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
"""
|
||||
Authentication is cached for one hour
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
x_api_key,
|
||||
SettingsManager.get_settings().SECRET_KEY,
|
||||
algorithms=[ALGORITHM],
|
||||
)
|
||||
api_key_data = TokenPayload(**payload)
|
||||
except (JWTError, ValidationError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
if api_key_data.exp < time.time():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Auth token is expired",
|
||||
)
|
||||
|
||||
organization = await db.get_organization(organization_id=api_key_data.sub)
|
||||
if not organization:
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# check if the token exists in the database
|
||||
api_key_db_obj = await db.validate_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
token=x_api_key,
|
||||
)
|
||||
if not api_key_db_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
|
||||
# set organization_id in skyvern context and log context
|
||||
context = skyvern_context.current()
|
||||
if context:
|
||||
context.organization_id = organization.organization_id
|
||||
return organization
|
||||
Reference in New Issue
Block a user