Enable Custom Oauth interface (#214)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-04-22 00:44:16 -07:00
committed by GitHub
parent 566ff8af4e
commit 55d14db971
3 changed files with 60 additions and 1 deletions

View File

@@ -10,9 +10,11 @@ from starlette.requests import HTTPConnection, Request
from starlette_context.middleware import RawContextMiddleware from starlette_context.middleware import RawContextMiddleware
from starlette_context.plugins.base import Plugin from starlette_context.plugins.base import Plugin
from skyvern.forge import app as forge_app
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.routes.agent_protocol import base_router from skyvern.forge.sdk.routes.agent_protocol import base_router
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.scheduler import SCHEDULER from skyvern.scheduler import SCHEDULER
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -86,6 +88,17 @@ def get_agent_app(router: APIRouter = base_router) -> FastAPI:
finally: finally:
skyvern_context.reset() skyvern_context.reset()
if SettingsManager.get_settings().ADDITIONAL_MODULES:
for module in SettingsManager.get_settings().ADDITIONAL_MODULES:
LOG.info("Loading additional module to set up api app", module=module)
__import__(module)
LOG.info(
"Additional modules loaded to set up api app", modules=SettingsManager.get_settings().ADDITIONAL_MODULES
)
if forge_app.setup_api_app:
forge_app.setup_api_app(app)
return app return app

View File

@@ -1,7 +1,8 @@
from typing import Callable from typing import Awaitable, Callable
from ddtrace import tracer from ddtrace import tracer
from ddtrace.filters import FilterRequestsOnUrl from ddtrace.filters import FilterRequestsOnUrl
from fastapi import FastAPI
from playwright.async_api import Page from playwright.async_api import Page
from skyvern.forge.agent import ForgeAgent from skyvern.forge.agent import ForgeAgent
@@ -38,5 +39,7 @@ LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_s
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager() WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
WORKFLOW_SERVICE = WorkflowService() WORKFLOW_SERVICE = WorkflowService()
generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None
authentication_function: Callable[[str], Awaitable[Organization]] | None = None
setup_api_app: Callable[[FastAPI], None] | None = None
agent = ForgeAgent() agent = ForgeAgent()

View File

@@ -21,7 +21,25 @@ ALGORITHM = "HS256"
async def get_current_org( async def get_current_org(
x_api_key: Annotated[str | None, Header()] = None, x_api_key: Annotated[str | None, Header()] = None,
authorization: Annotated[str | None, Header()] = None,
) -> Organization: ) -> Organization:
if not x_api_key and not authorization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
if x_api_key:
return await _get_current_org_cached(x_api_key, app.DATABASE)
elif authorization:
return await _authenticate_helper(authorization)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
async def get_current_org_with_api_key(x_api_key: Annotated[str | None, Header()] = None) -> Organization:
if not x_api_key: if not x_api_key:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@@ -30,6 +48,31 @@ async def get_current_org(
return await _get_current_org_cached(x_api_key, app.DATABASE) return await _get_current_org_cached(x_api_key, app.DATABASE)
async def get_current_org_with_authentication(authorization: Annotated[str | None, Header()] = None) -> Organization:
if not authorization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return await _authenticate_helper(authorization)
async def _authenticate_helper(authorization: str) -> Organization:
token = authorization.split(" ")[1]
if not app.authentication_function:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication method",
)
organization = await app.authentication_function(token)
if not organization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return organization
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL)) @cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization: async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
""" """