diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index e519085c..231998a6 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -864,6 +864,7 @@ class ForgeAgent: prompt_name="extract-actions", step=step, screenshots=scraped_page.screenshots, + llm_key_override=llm_caller.llm_key if llm_caller else None, ) try: json_response = await self.handle_potential_verification_code( diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 13863091..f1299577 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -82,6 +82,7 @@ class LLMAPIHandlerFactory: ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, + llm_key_override: str | None = None, ) -> dict[str, Any]: """ Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision. @@ -267,15 +268,25 @@ class LLMAPIHandlerFactory: ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, + llm_key_override: str | None = None, ) -> dict[str, Any]: + nonlocal llm_config + nonlocal llm_key + + local_llm_config: LLMConfig | LLMRouterConfig = llm_config + if llm_key_override: + local_llm_config = LLMConfigRegistry.get_config(llm_key_override) + + local_llm_key = llm_key_override or llm_key + start_time = time.time() active_parameters = base_parameters or {} if parameters is None: - parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config) + parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config) active_parameters.update(parameters) - if llm_config.litellm_params: # type: ignore - active_parameters.update(llm_config.litellm_params) # type: ignore + if local_llm_config.litellm_params: # type: ignore + active_parameters.update(local_llm_config.litellm_params) # type: ignore context = skyvern_context.current() if context and len(context.hashed_href_map) > 0: @@ -298,14 +309,16 @@ class LLMAPIHandlerFactory: ai_suggestion=ai_suggestion, ) - if not llm_config.supports_vision: + if not local_llm_config.supports_vision: screenshots = None - messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) + model_name = local_llm_config.model_name + + messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps( { - "model": llm_config.model_name, + "model": model_name, "messages": messages, # we're not using active_parameters here because it may contain sensitive information **parameters, @@ -323,32 +336,32 @@ class LLMAPIHandlerFactory: # TODO (kerem): add a retry mechanism to this call (acompletion_with_retries) # TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work response = await litellm.acompletion( - model=llm_config.model_name, + model=model_name, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **active_parameters, ) except litellm.exceptions.APIError as e: - raise LLMProviderErrorRetryableTask(llm_key) from e + raise LLMProviderErrorRetryableTask(local_llm_key) from e except litellm.exceptions.ContextWindowExceededError as e: LOG.exception( "Context window exceeded", - llm_key=llm_key, - model=llm_config.model_name, + llm_key=local_llm_key, + model=model_name, ) raise SkyvernContextWindowExceededError() from e except CancelledError: t_llm_cancelled = time.perf_counter() LOG.error( "LLM request got cancelled", - llm_key=llm_key, - model=llm_config.model_name, + llm_key=local_llm_key, + model=model_name, duration=t_llm_cancelled - t_llm_request, ) - raise LLMProviderError(llm_key) + raise LLMProviderError(local_llm_key) except Exception as e: - LOG.exception("LLM request failed unexpectedly", llm_key=llm_key) - raise LLMProviderError(llm_key) from e + LOG.exception("LLM request failed unexpectedly", llm_key=local_llm_key) + raise LLMProviderError(local_llm_key) from e await app.ARTIFACT_MANAGER.create_llm_artifact( data=response.model_dump_json(indent=2).encode("utf-8"), @@ -396,7 +409,7 @@ class LLMAPIHandlerFactory: cached_token_count=cached_tokens if cached_tokens > 0 else None, thought_cost=llm_cost, ) - parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) + parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(parsed_response, indent=2).encode("utf-8"), artifact_type=ArtifactType.LLM_RESPONSE_PARSED, @@ -423,9 +436,9 @@ class LLMAPIHandlerFactory: duration_seconds = time.time() - start_time LOG.info( "LLM API handler duration metrics", - llm_key=llm_key, + llm_key=local_llm_key, prompt_name=prompt_name, - model=llm_config.model_name, + model=local_llm_config.model_name, duration_seconds=duration_seconds, step_id=step.step_id if step else None, thought_id=thought.observer_thought_id if thought else None, diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index 9d231c2a..1d4ae9b9 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -94,6 +94,7 @@ class LLMAPIHandler(Protocol): ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, + llm_key_override: str | None = None, ) -> Awaitable[dict[str, Any]]: ... @@ -106,5 +107,6 @@ async def dummy_llm_api_handler( ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, + llm_key_override: str | None = None, ) -> dict[str, Any]: raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")