Enable Custom Oauth interface (#214)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-04-22 00:44:16 -07:00
committed by GitHub
parent 566ff8af4e
commit 55d14db971
3 changed files with 60 additions and 1 deletions

View File

@@ -10,9 +10,11 @@ from starlette.requests import HTTPConnection, Request
from starlette_context.middleware import RawContextMiddleware
from starlette_context.plugins.base import Plugin
from skyvern.forge import app as forge_app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.routes.agent_protocol import base_router
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.scheduler import SCHEDULER
LOG = structlog.get_logger()
@@ -86,6 +88,17 @@ def get_agent_app(router: APIRouter = base_router) -> FastAPI:
finally:
skyvern_context.reset()
if SettingsManager.get_settings().ADDITIONAL_MODULES:
for module in SettingsManager.get_settings().ADDITIONAL_MODULES:
LOG.info("Loading additional module to set up api app", module=module)
__import__(module)
LOG.info(
"Additional modules loaded to set up api app", modules=SettingsManager.get_settings().ADDITIONAL_MODULES
)
if forge_app.setup_api_app:
forge_app.setup_api_app(app)
return app

View File

@@ -1,7 +1,8 @@
from typing import Callable
from typing import Awaitable, Callable
from ddtrace import tracer
from ddtrace.filters import FilterRequestsOnUrl
from fastapi import FastAPI
from playwright.async_api import Page
from skyvern.forge.agent import ForgeAgent
@@ -38,5 +39,7 @@ LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_s
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
WORKFLOW_SERVICE = WorkflowService()
generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None
authentication_function: Callable[[str], Awaitable[Organization]] | None = None
setup_api_app: Callable[[FastAPI], None] | None = None
agent = ForgeAgent()

View File

@@ -21,7 +21,25 @@ ALGORITHM = "HS256"
async def get_current_org(
x_api_key: Annotated[str | None, Header()] = None,
authorization: Annotated[str | None, Header()] = None,
) -> Organization:
if not x_api_key and not authorization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
if x_api_key:
return await _get_current_org_cached(x_api_key, app.DATABASE)
elif authorization:
return await _authenticate_helper(authorization)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
async def get_current_org_with_api_key(x_api_key: Annotated[str | None, Header()] = None) -> Organization:
if not x_api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -30,6 +48,31 @@ async def get_current_org(
return await _get_current_org_cached(x_api_key, app.DATABASE)
async def get_current_org_with_authentication(authorization: Annotated[str | None, Header()] = None) -> Organization:
if not authorization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return await _authenticate_helper(authorization)
async def _authenticate_helper(authorization: str) -> Organization:
token = authorization.split(" ")[1]
if not app.authentication_function:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication method",
)
organization = await app.authentication_function(token)
if not organization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return organization
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
"""