From 566ff8af4e95ab6dacb10f8c0b3a86f21637724b Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Sun, 21 Apr 2024 16:46:27 -0700 Subject: [PATCH] Remove the base Agent; Separate skyvern agent and fastapi app (#213) Co-authored-by: Shuchang Zheng --- skyvern/forge/__main__.py | 2 +- skyvern/forge/agent.py | 3 +- skyvern/forge/api_app.py | 92 ++++++++++++++++++++++++++++++++++++++ skyvern/forge/app.py | 2 - skyvern/forge/sdk/agent.py | 90 ------------------------------------- 5 files changed, 94 insertions(+), 95 deletions(-) create mode 100644 skyvern/forge/api_app.py delete mode 100644 skyvern/forge/sdk/agent.py diff --git a/skyvern/forge/__main__.py b/skyvern/forge/__main__.py index a3bb6acf..db4246d4 100644 --- a/skyvern/forge/__main__.py +++ b/skyvern/forge/__main__.py @@ -17,4 +17,4 @@ if __name__ == "__main__": load_dotenv() reload = SettingsManager.get_settings().ENV == "local" - uvicorn.run("skyvern.forge.app:app", host="0.0.0.0", port=port, log_level="info", reload=reload) + uvicorn.run("skyvern.forge.api_app:app", host="0.0.0.0", port=port, log_level="info", reload=reload) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index c571f66f..e9d12996 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -21,7 +21,6 @@ from skyvern.exceptions import ( from skyvern.forge import app from skyvern.forge.async_operations import AgentPhase, AsyncOperationPool from skyvern.forge.prompts import prompt_engine -from skyvern.forge.sdk.agent import Agent from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.security import generate_skyvern_signature @@ -49,7 +48,7 @@ from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website LOG = structlog.get_logger() -class ForgeAgent(Agent): +class ForgeAgent: def __init__(self) -> None: if SettingsManager.get_settings().ADDITIONAL_MODULES: for module in SettingsManager.get_settings().ADDITIONAL_MODULES: diff --git a/skyvern/forge/api_app.py b/skyvern/forge/api_app.py new file mode 100644 index 00000000..74850b05 --- /dev/null +++ b/skyvern/forge/api_app.py @@ -0,0 +1,92 @@ +import uuid +from datetime import datetime +from typing import Awaitable, Callable + +import structlog +from fastapi import APIRouter, FastAPI, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from starlette.requests import HTTPConnection, Request +from starlette_context.middleware import RawContextMiddleware +from starlette_context.plugins.base import Plugin + +from skyvern.forge.sdk.core import skyvern_context +from skyvern.forge.sdk.core.skyvern_context import SkyvernContext +from skyvern.forge.sdk.routes.agent_protocol import base_router +from skyvern.scheduler import SCHEDULER + +LOG = structlog.get_logger() + + +class ExecutionDatePlugin(Plugin): + key = "execution_date" + + async def process_request(self, request: Request | HTTPConnection) -> datetime: + return datetime.now() + + +def get_agent_app(router: APIRouter = base_router) -> FastAPI: + """ + Start the agent server. + """ + + app = FastAPI() + + # Add CORS middleware + origins = [ + "http://localhost:5000", + "http://127.0.0.1:5000", + "http://localhost:8000", + "http://127.0.0.1:8000", + "http://localhost:8080", + "http://127.0.0.1:8080", + # Add any other origins you want to whitelist + ] + + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + app.include_router(router, prefix="/api/v1") + + app.add_middleware( + RawContextMiddleware, + plugins=( + # TODO (suchintan): We should set these up + ExecutionDatePlugin(), + # RequestIdPlugin(), + # UserAgentPlugin(), + ), + ) + + # Register the scheduler on startup so that we can schedule jobs dynamically + @app.on_event("startup") + def start_scheduler() -> None: + LOG.info("Starting the skyvern scheduler.") + SCHEDULER.start() + + LOG.info("Server startup complete. Skyvern is now online") + + @app.exception_handler(Exception) + async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse: + LOG.exception("Unexpected error in agent server.", exc_info=exc) + return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"}) + + @app.middleware("http") + async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + request_id = str(uuid.uuid4()) + skyvern_context.set(SkyvernContext(request_id=request_id)) + + try: + return await call_next(request) + finally: + skyvern_context.reset() + + return app + + +app = get_agent_app() diff --git a/skyvern/forge/app.py b/skyvern/forge/app.py index 7cec463f..fc38ada1 100644 --- a/skyvern/forge/app.py +++ b/skyvern/forge/app.py @@ -40,5 +40,3 @@ WORKFLOW_SERVICE = WorkflowService() generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None agent = ForgeAgent() - -app = agent.get_agent_app() diff --git a/skyvern/forge/sdk/agent.py b/skyvern/forge/sdk/agent.py deleted file mode 100644 index 61cb763f..00000000 --- a/skyvern/forge/sdk/agent.py +++ /dev/null @@ -1,90 +0,0 @@ -import uuid -from datetime import datetime -from typing import Awaitable, Callable - -import structlog -from fastapi import APIRouter, FastAPI, Response -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from starlette.requests import HTTPConnection, Request -from starlette_context.middleware import RawContextMiddleware -from starlette_context.plugins.base import Plugin - -from skyvern.forge.sdk.core import skyvern_context -from skyvern.forge.sdk.core.skyvern_context import SkyvernContext -from skyvern.forge.sdk.routes.agent_protocol import base_router -from skyvern.scheduler import SCHEDULER - -LOG = structlog.get_logger() - - -class Agent: - def get_agent_app(self, router: APIRouter = base_router) -> FastAPI: - """ - Start the agent server. - """ - - app = FastAPI() - - # Add CORS middleware - origins = [ - "http://localhost:5000", - "http://127.0.0.1:5000", - "http://localhost:8000", - "http://127.0.0.1:8000", - "http://localhost:8080", - "http://127.0.0.1:8080", - # Add any other origins you want to whitelist - ] - - app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - app.include_router(router, prefix="/api/v1") - - app.add_middleware( - RawContextMiddleware, - plugins=( - # TODO (suchintan): We should set these up - ExecutionDatePlugin(), - # RequestIdPlugin(), - # UserAgentPlugin(), - ), - ) - - # Register the scheduler on startup so that we can schedule jobs dynamically - @app.on_event("startup") - def start_scheduler() -> None: - LOG.info("Starting the skyvern scheduler.") - SCHEDULER.start() - - LOG.info("Server startup complete. Skyvern is now online") - - @app.exception_handler(Exception) - async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse: - LOG.exception("Unexpected error in agent server.", exc_info=exc) - return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"}) - - @app.middleware("http") - async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: - request_id = str(uuid.uuid4()) - skyvern_context.set(SkyvernContext(request_id=request_id)) - - try: - return await call_next(request) - finally: - skyvern_context.reset() - - return app - - -class ExecutionDatePlugin(Plugin): - key = "execution_date" - - async def process_request(self, request: Request | HTTPConnection) -> datetime: - return datetime.now()