diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 829fbc6c..8730d0a1 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -1489,8 +1489,10 @@ class ForgeAgent: ) run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) scroll = True + llm_key_override = task.llm_key if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: scroll = False + llm_key_override = None scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll) @@ -1510,11 +1512,13 @@ class ForgeAgent: ) # this prompt is critical to our agent so let's use the primary LLM API handler + verification_result = await app.LLM_API_HANDLER( prompt=verification_prompt, step=step, screenshots=scraped_page_refreshed.screenshots, prompt_name="check-user-goal", + llm_key_override=llm_key_override, ) return CompleteVerifyResult.model_validate(verification_result) @@ -2683,11 +2687,16 @@ class ForgeAgent: verification_code_check=False, expire_verification_code=True, ) + llm_key_override = task.llm_key + run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) + if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: + llm_key_override = None return await app.LLM_API_HANDLER( prompt=extract_action_prompt, step=step, screenshots=scraped_page.screenshots, prompt_name="extract-actions", + llm_key_override=llm_key_override, ) return json_response diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index f1299577..c8b3864d 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -96,10 +96,18 @@ class LLMAPIHandlerFactory: Returns: The response from the LLM router. """ + nonlocal llm_config + nonlocal llm_key + + local_llm_config: LLMConfig | LLMRouterConfig = llm_config + if llm_key_override: + local_llm_config = LLMConfigRegistry.get_config(llm_key_override) + + local_llm_key = llm_key_override or llm_key start_time = time.time() if parameters is None: - parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config) + parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config) context = skyvern_context.current() if context and len(context.hashed_href_map) > 0: @@ -120,12 +128,12 @@ class LLMAPIHandlerFactory: task_v2=task_v2, thought=thought, ) - messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) + messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps( { - "model": llm_key, + "model": local_llm_key, "messages": messages, **parameters, } @@ -139,28 +147,28 @@ class LLMAPIHandlerFactory: try: response = await router.acompletion(model=main_model_group, messages=messages, **parameters) except litellm.exceptions.APIError as e: - raise LLMProviderErrorRetryableTask(llm_key) from e + raise LLMProviderErrorRetryableTask(local_llm_key) from e except litellm.exceptions.ContextWindowExceededError as e: LOG.exception( "Context window exceeded", - llm_key=llm_key, + llm_key=local_llm_key, model=main_model_group, ) raise SkyvernContextWindowExceededError() from e except ValueError as e: LOG.exception( "LLM token limit exceeded", - llm_key=llm_key, + llm_key=local_llm_key, model=main_model_group, ) - raise LLMProviderErrorRetryableTask(llm_key) from e + raise LLMProviderErrorRetryableTask(local_llm_key) from e except Exception as e: LOG.exception( "LLM request failed unexpectedly", - llm_key=llm_key, + llm_key=local_llm_key, model=main_model_group, ) - raise LLMProviderError(llm_key) from e + raise LLMProviderError(local_llm_key) from e await app.ARTIFACT_MANAGER.create_llm_artifact( data=response.model_dump_json(indent=2).encode("utf-8"), @@ -207,7 +215,7 @@ class LLMAPIHandlerFactory: reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, cached_token_count=cached_tokens if cached_tokens > 0 else None, ) - parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) + parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(parsed_response, indent=2).encode("utf-8"), artifact_type=ArtifactType.LLM_RESPONSE_PARSED, @@ -234,7 +242,7 @@ class LLMAPIHandlerFactory: duration_seconds = time.time() - start_time LOG.info( "LLM API handler duration metrics", - llm_key=llm_key, + llm_key=local_llm_key, model=main_model_group, prompt_name=prompt_name, duration_seconds=duration_seconds,