Initialize app at runtime instead of import time (#4024)

This commit is contained in:
Stanislav Novosad
2025-11-18 17:56:58 -07:00
committed by GitHub
parent f7e68141eb
commit 0efae234ab
14 changed files with 319 additions and 183 deletions

View File

@@ -16,6 +16,7 @@ from starlette_context.plugins.base import Plugin
from skyvern.config import settings
from skyvern.exceptions import SkyvernHTTPException
from skyvern.forge import app as forge_app
from skyvern.forge.forge_app_initializer import start_forge_app
from skyvern.forge.request_logging import log_raw_request_middleware
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
@@ -33,7 +34,7 @@ class ExecutionDatePlugin(Plugin):
return datetime.now()
def custom_openapi() -> dict:
def custom_openapi(app: FastAPI) -> dict:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
@@ -54,6 +55,7 @@ def custom_openapi() -> dict:
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncGenerator[None, Any]:
"""Lifespan context manager for FastAPI app startup and shutdown."""
LOG.info("Server started")
if forge_app.api_app_startup_event:
LOG.info("Calling api app startup event")
@@ -71,15 +73,17 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, Any]:
LOG.info("Server shutting down")
def get_agent_app() -> FastAPI:
def create_api_app() -> FastAPI:
"""
Start the agent server.
"""
app = FastAPI(lifespan=lifespan)
forge_app_instance = start_forge_app()
fastapi_app = FastAPI(lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
@@ -87,19 +91,19 @@ def get_agent_app() -> FastAPI:
allow_headers=["*"],
)
app.include_router(base_router, prefix="/v1")
app.include_router(legacy_base_router, prefix="/api/v1")
app.include_router(legacy_v2_router, prefix="/api/v2")
fastapi_app.include_router(base_router, prefix="/v1")
fastapi_app.include_router(legacy_base_router, prefix="/api/v1")
fastapi_app.include_router(legacy_v2_router, prefix="/api/v2")
# local dev endpoints
if settings.ENV == "local":
app.include_router(internal_auth.router, prefix="/v1")
app.include_router(internal_auth.router, prefix="/api/v1")
app.include_router(internal_auth.router, prefix="/api/v2")
fastapi_app.include_router(internal_auth.router, prefix="/v1")
fastapi_app.include_router(internal_auth.router, prefix="/api/v1")
fastapi_app.include_router(internal_auth.router, prefix="/api/v2")
app.openapi = custom_openapi
fastapi_app.openapi = lambda: custom_openapi(fastapi_app)
app.add_middleware(
fastapi_app.add_middleware(
RawContextMiddleware,
plugins=(
# TODO (suchintan): We should set these up
@@ -109,27 +113,27 @@ def get_agent_app() -> FastAPI:
),
)
@app.exception_handler(NotFoundError)
@fastapi_app.exception_handler(NotFoundError)
async def handle_not_found_error(request: Request, exc: NotFoundError) -> Response:
return Response(status_code=status.HTTP_404_NOT_FOUND)
@app.exception_handler(SkyvernHTTPException)
@fastapi_app.exception_handler(SkyvernHTTPException)
async def handle_skyvern_http_exception(request: Request, exc: SkyvernHTTPException) -> JSONResponse:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.message})
@app.exception_handler(ValidationError)
@fastapi_app.exception_handler(ValidationError)
async def handle_pydantic_validation_error(request: Request, exc: ValidationError) -> JSONResponse:
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={"detail": str(exc)},
)
@app.exception_handler(Exception)
@fastapi_app.exception_handler(Exception)
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
LOG.exception("Unexpected error in agent server.", exc_info=exc)
return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"})
@app.middleware("http")
@fastapi_app.middleware("http")
async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
curr_ctx = skyvern_context.current()
if not curr_ctx:
@@ -143,23 +147,11 @@ def get_agent_app() -> FastAPI:
finally:
skyvern_context.reset()
@app.middleware("http")
@fastapi_app.middleware("http")
async def raw_request_logging(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
return await log_raw_request_middleware(request, call_next)
if settings.ADDITIONAL_MODULES:
for module in settings.ADDITIONAL_MODULES:
LOG.info("Loading additional module to set up api app", module=module)
__import__(module)
LOG.info(
"Additional modules loaded to set up api app",
modules=settings.ADDITIONAL_MODULES,
)
if forge_app_instance.setup_api_app:
forge_app_instance.setup_api_app(fastapi_app)
if forge_app.setup_api_app:
forge_app.setup_api_app(app)
return app
app = get_agent_app()
return fastapi_app