fix llm_key_override (#2552)

This commit is contained in:
Shuchang Zheng
2025-05-31 11:11:25 -07:00
committed by GitHub
parent 07bf256779
commit 48f5f0913e
10 changed files with 45 additions and 18 deletions

View File

@@ -382,14 +382,17 @@ class ForgeAgent:
if page := await browser_state.get_working_page():
await self.register_async_operations(organization, task, page)
if not llm_caller:
if engine == RunEngine.anthropic_cua and not llm_caller:
# see if the llm_caller is already set in memory
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if engine == RunEngine.anthropic_cua and not llm_caller:
# llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE")
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if not llm_caller:
llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True)
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
if not llm_caller:
# if not, create a new llm_caller
llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True)
# TODO: remove the code after migrating everything to llm callers
# currently, only anthropic cua tasks use llm_caller
if engine == RunEngine.anthropic_cua and llm_caller:
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
step, detailed_output = await self.agent_step(
task,

View File

@@ -96,10 +96,12 @@ class BackgroundTaskExecutor(AsyncExecutor):
)
run_obj = await app.DATABASE.get_run(run_id=task_id, organization_id=organization_id)
engine = RunEngine.skyvern_v1
screenshot_scaling_enabled = False
if run_obj and run_obj.task_run_type == RunType.openai_cua:
engine = RunEngine.openai_cua
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
engine = RunEngine.anthropic_cua
screenshot_scaling_enabled = True
context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id
@@ -107,7 +109,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
context.max_steps_override = max_steps_override
llm_key = task.llm_key
llm_caller = LLMCaller(llm_key) if llm_key else None
llm_caller = LLMCaller(llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled) if llm_key else None
if background_tasks:
background_tasks.add_task(

View File

@@ -37,7 +37,7 @@ webhook_callback_url: https://example.com/webhook
totp_verification_url: https://example.com/totp
persist_browser_session: false
model:
model: gpt-3.5-turbo
name: gpt-4.1
workflow_definition:
parameters:
- key: website_url
@@ -121,7 +121,7 @@ workflow_definition = {
"webhook_callback_url": "https://example.com/webhook",
"totp_verification_url": "https://example.com/totp",
"totp_identifier": "4155555555",
"model": {"model": "gpt-3.5-turbo"},
"model": {"name": "gpt-4.1"},
"workflow_definition": {
"parameters": [
{
@@ -204,7 +204,8 @@ proxy_location: RESIDENTIAL
webhook_callback_url: https://example.com/webhook
totp_verification_url: https://example.com/totp
persist_browser_session: false
model: {model: gpt-3.5-turbo}
model:
name: gpt-4.1
workflow_definition:
parameters:
- key: website_url
@@ -287,7 +288,7 @@ updated_workflow_definition = {
"webhook_callback_url": "https://example.com/webhook",
"totp_verification_url": "https://example.com/totp",
"totp_identifier": "4155555555",
"model": {"model": "gpt-3.5-turbo"},
"model": {"name": "gpt-4.1"},
"workflow_definition": {
"parameters": [
{

View File

@@ -57,7 +57,7 @@ class TaskV2(BaseModel):
"""
if self.model:
model_name = self.model.get("model_name")
model_name = self.model.get("name")
if model_name:
mapping = settings.get_model_name_to_llm_key()
llm_key = mapping.get(model_name)

View File

@@ -248,7 +248,7 @@ class Task(TaskBase):
Otherwise return `None`.
"""
if self.model:
model_name = self.model.get("model_name")
model_name = self.model.get("name")
if model_name:
mapping = settings.get_model_name_to_llm_key()
return mapping.get(model_name)

View File

@@ -629,7 +629,14 @@ class BaseTaskBlock(Block):
current_context = skyvern_context.ensure_context()
current_context.task_id = task.task_id
llm_key = workflow.determine_llm_key(block=self)
llm_caller = None if not llm_key else LLMCaller(llm_key=llm_key)
screenshot_scaling_enabled = False
if self.engine == RunEngine.anthropic_cua:
screenshot_scaling_enabled = True
llm_caller = (
None
if not llm_key
else LLMCaller(llm_key=llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled)
)
await app.agent.execute_step(
organization=organization,

View File

@@ -94,14 +94,14 @@ class Workflow(BaseModel):
mapping = settings.get_model_name_to_llm_key()
if block:
model_name = (block.model or {}).get("model")
model_name = (block.model or {}).get("name")
if model_name:
llm_key = mapping.get(model_name)
if llm_key:
return llm_key
workflow_model_name = (self.model or {}).get("model")
workflow_model_name = (self.model or {}).get("name")
if workflow_model_name:
llm_key = mapping.get(workflow_model_name)