From 9a29d966ab147e5dae89e64974f7abff7fcb93f8 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 11 Jun 2025 08:23:44 -0700 Subject: [PATCH] overhual llm key override (#2677) --- skyvern/forge/agent.py | 28 ++++-- .../forge/sdk/api/llm/api_handler_factory.py | 85 +++++++++---------- skyvern/forge/sdk/api/llm/models.py | 2 - skyvern/services/task_v2_service.py | 7 +- skyvern/webeye/actions/handler.py | 12 +-- 5 files changed, 72 insertions(+), 62 deletions(-) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 1dcf09a5..beb6c15c 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -58,7 +58,7 @@ from skyvern.forge.sdk.api.files import ( rename_file, wait_for_download_finished, ) -from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller, LLMCallerManager +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers @@ -865,15 +865,20 @@ class ForgeAgent: ): using_cached_action_plan = True else: + llm_key_override = task.llm_key + # FIXME: Redundant engine check? if engine in CUA_ENGINES: self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm) + llm_key_override = None - json_response = await app.LLM_API_HANDLER( + llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( + llm_key_override, default=app.LLM_API_HANDLER + ) + json_response = await llm_api_handler( prompt=extract_action_prompt, prompt_name="extract-actions", step=step, screenshots=scraped_page.screenshots, - llm_key_override=task.llm_key, ) try: json_response = await self.handle_potential_verification_code( @@ -1513,12 +1518,14 @@ class ForgeAgent: # this prompt is critical to our agent so let's use the primary LLM API handler - verification_result = await app.LLM_API_HANDLER( + llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( + llm_key_override, default=app.LLM_API_HANDLER + ) + verification_result = await llm_api_handler( prompt=verification_prompt, step=step, screenshots=scraped_page_refreshed.screenshots, prompt_name="check-user-goal", - llm_key_override=llm_key_override, ) return CompleteVerifyResult.model_validate(verification_result) @@ -1833,7 +1840,10 @@ class ForgeAgent: prompt = prompt_engine.load_prompt( "infer-action-type", navigation_goal=navigation_goal, prompt_name="infer-action-type" ) - json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step, prompt_name="infer-action-type") + llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( + task.llm_key, default=app.LLM_API_HANDLER + ) + json_response = await llm_api_handler(prompt=prompt, step=step, prompt_name="infer-action-type") if json_response.get("error"): raise FailedToParseActionInstruction( reason=json_response.get("thought"), error_type=json_response.get("error") @@ -2772,12 +2782,14 @@ class ForgeAgent: run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: llm_key_override = None - return await app.LLM_API_HANDLER( + llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( + llm_key_override, default=app.LLM_API_HANDLER + ) + return await llm_api_handler( prompt=extract_action_prompt, step=step, screenshots=scraped_page.screenshots, prompt_name="extract-actions", - llm_key_override=llm_key_override, ) return json_response diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 7aa1ffc8..201120d7 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -45,6 +45,20 @@ class LLMCallStats(BaseModel): class LLMAPIHandlerFactory: _custom_handlers: dict[str, LLMAPIHandler] = {} + @staticmethod + def get_override_llm_api_handler(override_llm_key: str | None, *, default: LLMAPIHandler) -> LLMAPIHandler: + if not override_llm_key: + return default + try: + return LLMAPIHandlerFactory.get_llm_api_handler(override_llm_key) + except Exception: + LOG.warning( + "Failed to get override LLM API handler, going to use the default.", + override_llm_key=override_llm_key, + exc_info=True, + ) + return default + @staticmethod def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler: llm_config = LLMConfigRegistry.get_config(llm_key) @@ -82,7 +96,6 @@ class LLMAPIHandlerFactory: ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, - llm_key_override: str | None = None, ) -> dict[str, Any]: """ Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision. @@ -96,18 +109,10 @@ class LLMAPIHandlerFactory: Returns: The response from the LLM router. """ - nonlocal llm_config - nonlocal llm_key - - local_llm_config: LLMConfig | LLMRouterConfig = llm_config - if llm_key_override: - local_llm_config = LLMConfigRegistry.get_config(llm_key_override) - - local_llm_key = llm_key_override or llm_key start_time = time.time() if parameters is None: - parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config) + parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config) context = skyvern_context.current() if context and len(context.hashed_href_map) > 0: @@ -128,12 +133,12 @@ class LLMAPIHandlerFactory: task_v2=task_v2, thought=thought, ) - messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix) + messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps( { - "model": local_llm_key, + "model": llm_key, "messages": messages, **parameters, } @@ -149,12 +154,12 @@ class LLMAPIHandlerFactory: model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters ) except litellm.exceptions.APIError as e: - raise LLMProviderErrorRetryableTask(local_llm_key) from e + raise LLMProviderErrorRetryableTask(llm_key) from e except litellm.exceptions.ContextWindowExceededError as e: duration_seconds = time.time() - start_time LOG.exception( "Context window exceeded", - llm_key=local_llm_key, + llm_key=llm_key, model=main_model_group, prompt_name=prompt_name, duration_seconds=duration_seconds, @@ -164,22 +169,22 @@ class LLMAPIHandlerFactory: duration_seconds = time.time() - start_time LOG.exception( "LLM token limit exceeded", - llm_key=local_llm_key, + llm_key=llm_key, model=main_model_group, prompt_name=prompt_name, duration_seconds=duration_seconds, ) - raise LLMProviderErrorRetryableTask(local_llm_key) from e + raise LLMProviderErrorRetryableTask(llm_key) from e except Exception as e: duration_seconds = time.time() - start_time LOG.exception( "LLM request failed unexpectedly", - llm_key=local_llm_key, + llm_key=llm_key, model=main_model_group, prompt_name=prompt_name, duration_seconds=duration_seconds, ) - raise LLMProviderError(local_llm_key) from e + raise LLMProviderError(llm_key) from e await app.ARTIFACT_MANAGER.create_llm_artifact( data=response.model_dump_json(indent=2).encode("utf-8"), @@ -226,7 +231,7 @@ class LLMAPIHandlerFactory: reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, cached_token_count=cached_tokens if cached_tokens > 0 else None, ) - parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix) + parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(parsed_response, indent=2).encode("utf-8"), artifact_type=ArtifactType.LLM_RESPONSE_PARSED, @@ -253,7 +258,7 @@ class LLMAPIHandlerFactory: duration_seconds = time.time() - start_time LOG.info( "LLM API handler duration metrics", - llm_key=local_llm_key, + llm_key=llm_key, model=main_model_group, prompt_name=prompt_name, duration_seconds=duration_seconds, @@ -287,25 +292,15 @@ class LLMAPIHandlerFactory: ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, - llm_key_override: str | None = None, ) -> dict[str, Any]: - nonlocal llm_config - nonlocal llm_key - - local_llm_config: LLMConfig | LLMRouterConfig = llm_config - if llm_key_override: - local_llm_config = LLMConfigRegistry.get_config(llm_key_override) - - local_llm_key = llm_key_override or llm_key - start_time = time.time() active_parameters = base_parameters or {} if parameters is None: - parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config) + parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config) active_parameters.update(parameters) - if local_llm_config.litellm_params: # type: ignore - active_parameters.update(local_llm_config.litellm_params) # type: ignore + if llm_config.litellm_params: # type: ignore + active_parameters.update(llm_config.litellm_params) # type: ignore context = skyvern_context.current() if context and len(context.hashed_href_map) > 0: @@ -328,12 +323,12 @@ class LLMAPIHandlerFactory: ai_suggestion=ai_suggestion, ) - if not local_llm_config.supports_vision: + if not llm_config.supports_vision: screenshots = None - model_name = local_llm_config.model_name + model_name = llm_config.model_name - messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix) + messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps( { @@ -361,12 +356,12 @@ class LLMAPIHandlerFactory: **active_parameters, ) except litellm.exceptions.APIError as e: - raise LLMProviderErrorRetryableTask(local_llm_key) from e + raise LLMProviderErrorRetryableTask(llm_key) from e except litellm.exceptions.ContextWindowExceededError as e: duration_seconds = time.time() - start_time LOG.exception( "Context window exceeded", - llm_key=local_llm_key, + llm_key=llm_key, model=model_name, prompt_name=prompt_name, duration_seconds=duration_seconds, @@ -376,22 +371,22 @@ class LLMAPIHandlerFactory: t_llm_cancelled = time.perf_counter() LOG.error( "LLM request got cancelled", - llm_key=local_llm_key, + llm_key=llm_key, model=model_name, prompt_name=prompt_name, duration=t_llm_cancelled - t_llm_request, ) - raise LLMProviderError(local_llm_key) + raise LLMProviderError(llm_key) except Exception as e: duration_seconds = time.time() - start_time LOG.exception( "LLM request failed unexpectedly", - llm_key=local_llm_key, + llm_key=llm_key, model=model_name, prompt_name=prompt_name, duration_seconds=duration_seconds, ) - raise LLMProviderError(local_llm_key) from e + raise LLMProviderError(llm_key) from e await app.ARTIFACT_MANAGER.create_llm_artifact( data=response.model_dump_json(indent=2).encode("utf-8"), @@ -439,7 +434,7 @@ class LLMAPIHandlerFactory: cached_token_count=cached_tokens if cached_tokens > 0 else None, thought_cost=llm_cost, ) - parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix) + parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(parsed_response, indent=2).encode("utf-8"), artifact_type=ArtifactType.LLM_RESPONSE_PARSED, @@ -466,9 +461,9 @@ class LLMAPIHandlerFactory: duration_seconds = time.time() - start_time LOG.info( "LLM API handler duration metrics", - llm_key=local_llm_key, + llm_key=llm_key, prompt_name=prompt_name, - model=local_llm_config.model_name, + model=llm_config.model_name, duration_seconds=duration_seconds, step_id=step.step_id if step else None, thought_id=thought.observer_thought_id if thought else None, diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index 1d4ae9b9..9d231c2a 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -94,7 +94,6 @@ class LLMAPIHandler(Protocol): ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, - llm_key_override: str | None = None, ) -> Awaitable[dict[str, Any]]: ... @@ -107,6 +106,5 @@ async def dummy_llm_api_handler( ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, - llm_key_override: str | None = None, ) -> dict[str, Any]: raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.") diff --git a/skyvern/services/task_v2_service.py b/skyvern/services/task_v2_service.py index 1947d240..ea61e346 100644 --- a/skyvern/services/task_v2_service.py +++ b/skyvern/services/task_v2_service.py @@ -20,6 +20,7 @@ from skyvern.exceptions import ( ) from skyvern.forge import app from skyvern.forge.prompts import prompt_engine +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.hashing import generate_url_hash @@ -617,12 +618,14 @@ async def run_task_v2_helper( thought_type=ThoughtType.plan, thought_scenario=ThoughtScenario.generate_plan, ) - task_v2_response = await app.LLM_API_HANDLER( + llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( + task_v2.llm_key, default=app.LLM_API_HANDLER + ) + task_v2_response = await llm_api_handler( prompt=task_v2_prompt, screenshots=scraped_page.screenshots, thought=thought, prompt_name="task_v2", - llm_key_override=task_v2.llm_key, ) LOG.info( "Task v2 response", diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index b972d506..25daa26d 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -60,7 +60,7 @@ from skyvern.forge.sdk.api.files import ( list_files_in_directory, wait_for_download_finished, ) -from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCallerManager +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCallerManager from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post @@ -2557,7 +2557,8 @@ async def sequentially_select_from_dropdown( select_history=json.dumps(build_sequential_select_history(select_history)), local_datetime=datetime.now(ensure_context().tz_info).isoformat(), ) - json_response = await app.LLM_API_HANDLER( + llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(task.llm_key, default=app.LLM_API_HANDLER) + json_response = await llm_api_handler( prompt=prompt, screenshots=[screenshot], step=step, prompt_name="confirm-multi-selection-finish" ) if json_response.get("is_mini_goal_finished", False): @@ -2641,7 +2642,8 @@ async def select_from_emerging_elements( task_id=task.task_id, ) - json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step, prompt_name="custom-select") + llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(task.llm_key, default=app.LLM_API_HANDLER) + json_response = await llm_api_handler(prompt=prompt, step=step, prompt_name="custom-select") value: str | None = json_response.get("value", None) LOG.info( "LLM response for the matched element", @@ -3385,12 +3387,12 @@ async def extract_information_for_navigation_goal( # CUA tasks should use the default data extraction llm key llm_key_override = None - json_response = await app.LLM_API_HANDLER( + llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(llm_key_override, default=app.LLM_API_HANDLER) + json_response = await llm_api_handler( prompt=extract_information_prompt, step=step, screenshots=scraped_page.screenshots, prompt_name="extract-information", - llm_key_override=llm_key_override, ) return ScrapeResult(