overhual llm key override (#2677)
This commit is contained in:
@@ -58,7 +58,7 @@ from skyvern.forge.sdk.api.files import (
|
||||
rename_file,
|
||||
wait_for_download_finished,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller, LLMCallerManager
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
||||
@@ -865,15 +865,20 @@ class ForgeAgent:
|
||||
):
|
||||
using_cached_action_plan = True
|
||||
else:
|
||||
llm_key_override = task.llm_key
|
||||
# FIXME: Redundant engine check?
|
||||
if engine in CUA_ENGINES:
|
||||
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
|
||||
llm_key_override = None
|
||||
|
||||
json_response = await app.LLM_API_HANDLER(
|
||||
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
llm_key_override, default=app.LLM_API_HANDLER
|
||||
)
|
||||
json_response = await llm_api_handler(
|
||||
prompt=extract_action_prompt,
|
||||
prompt_name="extract-actions",
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
llm_key_override=task.llm_key,
|
||||
)
|
||||
try:
|
||||
json_response = await self.handle_potential_verification_code(
|
||||
@@ -1513,12 +1518,14 @@ class ForgeAgent:
|
||||
|
||||
# this prompt is critical to our agent so let's use the primary LLM API handler
|
||||
|
||||
verification_result = await app.LLM_API_HANDLER(
|
||||
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
llm_key_override, default=app.LLM_API_HANDLER
|
||||
)
|
||||
verification_result = await llm_api_handler(
|
||||
prompt=verification_prompt,
|
||||
step=step,
|
||||
screenshots=scraped_page_refreshed.screenshots,
|
||||
prompt_name="check-user-goal",
|
||||
llm_key_override=llm_key_override,
|
||||
)
|
||||
return CompleteVerifyResult.model_validate(verification_result)
|
||||
|
||||
@@ -1833,7 +1840,10 @@ class ForgeAgent:
|
||||
prompt = prompt_engine.load_prompt(
|
||||
"infer-action-type", navigation_goal=navigation_goal, prompt_name="infer-action-type"
|
||||
)
|
||||
json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step, prompt_name="infer-action-type")
|
||||
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
task.llm_key, default=app.LLM_API_HANDLER
|
||||
)
|
||||
json_response = await llm_api_handler(prompt=prompt, step=step, prompt_name="infer-action-type")
|
||||
if json_response.get("error"):
|
||||
raise FailedToParseActionInstruction(
|
||||
reason=json_response.get("thought"), error_type=json_response.get("error")
|
||||
@@ -2772,12 +2782,14 @@ class ForgeAgent:
|
||||
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
||||
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
||||
llm_key_override = None
|
||||
return await app.LLM_API_HANDLER(
|
||||
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
llm_key_override, default=app.LLM_API_HANDLER
|
||||
)
|
||||
return await llm_api_handler(
|
||||
prompt=extract_action_prompt,
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
prompt_name="extract-actions",
|
||||
llm_key_override=llm_key_override,
|
||||
)
|
||||
return json_response
|
||||
|
||||
|
||||
Reference in New Issue
Block a user