Enable Custom Oauth interface (#214)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -21,7 +21,25 @@ ALGORITHM = "HS256"
|
||||
|
||||
async def get_current_org(
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> Organization:
|
||||
if not x_api_key and not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
if x_api_key:
|
||||
return await _get_current_org_cached(x_api_key, app.DATABASE)
|
||||
elif authorization:
|
||||
return await _authenticate_helper(authorization)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
|
||||
|
||||
async def get_current_org_with_api_key(x_api_key: Annotated[str | None, Header()] = None) -> Organization:
|
||||
if not x_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -30,6 +48,31 @@ async def get_current_org(
|
||||
return await _get_current_org_cached(x_api_key, app.DATABASE)
|
||||
|
||||
|
||||
async def get_current_org_with_authentication(authorization: Annotated[str | None, Header()] = None) -> Organization:
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
return await _authenticate_helper(authorization)
|
||||
|
||||
|
||||
async def _authenticate_helper(authorization: str) -> Organization:
|
||||
token = authorization.split(" ")[1]
|
||||
if not app.authentication_function:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication method",
|
||||
)
|
||||
organization = await app.authentication_function(token)
|
||||
if not organization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
return organization
|
||||
|
||||
|
||||
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
|
||||
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user