fix llm_key_override (#2552)

This commit is contained in:
Shuchang Zheng
2025-05-31 11:11:25 -07:00
committed by GitHub
parent 07bf256779
commit 48f5f0913e
10 changed files with 45 additions and 18 deletions

View File

@@ -276,6 +276,8 @@ class Settings(BaseSettings):
"Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", "Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
"GPT 4.1": "OPENAI_GPT4_1", "GPT 4.1": "OPENAI_GPT4_1",
"GPT o3-mini": "OPENAI_O3_MINI", "GPT o3-mini": "OPENAI_O3_MINI",
"bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
} }
else: else:
# TODO: apparently the list for OSS is to be much larger # TODO: apparently the list for OSS is to be much larger
@@ -284,6 +286,8 @@ class Settings(BaseSettings):
"Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", "Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
"GPT 4.1": "OPENAI_GPT4_1", "GPT 4.1": "OPENAI_GPT4_1",
"GPT o3-mini": "OPENAI_O3_MINI", "GPT o3-mini": "OPENAI_O3_MINI",
"bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
} }
def is_cloud_environment(self) -> bool: def is_cloud_environment(self) -> bool:

View File

@@ -382,14 +382,17 @@ class ForgeAgent:
if page := await browser_state.get_working_page(): if page := await browser_state.get_working_page():
await self.register_async_operations(organization, task, page) await self.register_async_operations(organization, task, page)
if not llm_caller: if engine == RunEngine.anthropic_cua and not llm_caller:
# see if the llm_caller is already set in memory
llm_caller = LLMCallerManager.get_llm_caller(task.task_id) llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if engine == RunEngine.anthropic_cua and not llm_caller: if not llm_caller:
# llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE") # if not, create a new llm_caller
llm_caller = LLMCallerManager.get_llm_caller(task.task_id) llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True)
if not llm_caller:
llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True) # TODO: remove the code after migrating everything to llm callers
LLMCallerManager.set_llm_caller(task.task_id, llm_caller) # currently, only anthropic cua tasks use llm_caller
if engine == RunEngine.anthropic_cua and llm_caller:
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
step, detailed_output = await self.agent_step( step, detailed_output = await self.agent_step(
task, task,

View File

@@ -96,10 +96,12 @@ class BackgroundTaskExecutor(AsyncExecutor):
) )
run_obj = await app.DATABASE.get_run(run_id=task_id, organization_id=organization_id) run_obj = await app.DATABASE.get_run(run_id=task_id, organization_id=organization_id)
engine = RunEngine.skyvern_v1 engine = RunEngine.skyvern_v1
screenshot_scaling_enabled = False
if run_obj and run_obj.task_run_type == RunType.openai_cua: if run_obj and run_obj.task_run_type == RunType.openai_cua:
engine = RunEngine.openai_cua engine = RunEngine.openai_cua
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua: elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
engine = RunEngine.anthropic_cua engine = RunEngine.anthropic_cua
screenshot_scaling_enabled = True
context: SkyvernContext = skyvern_context.ensure_context() context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id context.task_id = task.task_id
@@ -107,7 +109,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
context.max_steps_override = max_steps_override context.max_steps_override = max_steps_override
llm_key = task.llm_key llm_key = task.llm_key
llm_caller = LLMCaller(llm_key) if llm_key else None llm_caller = LLMCaller(llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled) if llm_key else None
if background_tasks: if background_tasks:
background_tasks.add_task( background_tasks.add_task(

View File

@@ -37,7 +37,7 @@ webhook_callback_url: https://example.com/webhook
totp_verification_url: https://example.com/totp totp_verification_url: https://example.com/totp
persist_browser_session: false persist_browser_session: false
model: model:
model: gpt-3.5-turbo name: gpt-4.1
workflow_definition: workflow_definition:
parameters: parameters:
- key: website_url - key: website_url
@@ -121,7 +121,7 @@ workflow_definition = {
"webhook_callback_url": "https://example.com/webhook", "webhook_callback_url": "https://example.com/webhook",
"totp_verification_url": "https://example.com/totp", "totp_verification_url": "https://example.com/totp",
"totp_identifier": "4155555555", "totp_identifier": "4155555555",
"model": {"model": "gpt-3.5-turbo"}, "model": {"name": "gpt-4.1"},
"workflow_definition": { "workflow_definition": {
"parameters": [ "parameters": [
{ {
@@ -204,7 +204,8 @@ proxy_location: RESIDENTIAL
webhook_callback_url: https://example.com/webhook webhook_callback_url: https://example.com/webhook
totp_verification_url: https://example.com/totp totp_verification_url: https://example.com/totp
persist_browser_session: false persist_browser_session: false
model: {model: gpt-3.5-turbo} model:
name: gpt-4.1
workflow_definition: workflow_definition:
parameters: parameters:
- key: website_url - key: website_url
@@ -287,7 +288,7 @@ updated_workflow_definition = {
"webhook_callback_url": "https://example.com/webhook", "webhook_callback_url": "https://example.com/webhook",
"totp_verification_url": "https://example.com/totp", "totp_verification_url": "https://example.com/totp",
"totp_identifier": "4155555555", "totp_identifier": "4155555555",
"model": {"model": "gpt-3.5-turbo"}, "model": {"name": "gpt-4.1"},
"workflow_definition": { "workflow_definition": {
"parameters": [ "parameters": [
{ {

View File

@@ -57,7 +57,7 @@ class TaskV2(BaseModel):
""" """
if self.model: if self.model:
model_name = self.model.get("model_name") model_name = self.model.get("name")
if model_name: if model_name:
mapping = settings.get_model_name_to_llm_key() mapping = settings.get_model_name_to_llm_key()
llm_key = mapping.get(model_name) llm_key = mapping.get(model_name)

View File

@@ -248,7 +248,7 @@ class Task(TaskBase):
Otherwise return `None`. Otherwise return `None`.
""" """
if self.model: if self.model:
model_name = self.model.get("model_name") model_name = self.model.get("name")
if model_name: if model_name:
mapping = settings.get_model_name_to_llm_key() mapping = settings.get_model_name_to_llm_key()
return mapping.get(model_name) return mapping.get(model_name)

View File

@@ -629,7 +629,14 @@ class BaseTaskBlock(Block):
current_context = skyvern_context.ensure_context() current_context = skyvern_context.ensure_context()
current_context.task_id = task.task_id current_context.task_id = task.task_id
llm_key = workflow.determine_llm_key(block=self) llm_key = workflow.determine_llm_key(block=self)
llm_caller = None if not llm_key else LLMCaller(llm_key=llm_key) screenshot_scaling_enabled = False
if self.engine == RunEngine.anthropic_cua:
screenshot_scaling_enabled = True
llm_caller = (
None
if not llm_key
else LLMCaller(llm_key=llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled)
)
await app.agent.execute_step( await app.agent.execute_step(
organization=organization, organization=organization,

View File

@@ -94,14 +94,14 @@ class Workflow(BaseModel):
mapping = settings.get_model_name_to_llm_key() mapping = settings.get_model_name_to_llm_key()
if block: if block:
model_name = (block.model or {}).get("model") model_name = (block.model or {}).get("name")
if model_name: if model_name:
llm_key = mapping.get(model_name) llm_key = mapping.get(model_name)
if llm_key: if llm_key:
return llm_key return llm_key
workflow_model_name = (self.model or {}).get("model") workflow_model_name = (self.model or {}).get("name")
if workflow_model_name: if workflow_model_name:
llm_key = mapping.get(workflow_model_name) llm_key = mapping.get(workflow_model_name)

View File

@@ -292,6 +292,7 @@ class Skyvern(AsyncSkyvern):
self, self,
prompt: str, prompt: str,
engine: RunEngine = RunEngine.skyvern_v2, engine: RunEngine = RunEngine.skyvern_v2,
model: dict[str, Any] | None = None,
url: str | None = None, url: str | None = None,
webhook_url: str | None = None, webhook_url: str | None = None,
totp_identifier: str | None = None, totp_identifier: str | None = None,
@@ -325,6 +326,7 @@ class Skyvern(AsyncSkyvern):
task_request = TaskRequest( task_request = TaskRequest(
title=title or task_generation.suggested_title, title=title or task_generation.suggested_title,
url=url, url=url,
model=model,
navigation_goal=navigation_goal, navigation_goal=navigation_goal,
navigation_payload=navigation_payload, navigation_payload=navigation_payload,
data_extraction_goal=data_extraction_goal, data_extraction_goal=data_extraction_goal,
@@ -371,6 +373,7 @@ class Skyvern(AsyncSkyvern):
extracted_information_schema=data_extraction_schema, extracted_information_schema=data_extraction_schema,
error_code_mapping=error_code_mapping, error_code_mapping=error_code_mapping,
create_task_run=True, create_task_run=True,
model=model,
) )
await self._run_task_v2(organization, task_v2) await self._run_task_v2(organization, task_v2)

View File

@@ -70,6 +70,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.services.bitwarden import BitwardenConstants from skyvern.forge.sdk.services.bitwarden import BitwardenConstants
from skyvern.schemas.runs import CUA_RUN_TYPES
from skyvern.utils.prompt_engine import CheckPhoneNumberFormatResponse, load_prompt_with_elements from skyvern.utils.prompt_engine import CheckPhoneNumberFormatResponse, load_prompt_with_elements
from skyvern.webeye.actions import actions from skyvern.webeye.actions import actions
from skyvern.webeye.actions.actions import ( from skyvern.webeye.actions.actions import (
@@ -3363,12 +3364,18 @@ async def extract_information_for_navigation_goal(
local_datetime=datetime.now(context.tz_info).isoformat(), local_datetime=datetime.now(context.tz_info).isoformat(),
) )
task_run = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
llm_key_override = task.llm_key
if task_run and task_run.task_run_type in CUA_RUN_TYPES:
# CUA tasks should use the default data extraction llm key
llm_key_override = None
json_response = await app.LLM_API_HANDLER( json_response = await app.LLM_API_HANDLER(
prompt=extract_information_prompt, prompt=extract_information_prompt,
step=step, step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
prompt_name="extract-information", prompt_name="extract-information",
llm_key_override=task.llm_key, llm_key_override=llm_key_override,
) )
return ScrapeResult( return ScrapeResult(