diff --git a/evaluation/script/create_webvoyager_task_v2.py b/evaluation/script/create_webvoyager_task_v2.py index 1f777dfb..ebb2dffb 100644 --- a/evaluation/script/create_webvoyager_task_v2.py +++ b/evaluation/script/create_webvoyager_task_v2.py @@ -9,6 +9,7 @@ from dotenv import load_dotenv from evaluation.core import Evaluator, SkyvernClient from evaluation.core.utils import load_webvoyager_case_from_json from skyvern.forge import app +from skyvern.forge.forge_app_initializer import start_forge_app from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request @@ -19,6 +20,8 @@ async def create_task_v2( base_url: str, cred: str, ) -> None: + start_forge_app() + client = SkyvernClient(base_url=base_url, credentials=cred) group_id = uuid4() diff --git a/evaluation/script/create_webvoyager_workflow.py b/evaluation/script/create_webvoyager_workflow.py index 577c0f78..3060678c 100644 --- a/evaluation/script/create_webvoyager_workflow.py +++ b/evaluation/script/create_webvoyager_workflow.py @@ -10,6 +10,7 @@ from dotenv import load_dotenv from evaluation.core import Evaluator, SkyvernClient from evaluation.core.utils import load_webvoyager_case_from_json from skyvern.forge import app +from skyvern.forge.forge_app_initializer import start_forge_app from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody from skyvern.schemas.runs import ProxyLocation @@ -69,6 +70,8 @@ def main( None, "--proxy-location", help="overwrite the workflow proxy location" ), ) -> None: + start_forge_app() + asyncio.run( create_workflow_run(base_url=base_url, cred=cred, workflow_pid=workflow_pid, proxy_location=proxy_location) ) diff --git a/evaluation/script/eval_webvoyager_task_v2.py b/evaluation/script/eval_webvoyager_task_v2.py index 8bca329c..d1eef4d1 100644 --- a/evaluation/script/eval_webvoyager_task_v2.py +++ b/evaluation/script/eval_webvoyager_task_v2.py @@ -7,6 +7,7 @@ import typer from dotenv import load_dotenv from evaluation.core import Evaluator, SkyvernClient +from skyvern.forge.forge_app_initializer import start_forge_app from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus load_dotenv() @@ -109,6 +110,8 @@ def main( record_json_path: str = typer.Option(..., "--record-json", help="record json path for evaluation run"), output_csv_path: str = typer.Option("output.csv", "--output-path", help="output csv path for evaluation run"), ) -> None: + start_forge_app() + asyncio.run( run_eval(base_url=base_url, cred=cred, record_json_path=record_json_path, output_csv_path=output_csv_path) ) diff --git a/skyvern/cli/run_commands.py b/skyvern/cli/run_commands.py index 9804e18a..312cb3ff 100644 --- a/skyvern/cli/run_commands.py +++ b/skyvern/cli/run_commands.py @@ -95,10 +95,11 @@ def run_server() -> None: port = settings.PORT console.print(Panel(f"[bold green]Starting Skyvern API Server on port {port}...", border_style="green")) uvicorn.run( - "skyvern.forge.api_app:app", + "skyvern.forge.api_app:create_api_app", host="0.0.0.0", port=port, log_level="info", + factory=True, ) diff --git a/skyvern/forge/__init__.py b/skyvern/forge/__init__.py index e69de29b..1df634d5 100644 --- a/skyvern/forge/__init__.py +++ b/skyvern/forge/__init__.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import typing +from typing import Any + +if typing.TYPE_CHECKING: + from skyvern.forge.forge_app import ForgeApp + + +class AppHolder: + def __init__(self) -> None: + object.__setattr__(self, "_inst", None) + + def set_app(self, inst: ForgeApp) -> None: + object.__setattr__(self, "_inst", inst) + + def __getattr__(self, name: str) -> Any: + inst = object.__getattribute__(self, "_inst") + if inst is None: + raise RuntimeError("ForgeApp is not initialized, start_forge_app should be called") + + return getattr(inst, name) + + def __setattr__(self, name: str, value: Any) -> None: + inst = object.__getattribute__(self, "_inst") + if inst is None: + raise RuntimeError("ForgeApp is not initialized, start_forge_app should be called") + + setattr(inst, name, value) + + +if typing.TYPE_CHECKING: + app: ForgeApp +else: + app = AppHolder() # type: ignore + + +def set_force_app_instance(inst: ForgeApp) -> None: + app.set_app(inst) # type: ignore diff --git a/skyvern/forge/__main__.py b/skyvern/forge/__main__.py index 5b30c7ca..0d47c4b9 100644 --- a/skyvern/forge/__main__.py +++ b/skyvern/forge/__main__.py @@ -30,7 +30,7 @@ if __name__ == "__main__": ) uvicorn.run( - "skyvern.forge.api_app:app", + "skyvern.forge.api_app:create_api_app", host="0.0.0.0", port=port, log_level="info", @@ -39,4 +39,5 @@ if __name__ == "__main__": f"{temp_path_for_excludes}/**/*.py", f"{artifact_path_for_excludes}/{settings.ENV}/**/scripts/**/**/*.py", ], + factory=True, ) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index e6e25bdd..50e577d7 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -153,14 +153,6 @@ class ActionLinkedNode: class ForgeAgent: def __init__(self) -> None: - if settings.ADDITIONAL_MODULES: - for module in settings.ADDITIONAL_MODULES: - LOG.debug("Loading additional module", module=module) - __import__(module) - LOG.debug( - "Additional modules loaded", - modules=settings.ADDITIONAL_MODULES, - ) self.async_operation_pool = AsyncOperationPool() async def create_task_and_step_from_block( diff --git a/skyvern/forge/api_app.py b/skyvern/forge/api_app.py index a62bef55..2c3501f5 100644 --- a/skyvern/forge/api_app.py +++ b/skyvern/forge/api_app.py @@ -16,6 +16,7 @@ from starlette_context.plugins.base import Plugin from skyvern.config import settings from skyvern.exceptions import SkyvernHTTPException from skyvern.forge import app as forge_app +from skyvern.forge.forge_app_initializer import start_forge_app from skyvern.forge.request_logging import log_raw_request_middleware from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext @@ -33,7 +34,7 @@ class ExecutionDatePlugin(Plugin): return datetime.now() -def custom_openapi() -> dict: +def custom_openapi(app: FastAPI) -> dict: if app.openapi_schema: return app.openapi_schema openapi_schema = get_openapi( @@ -54,6 +55,7 @@ def custom_openapi() -> dict: @asynccontextmanager async def lifespan(_: FastAPI) -> AsyncGenerator[None, Any]: """Lifespan context manager for FastAPI app startup and shutdown.""" + LOG.info("Server started") if forge_app.api_app_startup_event: LOG.info("Calling api app startup event") @@ -71,15 +73,17 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, Any]: LOG.info("Server shutting down") -def get_agent_app() -> FastAPI: +def create_api_app() -> FastAPI: """ Start the agent server. """ - app = FastAPI(lifespan=lifespan) + forge_app_instance = start_forge_app() + + fastapi_app = FastAPI(lifespan=lifespan) # Add CORS middleware - app.add_middleware( + fastapi_app.add_middleware( CORSMiddleware, allow_origins=settings.ALLOWED_ORIGINS, allow_credentials=True, @@ -87,19 +91,19 @@ def get_agent_app() -> FastAPI: allow_headers=["*"], ) - app.include_router(base_router, prefix="/v1") - app.include_router(legacy_base_router, prefix="/api/v1") - app.include_router(legacy_v2_router, prefix="/api/v2") + fastapi_app.include_router(base_router, prefix="/v1") + fastapi_app.include_router(legacy_base_router, prefix="/api/v1") + fastapi_app.include_router(legacy_v2_router, prefix="/api/v2") # local dev endpoints if settings.ENV == "local": - app.include_router(internal_auth.router, prefix="/v1") - app.include_router(internal_auth.router, prefix="/api/v1") - app.include_router(internal_auth.router, prefix="/api/v2") + fastapi_app.include_router(internal_auth.router, prefix="/v1") + fastapi_app.include_router(internal_auth.router, prefix="/api/v1") + fastapi_app.include_router(internal_auth.router, prefix="/api/v2") - app.openapi = custom_openapi + fastapi_app.openapi = lambda: custom_openapi(fastapi_app) - app.add_middleware( + fastapi_app.add_middleware( RawContextMiddleware, plugins=( # TODO (suchintan): We should set these up @@ -109,27 +113,27 @@ def get_agent_app() -> FastAPI: ), ) - @app.exception_handler(NotFoundError) + @fastapi_app.exception_handler(NotFoundError) async def handle_not_found_error(request: Request, exc: NotFoundError) -> Response: return Response(status_code=status.HTTP_404_NOT_FOUND) - @app.exception_handler(SkyvernHTTPException) + @fastapi_app.exception_handler(SkyvernHTTPException) async def handle_skyvern_http_exception(request: Request, exc: SkyvernHTTPException) -> JSONResponse: return JSONResponse(status_code=exc.status_code, content={"detail": exc.message}) - @app.exception_handler(ValidationError) + @fastapi_app.exception_handler(ValidationError) async def handle_pydantic_validation_error(request: Request, exc: ValidationError) -> JSONResponse: return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={"detail": str(exc)}, ) - @app.exception_handler(Exception) + @fastapi_app.exception_handler(Exception) async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse: LOG.exception("Unexpected error in agent server.", exc_info=exc) return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"}) - @app.middleware("http") + @fastapi_app.middleware("http") async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: curr_ctx = skyvern_context.current() if not curr_ctx: @@ -143,23 +147,11 @@ def get_agent_app() -> FastAPI: finally: skyvern_context.reset() - @app.middleware("http") + @fastapi_app.middleware("http") async def raw_request_logging(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: return await log_raw_request_middleware(request, call_next) - if settings.ADDITIONAL_MODULES: - for module in settings.ADDITIONAL_MODULES: - LOG.info("Loading additional module to set up api app", module=module) - __import__(module) - LOG.info( - "Additional modules loaded to set up api app", - modules=settings.ADDITIONAL_MODULES, - ) + if forge_app_instance.setup_api_app: + forge_app_instance.setup_api_app(fastapi_app) - if forge_app.setup_api_app: - forge_app.setup_api_app(app) - - return app - - -app = get_agent_app() + return fastapi_app diff --git a/skyvern/forge/app.py b/skyvern/forge/app.py deleted file mode 100644 index d069962b..00000000 --- a/skyvern/forge/app.py +++ /dev/null @@ -1,136 +0,0 @@ -from typing import Awaitable, Callable - -from anthropic import AsyncAnthropic, AsyncAnthropicBedrock -from fastapi import FastAPI -from openai import AsyncAzureOpenAI, AsyncOpenAI - -from skyvern.forge.agent import ForgeAgent -from skyvern.forge.agent_functions import AgentFunction -from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory -from skyvern.forge.sdk.artifact.manager import ArtifactManager -from skyvern.forge.sdk.artifact.storage.factory import StorageFactory -from skyvern.forge.sdk.artifact.storage.s3 import S3Storage -from skyvern.forge.sdk.cache.factory import CacheFactory -from skyvern.forge.sdk.db.client import AgentDB -from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider -from skyvern.forge.sdk.schemas.credentials import CredentialVaultType -from skyvern.forge.sdk.schemas.organizations import Organization -from skyvern.forge.sdk.services.credential.azure_credential_vault_service import AzureCredentialVaultService -from skyvern.forge.sdk.services.credential.bitwarden_credential_service import BitwardenCredentialVaultService -from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService -from skyvern.forge.sdk.settings_manager import SettingsManager -from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager -from skyvern.forge.sdk.workflow.service import WorkflowService -from skyvern.webeye.browser_manager import BrowserManager -from skyvern.webeye.persistent_sessions_manager import PersistentSessionsManager -from skyvern.webeye.scraper.scraper import ScrapeExcludeFunc - -SETTINGS_MANAGER = SettingsManager.get_settings() -DATABASE = AgentDB( - SettingsManager.get_settings().DATABASE_STRING, - debug_enabled=SettingsManager.get_settings().DEBUG_MODE, -) -if SettingsManager.get_settings().SKYVERN_STORAGE_TYPE == "s3": - StorageFactory.set_storage(S3Storage()) -STORAGE = StorageFactory.get_storage() -CACHE = CacheFactory.get_cache() -ARTIFACT_MANAGER = ArtifactManager() -BROWSER_MANAGER = BrowserManager() -EXPERIMENTATION_PROVIDER: BaseExperimentationProvider = NoOpExperimentationProvider() -LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY) -OPENAI_CLIENT = AsyncOpenAI(api_key=SettingsManager.get_settings().OPENAI_API_KEY or "") -if SettingsManager.get_settings().ENABLE_AZURE_CUA: - OPENAI_CLIENT = AsyncAzureOpenAI( - api_key=SettingsManager.get_settings().AZURE_CUA_API_KEY, - api_version=SettingsManager.get_settings().AZURE_CUA_API_VERSION, - azure_endpoint=SettingsManager.get_settings().AZURE_CUA_ENDPOINT, - azure_deployment=SettingsManager.get_settings().AZURE_CUA_DEPLOYMENT, - ) -ANTHROPIC_CLIENT = AsyncAnthropic(api_key=SettingsManager.get_settings().ANTHROPIC_API_KEY) -if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC: - ANTHROPIC_CLIENT = AsyncAnthropicBedrock() - -# Add UI-TARS client setup -UI_TARS_CLIENT = None -if SettingsManager.get_settings().ENABLE_VOLCENGINE: - UI_TARS_CLIENT = AsyncOpenAI( - api_key=SettingsManager.get_settings().VOLCENGINE_API_KEY, - base_url=SettingsManager.get_settings().VOLCENGINE_API_BASE, - ) - -SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( - SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY -) -SELECT_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( - SETTINGS_MANAGER.SELECT_AGENT_LLM_KEY or SETTINGS_MANAGER.SECONDARY_LLM_KEY or SETTINGS_MANAGER.LLM_KEY -) -NORMAL_SELECT_AGENT_LLM_API_HANDLER = ( - LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.NORMAL_SELECT_AGENT_LLM_KEY) - if SETTINGS_MANAGER.NORMAL_SELECT_AGENT_LLM_KEY - else SECONDARY_LLM_API_HANDLER -) -CUSTOM_SELECT_AGENT_LLM_API_HANDLER = ( - LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.CUSTOM_SELECT_AGENT_LLM_KEY) - if SETTINGS_MANAGER.CUSTOM_SELECT_AGENT_LLM_KEY - else SECONDARY_LLM_API_HANDLER -) -SINGLE_CLICK_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( - SETTINGS_MANAGER.SINGLE_CLICK_AGENT_LLM_KEY or SETTINGS_MANAGER.SECONDARY_LLM_KEY or SETTINGS_MANAGER.LLM_KEY -) -SINGLE_INPUT_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( - SETTINGS_MANAGER.SINGLE_INPUT_AGENT_LLM_KEY or SETTINGS_MANAGER.SECONDARY_LLM_KEY or SETTINGS_MANAGER.LLM_KEY -) -PARSE_SELECT_LLM_API_HANDLER = ( - LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.PARSE_SELECT_LLM_KEY) - if SETTINGS_MANAGER.PARSE_SELECT_LLM_KEY - else SECONDARY_LLM_API_HANDLER -) -EXTRACTION_LLM_API_HANDLER = ( - LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.EXTRACTION_LLM_KEY) - if SETTINGS_MANAGER.EXTRACTION_LLM_KEY - else LLM_API_HANDLER -) -CHECK_USER_GOAL_LLM_API_HANDLER = ( - LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.CHECK_USER_GOAL_LLM_KEY) - if SETTINGS_MANAGER.CHECK_USER_GOAL_LLM_KEY - else SECONDARY_LLM_API_HANDLER -) -AUTO_COMPLETION_LLM_API_HANDLER = ( - LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.AUTO_COMPLETION_LLM_KEY) - if SETTINGS_MANAGER.AUTO_COMPLETION_LLM_KEY - else SECONDARY_LLM_API_HANDLER -) -SVG_CSS_CONVERTER_LLM_API_HANDLER = SECONDARY_LLM_API_HANDLER if SETTINGS_MANAGER.SECONDARY_LLM_KEY else None -SCRIPT_GENERATION_LLM_API_HANDLER = ( - LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.SCRIPT_GENERATION_LLM_KEY) - if SETTINGS_MANAGER.SCRIPT_GENERATION_LLM_KEY - else SECONDARY_LLM_API_HANDLER -) - -WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager() -WORKFLOW_SERVICE = WorkflowService() -AGENT_FUNCTION = AgentFunction() -PERSISTENT_SESSIONS_MANAGER = PersistentSessionsManager(database=DATABASE) - -BITWARDEN_CREDENTIAL_VAULT_SERVICE: BitwardenCredentialVaultService = BitwardenCredentialVaultService() -AZURE_CREDENTIAL_VAULT_SERVICE: AzureCredentialVaultService | None = None -if SettingsManager.get_settings().AZURE_CREDENTIAL_VAULT: - AZURE_CREDENTIAL_VAULT_SERVICE = AzureCredentialVaultService( - tenant_id=SettingsManager.get_settings().AZURE_TENANT_ID, # type: ignore - client_id=SettingsManager.get_settings().AZURE_CLIENT_ID, # type: ignore - client_secret=SettingsManager.get_settings().AZURE_CLIENT_SECRET, # type: ignore - vault_name=SettingsManager.get_settings().AZURE_CREDENTIAL_VAULT, # type: ignore - ) -CREDENTIAL_VAULT_SERVICES: dict[str, CredentialVaultService | None] = { - CredentialVaultType.BITWARDEN: BITWARDEN_CREDENTIAL_VAULT_SERVICE, - CredentialVaultType.AZURE_VAULT: AZURE_CREDENTIAL_VAULT_SERVICE, -} - -scrape_exclude: ScrapeExcludeFunc | None = None -authentication_function: Callable[[str], Awaitable[Organization]] | None = None -authenticate_user_function: Callable[[str], Awaitable[str | None]] | None = None -setup_api_app: Callable[[FastAPI], None] | None = None -api_app_startup_event: Callable[[], Awaitable[None]] | None = None -api_app_shutdown_event: Callable[[], Awaitable[None]] | None = None - -agent = ForgeAgent() diff --git a/skyvern/forge/forge_app.py b/skyvern/forge/forge_app.py new file mode 100644 index 00000000..6a4b3a27 --- /dev/null +++ b/skyvern/forge/forge_app.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from typing import Awaitable, Callable + +from anthropic import AsyncAnthropic, AsyncAnthropicBedrock +from fastapi import FastAPI +from openai import AsyncAzureOpenAI, AsyncOpenAI + +from skyvern.config import Settings +from skyvern.forge.agent import ForgeAgent +from skyvern.forge.agent_functions import AgentFunction +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory +from skyvern.forge.sdk.api.llm.models import LLMAPIHandler +from skyvern.forge.sdk.artifact.manager import ArtifactManager +from skyvern.forge.sdk.artifact.storage.base import BaseStorage +from skyvern.forge.sdk.artifact.storage.factory import StorageFactory +from skyvern.forge.sdk.artifact.storage.s3 import S3Storage +from skyvern.forge.sdk.cache.base import BaseCache +from skyvern.forge.sdk.cache.factory import CacheFactory +from skyvern.forge.sdk.db.client import AgentDB +from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider +from skyvern.forge.sdk.schemas.credentials import CredentialVaultType +from skyvern.forge.sdk.schemas.organizations import Organization +from skyvern.forge.sdk.services.credential.azure_credential_vault_service import AzureCredentialVaultService +from skyvern.forge.sdk.services.credential.bitwarden_credential_service import BitwardenCredentialVaultService +from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService +from skyvern.forge.sdk.settings_manager import SettingsManager +from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager +from skyvern.forge.sdk.workflow.service import WorkflowService +from skyvern.webeye.browser_manager import BrowserManager +from skyvern.webeye.persistent_sessions_manager import PersistentSessionsManager +from skyvern.webeye.scraper.scraper import ScrapeExcludeFunc + + +class ForgeApp: + """Container for shared Forge services""" + + SETTINGS_MANAGER: Settings + DATABASE: AgentDB + STORAGE: BaseStorage + CACHE: BaseCache + ARTIFACT_MANAGER: ArtifactManager + BROWSER_MANAGER: BrowserManager + EXPERIMENTATION_PROVIDER: BaseExperimentationProvider + LLM_API_HANDLER: LLMAPIHandler + OPENAI_CLIENT: AsyncOpenAI | AsyncAzureOpenAI + ANTHROPIC_CLIENT: AsyncAnthropic | AsyncAnthropicBedrock + UI_TARS_CLIENT: AsyncOpenAI | None + SECONDARY_LLM_API_HANDLER: LLMAPIHandler + SELECT_AGENT_LLM_API_HANDLER: LLMAPIHandler + NORMAL_SELECT_AGENT_LLM_API_HANDLER: LLMAPIHandler + CUSTOM_SELECT_AGENT_LLM_API_HANDLER: LLMAPIHandler + SINGLE_CLICK_AGENT_LLM_API_HANDLER: LLMAPIHandler + SINGLE_INPUT_AGENT_LLM_API_HANDLER: LLMAPIHandler + PARSE_SELECT_LLM_API_HANDLER: LLMAPIHandler + EXTRACTION_LLM_API_HANDLER: LLMAPIHandler + CHECK_USER_GOAL_LLM_API_HANDLER: LLMAPIHandler + AUTO_COMPLETION_LLM_API_HANDLER: LLMAPIHandler + SVG_CSS_CONVERTER_LLM_API_HANDLER: LLMAPIHandler | None + SCRIPT_GENERATION_LLM_API_HANDLER: LLMAPIHandler + WORKFLOW_CONTEXT_MANAGER: WorkflowContextManager + WORKFLOW_SERVICE: WorkflowService + AGENT_FUNCTION: AgentFunction + PERSISTENT_SESSIONS_MANAGER: PersistentSessionsManager + BITWARDEN_CREDENTIAL_VAULT_SERVICE: BitwardenCredentialVaultService + AZURE_CREDENTIAL_VAULT_SERVICE: AzureCredentialVaultService | None + CREDENTIAL_VAULT_SERVICES: dict[str, CredentialVaultService | None] + scrape_exclude: ScrapeExcludeFunc | None + authentication_function: Callable[[str], Awaitable[Organization]] | None + authenticate_user_function: Callable[[str], Awaitable[str | None]] | None + setup_api_app: Callable[[FastAPI], None] | None + api_app_startup_event: Callable[[], Awaitable[None]] | None + api_app_shutdown_event: Callable[[], Awaitable[None]] | None + agent: ForgeAgent + + +def create_forge_app() -> ForgeApp: + """Create and initialize a ForgeApp instance with all services""" + settings: Settings = SettingsManager.get_settings() + + app = ForgeApp() + + app.SETTINGS_MANAGER = settings + + app.DATABASE = AgentDB(settings.DATABASE_STRING, debug_enabled=settings.DEBUG_MODE) + if settings.SKYVERN_STORAGE_TYPE == "s3": + StorageFactory.set_storage(S3Storage()) + app.STORAGE = StorageFactory.get_storage() + app.CACHE = CacheFactory.get_cache() + app.ARTIFACT_MANAGER = ArtifactManager() + app.BROWSER_MANAGER = BrowserManager() + app.EXPERIMENTATION_PROVIDER = NoOpExperimentationProvider() + + app.LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(settings.LLM_KEY) + app.OPENAI_CLIENT = AsyncOpenAI(api_key=settings.OPENAI_API_KEY or "") + if settings.ENABLE_AZURE_CUA: + app.OPENAI_CLIENT = AsyncAzureOpenAI( + api_key=settings.AZURE_CUA_API_KEY, + api_version=settings.AZURE_CUA_API_VERSION, + azure_endpoint=settings.AZURE_CUA_ENDPOINT, + azure_deployment=settings.AZURE_CUA_DEPLOYMENT, + ) + + app.ANTHROPIC_CLIENT = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY) + if settings.ENABLE_BEDROCK_ANTHROPIC: + app.ANTHROPIC_CLIENT = AsyncAnthropicBedrock() + + app.UI_TARS_CLIENT = None + if settings.ENABLE_VOLCENGINE: + app.UI_TARS_CLIENT = AsyncOpenAI( + api_key=settings.VOLCENGINE_API_KEY, + base_url=settings.VOLCENGINE_API_BASE, + ) + + app.SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( + settings.SECONDARY_LLM_KEY if settings.SECONDARY_LLM_KEY else settings.LLM_KEY + ) + app.SELECT_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( + settings.SELECT_AGENT_LLM_KEY or settings.SECONDARY_LLM_KEY or settings.LLM_KEY + ) + app.NORMAL_SELECT_AGENT_LLM_API_HANDLER = ( + LLMAPIHandlerFactory.get_llm_api_handler(settings.NORMAL_SELECT_AGENT_LLM_KEY) + if settings.NORMAL_SELECT_AGENT_LLM_KEY + else app.SECONDARY_LLM_API_HANDLER + ) + app.CUSTOM_SELECT_AGENT_LLM_API_HANDLER = ( + LLMAPIHandlerFactory.get_llm_api_handler(settings.CUSTOM_SELECT_AGENT_LLM_KEY) + if settings.CUSTOM_SELECT_AGENT_LLM_KEY + else app.SECONDARY_LLM_API_HANDLER + ) + app.SINGLE_CLICK_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( + settings.SINGLE_CLICK_AGENT_LLM_KEY or settings.SECONDARY_LLM_KEY or settings.LLM_KEY + ) + app.SINGLE_INPUT_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( + settings.SINGLE_INPUT_AGENT_LLM_KEY or settings.SECONDARY_LLM_KEY or settings.LLM_KEY + ) + app.PARSE_SELECT_LLM_API_HANDLER = ( + LLMAPIHandlerFactory.get_llm_api_handler(settings.PARSE_SELECT_LLM_KEY) + if settings.PARSE_SELECT_LLM_KEY + else app.SECONDARY_LLM_API_HANDLER + ) + app.EXTRACTION_LLM_API_HANDLER = ( + LLMAPIHandlerFactory.get_llm_api_handler(settings.EXTRACTION_LLM_KEY) + if settings.EXTRACTION_LLM_KEY + else app.LLM_API_HANDLER + ) + app.CHECK_USER_GOAL_LLM_API_HANDLER = ( + LLMAPIHandlerFactory.get_llm_api_handler(settings.CHECK_USER_GOAL_LLM_KEY) + if settings.CHECK_USER_GOAL_LLM_KEY + else app.SECONDARY_LLM_API_HANDLER + ) + app.AUTO_COMPLETION_LLM_API_HANDLER = ( + LLMAPIHandlerFactory.get_llm_api_handler(settings.AUTO_COMPLETION_LLM_KEY) + if settings.AUTO_COMPLETION_LLM_KEY + else app.SECONDARY_LLM_API_HANDLER + ) + app.SVG_CSS_CONVERTER_LLM_API_HANDLER = app.SECONDARY_LLM_API_HANDLER if settings.SECONDARY_LLM_KEY else None + app.SCRIPT_GENERATION_LLM_API_HANDLER = ( + LLMAPIHandlerFactory.get_llm_api_handler(settings.SCRIPT_GENERATION_LLM_KEY) + if settings.SCRIPT_GENERATION_LLM_KEY + else app.SECONDARY_LLM_API_HANDLER + ) + + app.WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager() + app.WORKFLOW_SERVICE = WorkflowService() + app.AGENT_FUNCTION = AgentFunction() + app.PERSISTENT_SESSIONS_MANAGER = PersistentSessionsManager(database=app.DATABASE) + + app.BITWARDEN_CREDENTIAL_VAULT_SERVICE = BitwardenCredentialVaultService() + app.AZURE_CREDENTIAL_VAULT_SERVICE = ( + AzureCredentialVaultService( + tenant_id=settings.AZURE_TENANT_ID, # type: ignore[arg-type] + client_id=settings.AZURE_CLIENT_ID, # type: ignore[arg-type] + client_secret=settings.AZURE_CLIENT_SECRET, # type: ignore[arg-type] + vault_name=settings.AZURE_CREDENTIAL_VAULT, # type: ignore[arg-type] + ) + if settings.AZURE_CREDENTIAL_VAULT + else None + ) + app.CREDENTIAL_VAULT_SERVICES = { + CredentialVaultType.BITWARDEN: app.BITWARDEN_CREDENTIAL_VAULT_SERVICE, + CredentialVaultType.AZURE_VAULT: app.AZURE_CREDENTIAL_VAULT_SERVICE, + } + + app.scrape_exclude = None + app.authentication_function = None + app.authenticate_user_function = None + app.setup_api_app = None + app.api_app_startup_event = None + app.api_app_shutdown_event = None + + app.agent = ForgeAgent() + + return app diff --git a/skyvern/forge/forge_app_initializer.py b/skyvern/forge/forge_app_initializer.py new file mode 100644 index 00000000..52c59f4b --- /dev/null +++ b/skyvern/forge/forge_app_initializer.py @@ -0,0 +1,28 @@ +import structlog + +from skyvern.config import settings +from skyvern.forge import set_force_app_instance +from skyvern.forge.forge_app import ForgeApp, create_forge_app + +LOG = structlog.get_logger() + + +def start_forge_app() -> ForgeApp: + force_app_instance = create_forge_app() + set_force_app_instance(force_app_instance) + + if settings.ADDITIONAL_MODULES: + for module in settings.ADDITIONAL_MODULES: + LOG.info("Loading additional module to set up api app", module=module) + app_module = __import__(module) + configure_app_fn = getattr(app_module, "configure_app", None) + if not configure_app_fn: + raise RuntimeError(f"Missing configure_app function in {module}") + + configure_app_fn(force_app_instance) + LOG.info( + "Additional modules loaded to set up api app", + modules=settings.ADDITIONAL_MODULES, + ) + + return force_app_instance diff --git a/skyvern/library/embedded_server_factory.py b/skyvern/library/embedded_server_factory.py index d1852172..cc4b38ca 100644 --- a/skyvern/library/embedded_server_factory.py +++ b/skyvern/library/embedded_server_factory.py @@ -12,14 +12,15 @@ def create_embedded_server( async def handle_async_request(self, request: httpx.Request) -> httpx.Response: if self._transport is None: from skyvern.config import settings # noqa: PLC0415 - from skyvern.forge.api_app import app # noqa: PLC0415 settings.BROWSER_LOGS_ENABLED = False if openai_api_key: settings.OPENAI_API_KEY = openai_api_key - self._transport = ASGITransport(app=app) + from skyvern.forge.api_app import create_api_app # noqa: PLC0415 + + self._transport = ASGITransport(app=create_api_app()) response = await self._transport.handle_async_request(request) return response diff --git a/tests/unit_tests/test_file_parser_block.py b/tests/unit_tests/test_file_parser_block.py index be845c2c..f0c36175 100644 --- a/tests/unit_tests/test_file_parser_block.py +++ b/tests/unit_tests/test_file_parser_block.py @@ -6,10 +6,18 @@ from unittest.mock import MagicMock, patch import pandas as pd import pytest +from skyvern.forge import app +from skyvern.forge.forge_app_initializer import start_forge_app from skyvern.forge.sdk.workflow.models.block import FileParserBlock, FileType from skyvern.forge.sdk.workflow.models.parameter import OutputParameter +@pytest.fixture(scope="module", autouse=True) +def setup_forge_app(): + start_forge_app() + yield + + class TestFileParserBlock: @pytest.fixture def file_parser_block(self): @@ -172,7 +180,7 @@ class TestFileParserBlock: # Mock the LLM response mock_response = {"extracted_data": {"names": ["John", "Jane"], "total_count": 2}} - with patch("skyvern.forge.sdk.workflow.models.block.app.LLM_API_HANDLER") as mock_llm: + with patch.object(object.__getattribute__(app, "_inst"), "LLM_API_HANDLER") as mock_llm: mock_llm.return_value = mock_response with patch("skyvern.forge.sdk.workflow.models.block.prompt_engine.load_prompt") as mock_prompt: @@ -190,7 +198,7 @@ class TestFileParserBlock: # Mock the LLM response mock_response = {"output": {"summary": "Extracted data from file"}} - with patch("skyvern.forge.sdk.workflow.models.block.app.LLM_API_HANDLER") as mock_llm: + with patch.object(object.__getattribute__(app, "_inst"), "LLM_API_HANDLER") as mock_llm: mock_llm.return_value = mock_response with patch("skyvern.forge.sdk.workflow.models.block.prompt_engine.load_prompt") as mock_prompt: diff --git a/tests/unit_tests/test_openrouter_integration.py b/tests/unit_tests/test_openrouter_integration.py index 93329d5b..de5cbf47 100644 --- a/tests/unit_tests/test_openrouter_integration.py +++ b/tests/unit_tests/test_openrouter_integration.py @@ -8,10 +8,17 @@ import pytest from skyvern import config from skyvern.config import Settings from skyvern.forge import app +from skyvern.forge.forge_app_initializer import start_forge_app from skyvern.forge.sdk.api.llm import api_handler_factory, config_registry from skyvern.forge.sdk.settings_manager import SettingsManager +@pytest.fixture(scope="module", autouse=True) +def setup_forge_app(): + start_forge_app() + yield + + class DummyResponse(dict): def __init__(self, content: str): super().__init__({"choices": [{"message": {"content": content}}], "usage": {}})