Remove the base Agent; Separate skyvern agent and fastapi app (#213)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -17,4 +17,4 @@ if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
|
||||
reload = SettingsManager.get_settings().ENV == "local"
|
||||
uvicorn.run("skyvern.forge.app:app", host="0.0.0.0", port=port, log_level="info", reload=reload)
|
||||
uvicorn.run("skyvern.forge.api_app:app", host="0.0.0.0", port=port, log_level="info", reload=reload)
|
||||
|
||||
@@ -21,7 +21,6 @@ from skyvern.exceptions import (
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.async_operations import AgentPhase, AsyncOperationPool
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.agent import Agent
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
@@ -49,7 +48,7 @@ from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ForgeAgent(Agent):
|
||||
class ForgeAgent:
|
||||
def __init__(self) -> None:
|
||||
if SettingsManager.get_settings().ADDITIONAL_MODULES:
|
||||
for module in SettingsManager.get_settings().ADDITIONAL_MODULES:
|
||||
|
||||
92
skyvern/forge/api_app.py
Normal file
92
skyvern/forge/api_app.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, FastAPI, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
from starlette_context.plugins.base import Plugin
|
||||
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.routes.agent_protocol import base_router
|
||||
from skyvern.scheduler import SCHEDULER
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ExecutionDatePlugin(Plugin):
|
||||
key = "execution_date"
|
||||
|
||||
async def process_request(self, request: Request | HTTPConnection) -> datetime:
|
||||
return datetime.now()
|
||||
|
||||
|
||||
def get_agent_app(router: APIRouter = base_router) -> FastAPI:
|
||||
"""
|
||||
Start the agent server.
|
||||
"""
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware
|
||||
origins = [
|
||||
"http://localhost:5000",
|
||||
"http://127.0.0.1:5000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
# Add any other origins you want to whitelist
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
app.add_middleware(
|
||||
RawContextMiddleware,
|
||||
plugins=(
|
||||
# TODO (suchintan): We should set these up
|
||||
ExecutionDatePlugin(),
|
||||
# RequestIdPlugin(),
|
||||
# UserAgentPlugin(),
|
||||
),
|
||||
)
|
||||
|
||||
# Register the scheduler on startup so that we can schedule jobs dynamically
|
||||
@app.on_event("startup")
|
||||
def start_scheduler() -> None:
|
||||
LOG.info("Starting the skyvern scheduler.")
|
||||
SCHEDULER.start()
|
||||
|
||||
LOG.info("Server startup complete. Skyvern is now online")
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||
LOG.exception("Unexpected error in agent server.", exc_info=exc)
|
||||
return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"})
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
request_id = str(uuid.uuid4())
|
||||
skyvern_context.set(SkyvernContext(request_id=request_id))
|
||||
|
||||
try:
|
||||
return await call_next(request)
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = get_agent_app()
|
||||
@@ -40,5 +40,3 @@ WORKFLOW_SERVICE = WorkflowService()
|
||||
generate_async_operations: Callable[[Organization, Task, Page], list[AsyncOperation]] | None = None
|
||||
|
||||
agent = ForgeAgent()
|
||||
|
||||
app = agent.get_agent_app()
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, FastAPI, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
from starlette_context.plugins.base import Plugin
|
||||
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.routes.agent_protocol import base_router
|
||||
from skyvern.scheduler import SCHEDULER
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class Agent:
|
||||
def get_agent_app(self, router: APIRouter = base_router) -> FastAPI:
|
||||
"""
|
||||
Start the agent server.
|
||||
"""
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware
|
||||
origins = [
|
||||
"http://localhost:5000",
|
||||
"http://127.0.0.1:5000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:8080",
|
||||
# Add any other origins you want to whitelist
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
app.add_middleware(
|
||||
RawContextMiddleware,
|
||||
plugins=(
|
||||
# TODO (suchintan): We should set these up
|
||||
ExecutionDatePlugin(),
|
||||
# RequestIdPlugin(),
|
||||
# UserAgentPlugin(),
|
||||
),
|
||||
)
|
||||
|
||||
# Register the scheduler on startup so that we can schedule jobs dynamically
|
||||
@app.on_event("startup")
|
||||
def start_scheduler() -> None:
|
||||
LOG.info("Starting the skyvern scheduler.")
|
||||
SCHEDULER.start()
|
||||
|
||||
LOG.info("Server startup complete. Skyvern is now online")
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||
LOG.exception("Unexpected error in agent server.", exc_info=exc)
|
||||
return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"})
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
request_id = str(uuid.uuid4())
|
||||
skyvern_context.set(SkyvernContext(request_id=request_id))
|
||||
|
||||
try:
|
||||
return await call_next(request)
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class ExecutionDatePlugin(Plugin):
|
||||
key = "execution_date"
|
||||
|
||||
async def process_request(self, request: Request | HTTPConnection) -> datetime:
|
||||
return datetime.now()
|
||||
Reference in New Issue
Block a user