Initialize app at runtime instead of import time (#4024)
This commit is contained in:
committed by
GitHub
parent
f7e68141eb
commit
0efae234ab
@@ -9,6 +9,7 @@ from dotenv import load_dotenv
|
||||
from evaluation.core import Evaluator, SkyvernClient
|
||||
from evaluation.core.utils import load_webvoyager_case_from_json
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.forge_app_initializer import start_forge_app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
||||
|
||||
@@ -19,6 +20,8 @@ async def create_task_v2(
|
||||
base_url: str,
|
||||
cred: str,
|
||||
) -> None:
|
||||
start_forge_app()
|
||||
|
||||
client = SkyvernClient(base_url=base_url, credentials=cred)
|
||||
group_id = uuid4()
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from dotenv import load_dotenv
|
||||
from evaluation.core import Evaluator, SkyvernClient
|
||||
from evaluation.core.utils import load_webvoyager_case_from_json
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.forge_app_initializer import start_forge_app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRequestBody
|
||||
from skyvern.schemas.runs import ProxyLocation
|
||||
@@ -69,6 +70,8 @@ def main(
|
||||
None, "--proxy-location", help="overwrite the workflow proxy location"
|
||||
),
|
||||
) -> None:
|
||||
start_forge_app()
|
||||
|
||||
asyncio.run(
|
||||
create_workflow_run(base_url=base_url, cred=cred, workflow_pid=workflow_pid, proxy_location=proxy_location)
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ import typer
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from evaluation.core import Evaluator, SkyvernClient
|
||||
from skyvern.forge.forge_app_initializer import start_forge_app
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
|
||||
load_dotenv()
|
||||
@@ -109,6 +110,8 @@ def main(
|
||||
record_json_path: str = typer.Option(..., "--record-json", help="record json path for evaluation run"),
|
||||
output_csv_path: str = typer.Option("output.csv", "--output-path", help="output csv path for evaluation run"),
|
||||
) -> None:
|
||||
start_forge_app()
|
||||
|
||||
asyncio.run(
|
||||
run_eval(base_url=base_url, cred=cred, record_json_path=record_json_path, output_csv_path=output_csv_path)
|
||||
)
|
||||
|
||||
@@ -95,10 +95,11 @@ def run_server() -> None:
|
||||
port = settings.PORT
|
||||
console.print(Panel(f"[bold green]Starting Skyvern API Server on port {port}...", border_style="green"))
|
||||
uvicorn.run(
|
||||
"skyvern.forge.api_app:app",
|
||||
"skyvern.forge.api_app:create_api_app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
log_level="info",
|
||||
factory=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from skyvern.forge.forge_app import ForgeApp
|
||||
|
||||
|
||||
class AppHolder:
|
||||
def __init__(self) -> None:
|
||||
object.__setattr__(self, "_inst", None)
|
||||
|
||||
def set_app(self, inst: ForgeApp) -> None:
|
||||
object.__setattr__(self, "_inst", inst)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
inst = object.__getattribute__(self, "_inst")
|
||||
if inst is None:
|
||||
raise RuntimeError("ForgeApp is not initialized, start_forge_app should be called")
|
||||
|
||||
return getattr(inst, name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
inst = object.__getattribute__(self, "_inst")
|
||||
if inst is None:
|
||||
raise RuntimeError("ForgeApp is not initialized, start_forge_app should be called")
|
||||
|
||||
setattr(inst, name, value)
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
app: ForgeApp
|
||||
else:
|
||||
app = AppHolder() # type: ignore
|
||||
|
||||
|
||||
def set_force_app_instance(inst: ForgeApp) -> None:
|
||||
app.set_app(inst) # type: ignore
|
||||
|
||||
@@ -30,7 +30,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
uvicorn.run(
|
||||
"skyvern.forge.api_app:app",
|
||||
"skyvern.forge.api_app:create_api_app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
log_level="info",
|
||||
@@ -39,4 +39,5 @@ if __name__ == "__main__":
|
||||
f"{temp_path_for_excludes}/**/*.py",
|
||||
f"{artifact_path_for_excludes}/{settings.ENV}/**/scripts/**/**/*.py",
|
||||
],
|
||||
factory=True,
|
||||
)
|
||||
|
||||
@@ -153,14 +153,6 @@ class ActionLinkedNode:
|
||||
|
||||
class ForgeAgent:
|
||||
def __init__(self) -> None:
|
||||
if settings.ADDITIONAL_MODULES:
|
||||
for module in settings.ADDITIONAL_MODULES:
|
||||
LOG.debug("Loading additional module", module=module)
|
||||
__import__(module)
|
||||
LOG.debug(
|
||||
"Additional modules loaded",
|
||||
modules=settings.ADDITIONAL_MODULES,
|
||||
)
|
||||
self.async_operation_pool = AsyncOperationPool()
|
||||
|
||||
async def create_task_and_step_from_block(
|
||||
|
||||
@@ -16,6 +16,7 @@ from starlette_context.plugins.base import Plugin
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import SkyvernHTTPException
|
||||
from skyvern.forge import app as forge_app
|
||||
from skyvern.forge.forge_app_initializer import start_forge_app
|
||||
from skyvern.forge.request_logging import log_raw_request_middleware
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
@@ -33,7 +34,7 @@ class ExecutionDatePlugin(Plugin):
|
||||
return datetime.now()
|
||||
|
||||
|
||||
def custom_openapi() -> dict:
|
||||
def custom_openapi(app: FastAPI) -> dict:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
@@ -54,6 +55,7 @@ def custom_openapi() -> dict:
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI) -> AsyncGenerator[None, Any]:
|
||||
"""Lifespan context manager for FastAPI app startup and shutdown."""
|
||||
|
||||
LOG.info("Server started")
|
||||
if forge_app.api_app_startup_event:
|
||||
LOG.info("Calling api app startup event")
|
||||
@@ -71,15 +73,17 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, Any]:
|
||||
LOG.info("Server shutting down")
|
||||
|
||||
|
||||
def get_agent_app() -> FastAPI:
|
||||
def create_api_app() -> FastAPI:
|
||||
"""
|
||||
Start the agent server.
|
||||
"""
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
forge_app_instance = start_forge_app()
|
||||
|
||||
fastapi_app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
fastapi_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
@@ -87,19 +91,19 @@ def get_agent_app() -> FastAPI:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(base_router, prefix="/v1")
|
||||
app.include_router(legacy_base_router, prefix="/api/v1")
|
||||
app.include_router(legacy_v2_router, prefix="/api/v2")
|
||||
fastapi_app.include_router(base_router, prefix="/v1")
|
||||
fastapi_app.include_router(legacy_base_router, prefix="/api/v1")
|
||||
fastapi_app.include_router(legacy_v2_router, prefix="/api/v2")
|
||||
|
||||
# local dev endpoints
|
||||
if settings.ENV == "local":
|
||||
app.include_router(internal_auth.router, prefix="/v1")
|
||||
app.include_router(internal_auth.router, prefix="/api/v1")
|
||||
app.include_router(internal_auth.router, prefix="/api/v2")
|
||||
fastapi_app.include_router(internal_auth.router, prefix="/v1")
|
||||
fastapi_app.include_router(internal_auth.router, prefix="/api/v1")
|
||||
fastapi_app.include_router(internal_auth.router, prefix="/api/v2")
|
||||
|
||||
app.openapi = custom_openapi
|
||||
fastapi_app.openapi = lambda: custom_openapi(fastapi_app)
|
||||
|
||||
app.add_middleware(
|
||||
fastapi_app.add_middleware(
|
||||
RawContextMiddleware,
|
||||
plugins=(
|
||||
# TODO (suchintan): We should set these up
|
||||
@@ -109,27 +113,27 @@ def get_agent_app() -> FastAPI:
|
||||
),
|
||||
)
|
||||
|
||||
@app.exception_handler(NotFoundError)
|
||||
@fastapi_app.exception_handler(NotFoundError)
|
||||
async def handle_not_found_error(request: Request, exc: NotFoundError) -> Response:
|
||||
return Response(status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
@app.exception_handler(SkyvernHTTPException)
|
||||
@fastapi_app.exception_handler(SkyvernHTTPException)
|
||||
async def handle_skyvern_http_exception(request: Request, exc: SkyvernHTTPException) -> JSONResponse:
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.message})
|
||||
|
||||
@app.exception_handler(ValidationError)
|
||||
@fastapi_app.exception_handler(ValidationError)
|
||||
async def handle_pydantic_validation_error(request: Request, exc: ValidationError) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
@fastapi_app.exception_handler(Exception)
|
||||
async def unexpected_exception(request: Request, exc: Exception) -> JSONResponse:
|
||||
LOG.exception("Unexpected error in agent server.", exc_info=exc)
|
||||
return JSONResponse(status_code=500, content={"error": f"Unexpected error: {exc}"})
|
||||
|
||||
@app.middleware("http")
|
||||
@fastapi_app.middleware("http")
|
||||
async def request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
curr_ctx = skyvern_context.current()
|
||||
if not curr_ctx:
|
||||
@@ -143,23 +147,11 @@ def get_agent_app() -> FastAPI:
|
||||
finally:
|
||||
skyvern_context.reset()
|
||||
|
||||
@app.middleware("http")
|
||||
@fastapi_app.middleware("http")
|
||||
async def raw_request_logging(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
return await log_raw_request_middleware(request, call_next)
|
||||
|
||||
if settings.ADDITIONAL_MODULES:
|
||||
for module in settings.ADDITIONAL_MODULES:
|
||||
LOG.info("Loading additional module to set up api app", module=module)
|
||||
__import__(module)
|
||||
LOG.info(
|
||||
"Additional modules loaded to set up api app",
|
||||
modules=settings.ADDITIONAL_MODULES,
|
||||
)
|
||||
if forge_app_instance.setup_api_app:
|
||||
forge_app_instance.setup_api_app(fastapi_app)
|
||||
|
||||
if forge_app.setup_api_app:
|
||||
forge_app.setup_api_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = get_agent_app()
|
||||
return fastapi_app
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
|
||||
from fastapi import FastAPI
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.agent_functions import AgentFunction
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
||||
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
||||
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
|
||||
from skyvern.forge.sdk.cache.factory import CacheFactory
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider
|
||||
from skyvern.forge.sdk.schemas.credentials import CredentialVaultType
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.services.credential.azure_credential_vault_service import AzureCredentialVaultService
|
||||
from skyvern.forge.sdk.services.credential.bitwarden_credential_service import BitwardenCredentialVaultService
|
||||
from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
|
||||
from skyvern.forge.sdk.workflow.service import WorkflowService
|
||||
from skyvern.webeye.browser_manager import BrowserManager
|
||||
from skyvern.webeye.persistent_sessions_manager import PersistentSessionsManager
|
||||
from skyvern.webeye.scraper.scraper import ScrapeExcludeFunc
|
||||
|
||||
SETTINGS_MANAGER = SettingsManager.get_settings()
|
||||
DATABASE = AgentDB(
|
||||
SettingsManager.get_settings().DATABASE_STRING,
|
||||
debug_enabled=SettingsManager.get_settings().DEBUG_MODE,
|
||||
)
|
||||
if SettingsManager.get_settings().SKYVERN_STORAGE_TYPE == "s3":
|
||||
StorageFactory.set_storage(S3Storage())
|
||||
STORAGE = StorageFactory.get_storage()
|
||||
CACHE = CacheFactory.get_cache()
|
||||
ARTIFACT_MANAGER = ArtifactManager()
|
||||
BROWSER_MANAGER = BrowserManager()
|
||||
EXPERIMENTATION_PROVIDER: BaseExperimentationProvider = NoOpExperimentationProvider()
|
||||
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
|
||||
OPENAI_CLIENT = AsyncOpenAI(api_key=SettingsManager.get_settings().OPENAI_API_KEY or "")
|
||||
if SettingsManager.get_settings().ENABLE_AZURE_CUA:
|
||||
OPENAI_CLIENT = AsyncAzureOpenAI(
|
||||
api_key=SettingsManager.get_settings().AZURE_CUA_API_KEY,
|
||||
api_version=SettingsManager.get_settings().AZURE_CUA_API_VERSION,
|
||||
azure_endpoint=SettingsManager.get_settings().AZURE_CUA_ENDPOINT,
|
||||
azure_deployment=SettingsManager.get_settings().AZURE_CUA_DEPLOYMENT,
|
||||
)
|
||||
ANTHROPIC_CLIENT = AsyncAnthropic(api_key=SettingsManager.get_settings().ANTHROPIC_API_KEY)
|
||||
if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC:
|
||||
ANTHROPIC_CLIENT = AsyncAnthropicBedrock()
|
||||
|
||||
# Add UI-TARS client setup
|
||||
UI_TARS_CLIENT = None
|
||||
if SettingsManager.get_settings().ENABLE_VOLCENGINE:
|
||||
UI_TARS_CLIENT = AsyncOpenAI(
|
||||
api_key=SettingsManager.get_settings().VOLCENGINE_API_KEY,
|
||||
base_url=SettingsManager.get_settings().VOLCENGINE_API_BASE,
|
||||
)
|
||||
|
||||
SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY
|
||||
)
|
||||
SELECT_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
SETTINGS_MANAGER.SELECT_AGENT_LLM_KEY or SETTINGS_MANAGER.SECONDARY_LLM_KEY or SETTINGS_MANAGER.LLM_KEY
|
||||
)
|
||||
NORMAL_SELECT_AGENT_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.NORMAL_SELECT_AGENT_LLM_KEY)
|
||||
if SETTINGS_MANAGER.NORMAL_SELECT_AGENT_LLM_KEY
|
||||
else SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
CUSTOM_SELECT_AGENT_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.CUSTOM_SELECT_AGENT_LLM_KEY)
|
||||
if SETTINGS_MANAGER.CUSTOM_SELECT_AGENT_LLM_KEY
|
||||
else SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
SINGLE_CLICK_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
SETTINGS_MANAGER.SINGLE_CLICK_AGENT_LLM_KEY or SETTINGS_MANAGER.SECONDARY_LLM_KEY or SETTINGS_MANAGER.LLM_KEY
|
||||
)
|
||||
SINGLE_INPUT_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
SETTINGS_MANAGER.SINGLE_INPUT_AGENT_LLM_KEY or SETTINGS_MANAGER.SECONDARY_LLM_KEY or SETTINGS_MANAGER.LLM_KEY
|
||||
)
|
||||
PARSE_SELECT_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.PARSE_SELECT_LLM_KEY)
|
||||
if SETTINGS_MANAGER.PARSE_SELECT_LLM_KEY
|
||||
else SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
EXTRACTION_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.EXTRACTION_LLM_KEY)
|
||||
if SETTINGS_MANAGER.EXTRACTION_LLM_KEY
|
||||
else LLM_API_HANDLER
|
||||
)
|
||||
CHECK_USER_GOAL_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.CHECK_USER_GOAL_LLM_KEY)
|
||||
if SETTINGS_MANAGER.CHECK_USER_GOAL_LLM_KEY
|
||||
else SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
AUTO_COMPLETION_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.AUTO_COMPLETION_LLM_KEY)
|
||||
if SETTINGS_MANAGER.AUTO_COMPLETION_LLM_KEY
|
||||
else SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
SVG_CSS_CONVERTER_LLM_API_HANDLER = SECONDARY_LLM_API_HANDLER if SETTINGS_MANAGER.SECONDARY_LLM_KEY else None
|
||||
SCRIPT_GENERATION_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(SETTINGS_MANAGER.SCRIPT_GENERATION_LLM_KEY)
|
||||
if SETTINGS_MANAGER.SCRIPT_GENERATION_LLM_KEY
|
||||
else SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
|
||||
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
||||
WORKFLOW_SERVICE = WorkflowService()
|
||||
AGENT_FUNCTION = AgentFunction()
|
||||
PERSISTENT_SESSIONS_MANAGER = PersistentSessionsManager(database=DATABASE)
|
||||
|
||||
BITWARDEN_CREDENTIAL_VAULT_SERVICE: BitwardenCredentialVaultService = BitwardenCredentialVaultService()
|
||||
AZURE_CREDENTIAL_VAULT_SERVICE: AzureCredentialVaultService | None = None
|
||||
if SettingsManager.get_settings().AZURE_CREDENTIAL_VAULT:
|
||||
AZURE_CREDENTIAL_VAULT_SERVICE = AzureCredentialVaultService(
|
||||
tenant_id=SettingsManager.get_settings().AZURE_TENANT_ID, # type: ignore
|
||||
client_id=SettingsManager.get_settings().AZURE_CLIENT_ID, # type: ignore
|
||||
client_secret=SettingsManager.get_settings().AZURE_CLIENT_SECRET, # type: ignore
|
||||
vault_name=SettingsManager.get_settings().AZURE_CREDENTIAL_VAULT, # type: ignore
|
||||
)
|
||||
CREDENTIAL_VAULT_SERVICES: dict[str, CredentialVaultService | None] = {
|
||||
CredentialVaultType.BITWARDEN: BITWARDEN_CREDENTIAL_VAULT_SERVICE,
|
||||
CredentialVaultType.AZURE_VAULT: AZURE_CREDENTIAL_VAULT_SERVICE,
|
||||
}
|
||||
|
||||
scrape_exclude: ScrapeExcludeFunc | None = None
|
||||
authentication_function: Callable[[str], Awaitable[Organization]] | None = None
|
||||
authenticate_user_function: Callable[[str], Awaitable[str | None]] | None = None
|
||||
setup_api_app: Callable[[FastAPI], None] | None = None
|
||||
api_app_startup_event: Callable[[], Awaitable[None]] | None = None
|
||||
api_app_shutdown_event: Callable[[], Awaitable[None]] | None = None
|
||||
|
||||
agent = ForgeAgent()
|
||||
194
skyvern/forge/forge_app.py
Normal file
194
skyvern/forge/forge_app.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
|
||||
from fastapi import FastAPI
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
from skyvern.config import Settings
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.agent_functions import AgentFunction
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
|
||||
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
||||
from skyvern.forge.sdk.artifact.storage.base import BaseStorage
|
||||
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
||||
from skyvern.forge.sdk.artifact.storage.s3 import S3Storage
|
||||
from skyvern.forge.sdk.cache.base import BaseCache
|
||||
from skyvern.forge.sdk.cache.factory import CacheFactory
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.experimentation.providers import BaseExperimentationProvider, NoOpExperimentationProvider
|
||||
from skyvern.forge.sdk.schemas.credentials import CredentialVaultType
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.services.credential.azure_credential_vault_service import AzureCredentialVaultService
|
||||
from skyvern.forge.sdk.services.credential.bitwarden_credential_service import BitwardenCredentialVaultService
|
||||
from skyvern.forge.sdk.services.credential.credential_vault_service import CredentialVaultService
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowContextManager
|
||||
from skyvern.forge.sdk.workflow.service import WorkflowService
|
||||
from skyvern.webeye.browser_manager import BrowserManager
|
||||
from skyvern.webeye.persistent_sessions_manager import PersistentSessionsManager
|
||||
from skyvern.webeye.scraper.scraper import ScrapeExcludeFunc
|
||||
|
||||
|
||||
class ForgeApp:
|
||||
"""Container for shared Forge services"""
|
||||
|
||||
SETTINGS_MANAGER: Settings
|
||||
DATABASE: AgentDB
|
||||
STORAGE: BaseStorage
|
||||
CACHE: BaseCache
|
||||
ARTIFACT_MANAGER: ArtifactManager
|
||||
BROWSER_MANAGER: BrowserManager
|
||||
EXPERIMENTATION_PROVIDER: BaseExperimentationProvider
|
||||
LLM_API_HANDLER: LLMAPIHandler
|
||||
OPENAI_CLIENT: AsyncOpenAI | AsyncAzureOpenAI
|
||||
ANTHROPIC_CLIENT: AsyncAnthropic | AsyncAnthropicBedrock
|
||||
UI_TARS_CLIENT: AsyncOpenAI | None
|
||||
SECONDARY_LLM_API_HANDLER: LLMAPIHandler
|
||||
SELECT_AGENT_LLM_API_HANDLER: LLMAPIHandler
|
||||
NORMAL_SELECT_AGENT_LLM_API_HANDLER: LLMAPIHandler
|
||||
CUSTOM_SELECT_AGENT_LLM_API_HANDLER: LLMAPIHandler
|
||||
SINGLE_CLICK_AGENT_LLM_API_HANDLER: LLMAPIHandler
|
||||
SINGLE_INPUT_AGENT_LLM_API_HANDLER: LLMAPIHandler
|
||||
PARSE_SELECT_LLM_API_HANDLER: LLMAPIHandler
|
||||
EXTRACTION_LLM_API_HANDLER: LLMAPIHandler
|
||||
CHECK_USER_GOAL_LLM_API_HANDLER: LLMAPIHandler
|
||||
AUTO_COMPLETION_LLM_API_HANDLER: LLMAPIHandler
|
||||
SVG_CSS_CONVERTER_LLM_API_HANDLER: LLMAPIHandler | None
|
||||
SCRIPT_GENERATION_LLM_API_HANDLER: LLMAPIHandler
|
||||
WORKFLOW_CONTEXT_MANAGER: WorkflowContextManager
|
||||
WORKFLOW_SERVICE: WorkflowService
|
||||
AGENT_FUNCTION: AgentFunction
|
||||
PERSISTENT_SESSIONS_MANAGER: PersistentSessionsManager
|
||||
BITWARDEN_CREDENTIAL_VAULT_SERVICE: BitwardenCredentialVaultService
|
||||
AZURE_CREDENTIAL_VAULT_SERVICE: AzureCredentialVaultService | None
|
||||
CREDENTIAL_VAULT_SERVICES: dict[str, CredentialVaultService | None]
|
||||
scrape_exclude: ScrapeExcludeFunc | None
|
||||
authentication_function: Callable[[str], Awaitable[Organization]] | None
|
||||
authenticate_user_function: Callable[[str], Awaitable[str | None]] | None
|
||||
setup_api_app: Callable[[FastAPI], None] | None
|
||||
api_app_startup_event: Callable[[], Awaitable[None]] | None
|
||||
api_app_shutdown_event: Callable[[], Awaitable[None]] | None
|
||||
agent: ForgeAgent
|
||||
|
||||
|
||||
def create_forge_app() -> ForgeApp:
|
||||
"""Create and initialize a ForgeApp instance with all services"""
|
||||
settings: Settings = SettingsManager.get_settings()
|
||||
|
||||
app = ForgeApp()
|
||||
|
||||
app.SETTINGS_MANAGER = settings
|
||||
|
||||
app.DATABASE = AgentDB(settings.DATABASE_STRING, debug_enabled=settings.DEBUG_MODE)
|
||||
if settings.SKYVERN_STORAGE_TYPE == "s3":
|
||||
StorageFactory.set_storage(S3Storage())
|
||||
app.STORAGE = StorageFactory.get_storage()
|
||||
app.CACHE = CacheFactory.get_cache()
|
||||
app.ARTIFACT_MANAGER = ArtifactManager()
|
||||
app.BROWSER_MANAGER = BrowserManager()
|
||||
app.EXPERIMENTATION_PROVIDER = NoOpExperimentationProvider()
|
||||
|
||||
app.LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(settings.LLM_KEY)
|
||||
app.OPENAI_CLIENT = AsyncOpenAI(api_key=settings.OPENAI_API_KEY or "")
|
||||
if settings.ENABLE_AZURE_CUA:
|
||||
app.OPENAI_CLIENT = AsyncAzureOpenAI(
|
||||
api_key=settings.AZURE_CUA_API_KEY,
|
||||
api_version=settings.AZURE_CUA_API_VERSION,
|
||||
azure_endpoint=settings.AZURE_CUA_ENDPOINT,
|
||||
azure_deployment=settings.AZURE_CUA_DEPLOYMENT,
|
||||
)
|
||||
|
||||
app.ANTHROPIC_CLIENT = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
if settings.ENABLE_BEDROCK_ANTHROPIC:
|
||||
app.ANTHROPIC_CLIENT = AsyncAnthropicBedrock()
|
||||
|
||||
app.UI_TARS_CLIENT = None
|
||||
if settings.ENABLE_VOLCENGINE:
|
||||
app.UI_TARS_CLIENT = AsyncOpenAI(
|
||||
api_key=settings.VOLCENGINE_API_KEY,
|
||||
base_url=settings.VOLCENGINE_API_BASE,
|
||||
)
|
||||
|
||||
app.SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
settings.SECONDARY_LLM_KEY if settings.SECONDARY_LLM_KEY else settings.LLM_KEY
|
||||
)
|
||||
app.SELECT_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
settings.SELECT_AGENT_LLM_KEY or settings.SECONDARY_LLM_KEY or settings.LLM_KEY
|
||||
)
|
||||
app.NORMAL_SELECT_AGENT_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(settings.NORMAL_SELECT_AGENT_LLM_KEY)
|
||||
if settings.NORMAL_SELECT_AGENT_LLM_KEY
|
||||
else app.SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
app.CUSTOM_SELECT_AGENT_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(settings.CUSTOM_SELECT_AGENT_LLM_KEY)
|
||||
if settings.CUSTOM_SELECT_AGENT_LLM_KEY
|
||||
else app.SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
app.SINGLE_CLICK_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
settings.SINGLE_CLICK_AGENT_LLM_KEY or settings.SECONDARY_LLM_KEY or settings.LLM_KEY
|
||||
)
|
||||
app.SINGLE_INPUT_AGENT_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
settings.SINGLE_INPUT_AGENT_LLM_KEY or settings.SECONDARY_LLM_KEY or settings.LLM_KEY
|
||||
)
|
||||
app.PARSE_SELECT_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(settings.PARSE_SELECT_LLM_KEY)
|
||||
if settings.PARSE_SELECT_LLM_KEY
|
||||
else app.SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
app.EXTRACTION_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(settings.EXTRACTION_LLM_KEY)
|
||||
if settings.EXTRACTION_LLM_KEY
|
||||
else app.LLM_API_HANDLER
|
||||
)
|
||||
app.CHECK_USER_GOAL_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(settings.CHECK_USER_GOAL_LLM_KEY)
|
||||
if settings.CHECK_USER_GOAL_LLM_KEY
|
||||
else app.SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
app.AUTO_COMPLETION_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(settings.AUTO_COMPLETION_LLM_KEY)
|
||||
if settings.AUTO_COMPLETION_LLM_KEY
|
||||
else app.SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
app.SVG_CSS_CONVERTER_LLM_API_HANDLER = app.SECONDARY_LLM_API_HANDLER if settings.SECONDARY_LLM_KEY else None
|
||||
app.SCRIPT_GENERATION_LLM_API_HANDLER = (
|
||||
LLMAPIHandlerFactory.get_llm_api_handler(settings.SCRIPT_GENERATION_LLM_KEY)
|
||||
if settings.SCRIPT_GENERATION_LLM_KEY
|
||||
else app.SECONDARY_LLM_API_HANDLER
|
||||
)
|
||||
|
||||
app.WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
||||
app.WORKFLOW_SERVICE = WorkflowService()
|
||||
app.AGENT_FUNCTION = AgentFunction()
|
||||
app.PERSISTENT_SESSIONS_MANAGER = PersistentSessionsManager(database=app.DATABASE)
|
||||
|
||||
app.BITWARDEN_CREDENTIAL_VAULT_SERVICE = BitwardenCredentialVaultService()
|
||||
app.AZURE_CREDENTIAL_VAULT_SERVICE = (
|
||||
AzureCredentialVaultService(
|
||||
tenant_id=settings.AZURE_TENANT_ID, # type: ignore[arg-type]
|
||||
client_id=settings.AZURE_CLIENT_ID, # type: ignore[arg-type]
|
||||
client_secret=settings.AZURE_CLIENT_SECRET, # type: ignore[arg-type]
|
||||
vault_name=settings.AZURE_CREDENTIAL_VAULT, # type: ignore[arg-type]
|
||||
)
|
||||
if settings.AZURE_CREDENTIAL_VAULT
|
||||
else None
|
||||
)
|
||||
app.CREDENTIAL_VAULT_SERVICES = {
|
||||
CredentialVaultType.BITWARDEN: app.BITWARDEN_CREDENTIAL_VAULT_SERVICE,
|
||||
CredentialVaultType.AZURE_VAULT: app.AZURE_CREDENTIAL_VAULT_SERVICE,
|
||||
}
|
||||
|
||||
app.scrape_exclude = None
|
||||
app.authentication_function = None
|
||||
app.authenticate_user_function = None
|
||||
app.setup_api_app = None
|
||||
app.api_app_startup_event = None
|
||||
app.api_app_shutdown_event = None
|
||||
|
||||
app.agent = ForgeAgent()
|
||||
|
||||
return app
|
||||
28
skyvern/forge/forge_app_initializer.py
Normal file
28
skyvern/forge/forge_app_initializer.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import structlog
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import set_force_app_instance
|
||||
from skyvern.forge.forge_app import ForgeApp, create_forge_app
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
def start_forge_app() -> ForgeApp:
|
||||
force_app_instance = create_forge_app()
|
||||
set_force_app_instance(force_app_instance)
|
||||
|
||||
if settings.ADDITIONAL_MODULES:
|
||||
for module in settings.ADDITIONAL_MODULES:
|
||||
LOG.info("Loading additional module to set up api app", module=module)
|
||||
app_module = __import__(module)
|
||||
configure_app_fn = getattr(app_module, "configure_app", None)
|
||||
if not configure_app_fn:
|
||||
raise RuntimeError(f"Missing configure_app function in {module}")
|
||||
|
||||
configure_app_fn(force_app_instance)
|
||||
LOG.info(
|
||||
"Additional modules loaded to set up api app",
|
||||
modules=settings.ADDITIONAL_MODULES,
|
||||
)
|
||||
|
||||
return force_app_instance
|
||||
@@ -12,14 +12,15 @@ def create_embedded_server(
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if self._transport is None:
|
||||
from skyvern.config import settings # noqa: PLC0415
|
||||
from skyvern.forge.api_app import app # noqa: PLC0415
|
||||
|
||||
settings.BROWSER_LOGS_ENABLED = False
|
||||
|
||||
if openai_api_key:
|
||||
settings.OPENAI_API_KEY = openai_api_key
|
||||
|
||||
self._transport = ASGITransport(app=app)
|
||||
from skyvern.forge.api_app import create_api_app # noqa: PLC0415
|
||||
|
||||
self._transport = ASGITransport(app=create_api_app())
|
||||
|
||||
response = await self._transport.handle_async_request(request)
|
||||
return response
|
||||
|
||||
@@ -6,10 +6,18 @@ from unittest.mock import MagicMock, patch
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.forge_app_initializer import start_forge_app
|
||||
from skyvern.forge.sdk.workflow.models.block import FileParserBlock, FileType
|
||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_forge_app():
|
||||
start_forge_app()
|
||||
yield
|
||||
|
||||
|
||||
class TestFileParserBlock:
|
||||
@pytest.fixture
|
||||
def file_parser_block(self):
|
||||
@@ -172,7 +180,7 @@ class TestFileParserBlock:
|
||||
# Mock the LLM response
|
||||
mock_response = {"extracted_data": {"names": ["John", "Jane"], "total_count": 2}}
|
||||
|
||||
with patch("skyvern.forge.sdk.workflow.models.block.app.LLM_API_HANDLER") as mock_llm:
|
||||
with patch.object(object.__getattribute__(app, "_inst"), "LLM_API_HANDLER") as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch("skyvern.forge.sdk.workflow.models.block.prompt_engine.load_prompt") as mock_prompt:
|
||||
@@ -190,7 +198,7 @@ class TestFileParserBlock:
|
||||
# Mock the LLM response
|
||||
mock_response = {"output": {"summary": "Extracted data from file"}}
|
||||
|
||||
with patch("skyvern.forge.sdk.workflow.models.block.app.LLM_API_HANDLER") as mock_llm:
|
||||
with patch.object(object.__getattribute__(app, "_inst"), "LLM_API_HANDLER") as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch("skyvern.forge.sdk.workflow.models.block.prompt_engine.load_prompt") as mock_prompt:
|
||||
|
||||
@@ -8,10 +8,17 @@ import pytest
|
||||
from skyvern import config
|
||||
from skyvern.config import Settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.forge_app_initializer import start_forge_app
|
||||
from skyvern.forge.sdk.api.llm import api_handler_factory, config_registry
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_forge_app():
|
||||
start_forge_app()
|
||||
yield
|
||||
|
||||
|
||||
class DummyResponse(dict):
|
||||
def __init__(self, content: str):
|
||||
super().__init__({"choices": [{"message": {"content": content}}], "usage": {}})
|
||||
|
||||
Reference in New Issue
Block a user