Enable Custom Oauth interface (#214)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -10,9 +10,11 @@ from starlette.requests import HTTPConnection, Request
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
from starlette_context.plugins.base import Plugin
|
||||
|
||||
from skyvern.forge import app as forge_app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.routes.agent_protocol import base_router
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.scheduler import SCHEDULER
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -86,6 +88,17 @@ def get_agent_app(router: APIRouter = base_router) -> FastAPI:
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
if SettingsManager.get_settings().ADDITIONAL_MODULES:
|
||||
for module in SettingsManager.get_settings().ADDITIONAL_MODULES:
|
||||
LOG.info("Loading additional module to set up api app", module=module)
|
||||
__import__(module)
|
||||
LOG.info(
|
||||
"Additional modules loaded to set up api app", modules=SettingsManager.get_settings().ADDITIONAL_MODULES
|
||||
)
|
||||
|
||||
if forge_app.setup_api_app:
|
||||
forge_app.setup_api_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Callable
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from ddtrace import tracer
|
||||
from ddtrace.filters import FilterRequestsOnUrl
|
||||
from fastapi import FastAPI
|
||||
from playwright.async_api import Page
|
||||
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
@@ -38,5 +39,7 @@ LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_s
|
||||
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
||||
WORKFLOW_SERVICE = WorkflowService()
|
||||
generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None
|
||||
authentication_function: Callable[[str], Awaitable[Organization]] | None = None
|
||||
setup_api_app: Callable[[FastAPI], None] | None = None
|
||||
|
||||
agent = ForgeAgent()
|
||||
|
||||
@@ -21,7 +21,25 @@ ALGORITHM = "HS256"
|
||||
|
||||
async def get_current_org(
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
authorization: Annotated[str | None, Header()] = None,
|
||||
) -> Organization:
|
||||
if not x_api_key and not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
if x_api_key:
|
||||
return await _get_current_org_cached(x_api_key, app.DATABASE)
|
||||
elif authorization:
|
||||
return await _authenticate_helper(authorization)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
|
||||
|
||||
async def get_current_org_with_api_key(x_api_key: Annotated[str | None, Header()] = None) -> Organization:
|
||||
if not x_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -30,6 +48,31 @@ async def get_current_org(
|
||||
return await _get_current_org_cached(x_api_key, app.DATABASE)
|
||||
|
||||
|
||||
async def get_current_org_with_authentication(authorization: Annotated[str | None, Header()] = None) -> Organization:
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
return await _authenticate_helper(authorization)
|
||||
|
||||
|
||||
async def _authenticate_helper(authorization: str) -> Organization:
|
||||
token = authorization.split(" ")[1]
|
||||
if not app.authentication_function:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication method",
|
||||
)
|
||||
organization = await app.authentication_function(token)
|
||||
if not organization:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
return organization
|
||||
|
||||
|
||||
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
|
||||
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user