fix llm_key - render every llm_key for cua in the execute_step (#2556)

This commit is contained in:
Shuchang Zheng
2025-05-31 16:26:02 -07:00
committed by GitHub
parent 48f5f0913e
commit deb38af17d
3 changed files with 6 additions and 20 deletions

View File

@@ -174,6 +174,7 @@ class ForgeAgent:
max_steps_per_run=task_block.max_steps_per_run,
error_code_mapping=task_block.error_code_mapping,
include_action_history_in_verification=task_block.include_action_history_in_verification,
model=task_block.model,
)
LOG.info(
"Created a new task for workflow run",
@@ -387,7 +388,10 @@ class ForgeAgent:
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if not llm_caller:
# if not, create a new llm_caller
llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True)
llm_key = task.llm_key
llm_caller = LLMCaller(
llm_key=llm_key or settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True
)
# TODO: remove the code after migrating everything to llm callers
# currently, only anthropic cua tasks use llm_caller

View File

@@ -5,7 +5,6 @@ from fastapi import BackgroundTasks, Request
from skyvern.exceptions import OrganizationNotFound
from skyvern.forge import app
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.schemas.organizations import Organization
@@ -96,21 +95,16 @@ class BackgroundTaskExecutor(AsyncExecutor):
)
run_obj = await app.DATABASE.get_run(run_id=task_id, organization_id=organization_id)
engine = RunEngine.skyvern_v1
screenshot_scaling_enabled = False
if run_obj and run_obj.task_run_type == RunType.openai_cua:
engine = RunEngine.openai_cua
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
engine = RunEngine.anthropic_cua
screenshot_scaling_enabled = True
context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id
context.organization_id = organization_id
context.max_steps_override = max_steps_override
llm_key = task.llm_key
llm_caller = LLMCaller(llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled) if llm_key else None
if background_tasks:
background_tasks.add_task(
app.agent.execute_step,
@@ -121,7 +115,6 @@ class BackgroundTaskExecutor(AsyncExecutor):
close_browser_on_completion=close_browser_on_completion,
browser_session_id=browser_session_id,
engine=engine,
llm_caller=llm_caller,
)
async def execute_workflow(

View File

@@ -46,7 +46,7 @@ from skyvern.forge.sdk.api.files import (
download_from_s3,
get_path_for_workflow_download_directory,
)
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.db.enums import TaskType
@@ -628,16 +628,6 @@ class BaseTaskBlock(Block):
try:
current_context = skyvern_context.ensure_context()
current_context.task_id = task.task_id
llm_key = workflow.determine_llm_key(block=self)
screenshot_scaling_enabled = False
if self.engine == RunEngine.anthropic_cua:
screenshot_scaling_enabled = True
llm_caller = (
None
if not llm_key
else LLMCaller(llm_key=llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled)
)
await app.agent.execute_step(
organization=organization,
task=task,
@@ -647,7 +637,6 @@ class BaseTaskBlock(Block):
close_browser_on_completion=browser_session_id is None,
complete_verification=self.complete_verification,
engine=self.engine,
llm_caller=llm_caller,
)
except Exception as e:
# Make sure the task is marked as failed in the database before raising the exception