diff --git a/skyvern/forge/api_app.py b/skyvern/forge/api_app.py index 74850b05..7b8f8525 100644 --- a/skyvern/forge/api_app.py +++ b/skyvern/forge/api_app.py @@ -10,9 +10,11 @@ from starlette.requests import HTTPConnection, Request from starlette_context.middleware import RawContextMiddleware from starlette_context.plugins.base import Plugin +from skyvern.forge import app as forge_app from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.routes.agent_protocol import base_router +from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.scheduler import SCHEDULER LOG = structlog.get_logger() @@ -86,6 +88,17 @@ def get_agent_app(router: APIRouter = base_router) -> FastAPI: finally: skyvern_context.reset() + if SettingsManager.get_settings().ADDITIONAL_MODULES: + for module in SettingsManager.get_settings().ADDITIONAL_MODULES: + LOG.info("Loading additional module to set up api app", module=module) + __import__(module) + LOG.info( + "Additional modules loaded to set up api app", modules=SettingsManager.get_settings().ADDITIONAL_MODULES + ) + + if forge_app.setup_api_app: + forge_app.setup_api_app(app) + return app diff --git a/skyvern/forge/app.py b/skyvern/forge/app.py index fc38ada1..c0c006a4 100644 --- a/skyvern/forge/app.py +++ b/skyvern/forge/app.py @@ -1,7 +1,8 @@ -from typing import Callable +from typing import Awaitable, Callable from ddtrace import tracer from ddtrace.filters import FilterRequestsOnUrl +from fastapi import FastAPI from playwright.async_api import Page from skyvern.forge.agent import ForgeAgent @@ -38,5 +39,7 @@ LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_s WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager() WORKFLOW_SERVICE = WorkflowService() generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None +authentication_function: Callable[[str], Awaitable[Organization]] | None = None +setup_api_app: Callable[[FastAPI], None] | None = None agent = ForgeAgent() diff --git a/skyvern/forge/sdk/services/org_auth_service.py b/skyvern/forge/sdk/services/org_auth_service.py index f3c7b705..7f0c9461 100644 --- a/skyvern/forge/sdk/services/org_auth_service.py +++ b/skyvern/forge/sdk/services/org_auth_service.py @@ -21,7 +21,25 @@ ALGORITHM = "HS256" async def get_current_org( x_api_key: Annotated[str | None, Header()] = None, + authorization: Annotated[str | None, Header()] = None, ) -> Organization: + if not x_api_key and not authorization: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid credentials", + ) + if x_api_key: + return await _get_current_org_cached(x_api_key, app.DATABASE) + elif authorization: + return await _authenticate_helper(authorization) + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid credentials", + ) + + +async def get_current_org_with_api_key(x_api_key: Annotated[str | None, Header()] = None) -> Organization: if not x_api_key: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -30,6 +48,31 @@ async def get_current_org( return await _get_current_org_cached(x_api_key, app.DATABASE) +async def get_current_org_with_authentication(authorization: Annotated[str | None, Header()] = None) -> Organization: + if not authorization: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid credentials", + ) + return await _authenticate_helper(authorization) + + +async def _authenticate_helper(authorization: str) -> Organization: + token = authorization.split(" ")[1] + if not app.authentication_function: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid authentication method", + ) + organization = await app.authentication_function(token) + if not organization: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid credentials", + ) + return organization + + @cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL)) async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization: """