fix llm key override in llm_api_handler_with_router_and_fallback (#2562)
This commit is contained in:
@@ -1489,8 +1489,10 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
||||||
scroll = True
|
scroll = True
|
||||||
|
llm_key_override = task.llm_key
|
||||||
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
||||||
scroll = False
|
scroll = False
|
||||||
|
llm_key_override = None
|
||||||
|
|
||||||
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll)
|
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll)
|
||||||
|
|
||||||
@@ -1510,11 +1512,13 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# this prompt is critical to our agent so let's use the primary LLM API handler
|
# this prompt is critical to our agent so let's use the primary LLM API handler
|
||||||
|
|
||||||
verification_result = await app.LLM_API_HANDLER(
|
verification_result = await app.LLM_API_HANDLER(
|
||||||
prompt=verification_prompt,
|
prompt=verification_prompt,
|
||||||
step=step,
|
step=step,
|
||||||
screenshots=scraped_page_refreshed.screenshots,
|
screenshots=scraped_page_refreshed.screenshots,
|
||||||
prompt_name="check-user-goal",
|
prompt_name="check-user-goal",
|
||||||
|
llm_key_override=llm_key_override,
|
||||||
)
|
)
|
||||||
return CompleteVerifyResult.model_validate(verification_result)
|
return CompleteVerifyResult.model_validate(verification_result)
|
||||||
|
|
||||||
@@ -2683,11 +2687,16 @@ class ForgeAgent:
|
|||||||
verification_code_check=False,
|
verification_code_check=False,
|
||||||
expire_verification_code=True,
|
expire_verification_code=True,
|
||||||
)
|
)
|
||||||
|
llm_key_override = task.llm_key
|
||||||
|
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
||||||
|
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
||||||
|
llm_key_override = None
|
||||||
return await app.LLM_API_HANDLER(
|
return await app.LLM_API_HANDLER(
|
||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
step=step,
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
prompt_name="extract-actions",
|
prompt_name="extract-actions",
|
||||||
|
llm_key_override=llm_key_override,
|
||||||
)
|
)
|
||||||
return json_response
|
return json_response
|
||||||
|
|
||||||
|
|||||||
@@ -96,10 +96,18 @@ class LLMAPIHandlerFactory:
|
|||||||
Returns:
|
Returns:
|
||||||
The response from the LLM router.
|
The response from the LLM router.
|
||||||
"""
|
"""
|
||||||
|
nonlocal llm_config
|
||||||
|
nonlocal llm_key
|
||||||
|
|
||||||
|
local_llm_config: LLMConfig | LLMRouterConfig = llm_config
|
||||||
|
if llm_key_override:
|
||||||
|
local_llm_config = LLMConfigRegistry.get_config(llm_key_override)
|
||||||
|
|
||||||
|
local_llm_key = llm_key_override or llm_key
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if parameters is None:
|
if parameters is None:
|
||||||
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
|
parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config)
|
||||||
|
|
||||||
context = skyvern_context.current()
|
context = skyvern_context.current()
|
||||||
if context and len(context.hashed_href_map) > 0:
|
if context and len(context.hashed_href_map) > 0:
|
||||||
@@ -120,12 +128,12 @@ class LLMAPIHandlerFactory:
|
|||||||
task_v2=task_v2,
|
task_v2=task_v2,
|
||||||
thought=thought,
|
thought=thought,
|
||||||
)
|
)
|
||||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix)
|
||||||
|
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
"model": llm_key,
|
"model": local_llm_key,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
**parameters,
|
**parameters,
|
||||||
}
|
}
|
||||||
@@ -139,28 +147,28 @@ class LLMAPIHandlerFactory:
|
|||||||
try:
|
try:
|
||||||
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
|
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
|
||||||
except litellm.exceptions.APIError as e:
|
except litellm.exceptions.APIError as e:
|
||||||
raise LLMProviderErrorRetryableTask(llm_key) from e
|
raise LLMProviderErrorRetryableTask(local_llm_key) from e
|
||||||
except litellm.exceptions.ContextWindowExceededError as e:
|
except litellm.exceptions.ContextWindowExceededError as e:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"Context window exceeded",
|
"Context window exceeded",
|
||||||
llm_key=llm_key,
|
llm_key=local_llm_key,
|
||||||
model=main_model_group,
|
model=main_model_group,
|
||||||
)
|
)
|
||||||
raise SkyvernContextWindowExceededError() from e
|
raise SkyvernContextWindowExceededError() from e
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"LLM token limit exceeded",
|
"LLM token limit exceeded",
|
||||||
llm_key=llm_key,
|
llm_key=local_llm_key,
|
||||||
model=main_model_group,
|
model=main_model_group,
|
||||||
)
|
)
|
||||||
raise LLMProviderErrorRetryableTask(llm_key) from e
|
raise LLMProviderErrorRetryableTask(local_llm_key) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"LLM request failed unexpectedly",
|
"LLM request failed unexpectedly",
|
||||||
llm_key=llm_key,
|
llm_key=local_llm_key,
|
||||||
model=main_model_group,
|
model=main_model_group,
|
||||||
)
|
)
|
||||||
raise LLMProviderError(llm_key) from e
|
raise LLMProviderError(local_llm_key) from e
|
||||||
|
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||||
@@ -207,7 +215,7 @@ class LLMAPIHandlerFactory:
|
|||||||
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
|
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
|
||||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||||
)
|
)
|
||||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix)
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||||
@@ -234,7 +242,7 @@ class LLMAPIHandlerFactory:
|
|||||||
duration_seconds = time.time() - start_time
|
duration_seconds = time.time() - start_time
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"LLM API handler duration metrics",
|
"LLM API handler duration metrics",
|
||||||
llm_key=llm_key,
|
llm_key=local_llm_key,
|
||||||
model=main_model_group,
|
model=main_model_group,
|
||||||
prompt_name=prompt_name,
|
prompt_name=prompt_name,
|
||||||
duration_seconds=duration_seconds,
|
duration_seconds=duration_seconds,
|
||||||
|
|||||||
Reference in New Issue
Block a user