fix llm key override in llm_api_handler_with_router_and_fallback (#2562)

This commit is contained in:
Shuchang Zheng
2025-05-31 23:23:37 -07:00
committed by GitHub
parent 2167d88c20
commit aef945cb63
2 changed files with 28 additions and 11 deletions

View File

@@ -1489,8 +1489,10 @@ class ForgeAgent:
) )
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
scroll = True scroll = True
llm_key_override = task.llm_key
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
scroll = False scroll = False
llm_key_override = None
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll) scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll)
@@ -1510,11 +1512,13 @@ class ForgeAgent:
) )
# this prompt is critical to our agent so let's use the primary LLM API handler # this prompt is critical to our agent so let's use the primary LLM API handler
verification_result = await app.LLM_API_HANDLER( verification_result = await app.LLM_API_HANDLER(
prompt=verification_prompt, prompt=verification_prompt,
step=step, step=step,
screenshots=scraped_page_refreshed.screenshots, screenshots=scraped_page_refreshed.screenshots,
prompt_name="check-user-goal", prompt_name="check-user-goal",
llm_key_override=llm_key_override,
) )
return CompleteVerifyResult.model_validate(verification_result) return CompleteVerifyResult.model_validate(verification_result)
@@ -2683,11 +2687,16 @@ class ForgeAgent:
verification_code_check=False, verification_code_check=False,
expire_verification_code=True, expire_verification_code=True,
) )
llm_key_override = task.llm_key
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
llm_key_override = None
return await app.LLM_API_HANDLER( return await app.LLM_API_HANDLER(
prompt=extract_action_prompt, prompt=extract_action_prompt,
step=step, step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
prompt_name="extract-actions", prompt_name="extract-actions",
llm_key_override=llm_key_override,
) )
return json_response return json_response

View File

@@ -96,10 +96,18 @@ class LLMAPIHandlerFactory:
Returns: Returns:
The response from the LLM router. The response from the LLM router.
""" """
nonlocal llm_config
nonlocal llm_key
local_llm_config: LLMConfig | LLMRouterConfig = llm_config
if llm_key_override:
local_llm_config = LLMConfigRegistry.get_config(llm_key_override)
local_llm_key = llm_key_override or llm_key
start_time = time.time() start_time = time.time()
if parameters is None: if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config) parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config)
context = skyvern_context.current() context = skyvern_context.current()
if context and len(context.hashed_href_map) > 0: if context and len(context.hashed_href_map) > 0:
@@ -120,12 +128,12 @@ class LLMAPIHandlerFactory:
task_v2=task_v2, task_v2=task_v2,
thought=thought, thought=thought,
) )
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps( data=json.dumps(
{ {
"model": llm_key, "model": local_llm_key,
"messages": messages, "messages": messages,
**parameters, **parameters,
} }
@@ -139,28 +147,28 @@ class LLMAPIHandlerFactory:
try: try:
response = await router.acompletion(model=main_model_group, messages=messages, **parameters) response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
except litellm.exceptions.APIError as e: except litellm.exceptions.APIError as e:
raise LLMProviderErrorRetryableTask(llm_key) from e raise LLMProviderErrorRetryableTask(local_llm_key) from e
except litellm.exceptions.ContextWindowExceededError as e: except litellm.exceptions.ContextWindowExceededError as e:
LOG.exception( LOG.exception(
"Context window exceeded", "Context window exceeded",
llm_key=llm_key, llm_key=local_llm_key,
model=main_model_group, model=main_model_group,
) )
raise SkyvernContextWindowExceededError() from e raise SkyvernContextWindowExceededError() from e
except ValueError as e: except ValueError as e:
LOG.exception( LOG.exception(
"LLM token limit exceeded", "LLM token limit exceeded",
llm_key=llm_key, llm_key=local_llm_key,
model=main_model_group, model=main_model_group,
) )
raise LLMProviderErrorRetryableTask(llm_key) from e raise LLMProviderErrorRetryableTask(local_llm_key) from e
except Exception as e: except Exception as e:
LOG.exception( LOG.exception(
"LLM request failed unexpectedly", "LLM request failed unexpectedly",
llm_key=llm_key, llm_key=local_llm_key,
model=main_model_group, model=main_model_group,
) )
raise LLMProviderError(llm_key) from e raise LLMProviderError(local_llm_key) from e
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"), data=response.model_dump_json(indent=2).encode("utf-8"),
@@ -207,7 +215,7 @@ class LLMAPIHandlerFactory:
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
cached_token_count=cached_tokens if cached_tokens > 0 else None, cached_token_count=cached_tokens if cached_tokens > 0 else None,
) )
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"), data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED, artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
@@ -234,7 +242,7 @@ class LLMAPIHandlerFactory:
duration_seconds = time.time() - start_time duration_seconds = time.time() - start_time
LOG.info( LOG.info(
"LLM API handler duration metrics", "LLM API handler duration metrics",
llm_key=llm_key, llm_key=local_llm_key,
model=main_model_group, model=main_model_group,
prompt_name=prompt_name, prompt_name=prompt_name,
duration_seconds=duration_seconds, duration_seconds=duration_seconds,