From 48f5f0913e8799d06dd09a8d296b250340526dbe Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sat, 31 May 2025 11:11:25 -0700 Subject: [PATCH] fix llm_key_override (#2552) --- skyvern/config.py | 4 ++++ skyvern/forge/agent.py | 17 ++++++++++------- skyvern/forge/sdk/executor/async_executor.py | 4 +++- skyvern/forge/sdk/routes/code_samples.py | 9 +++++---- skyvern/forge/sdk/schemas/task_v2.py | 2 +- skyvern/forge/sdk/schemas/tasks.py | 2 +- skyvern/forge/sdk/workflow/models/block.py | 9 ++++++++- skyvern/forge/sdk/workflow/models/workflow.py | 4 ++-- skyvern/library/skyvern.py | 3 +++ skyvern/webeye/actions/handler.py | 9 ++++++++- 10 files changed, 45 insertions(+), 18 deletions(-) diff --git a/skyvern/config.py b/skyvern/config.py index 1f9346a4..991d29f7 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -276,6 +276,8 @@ class Settings(BaseSettings): "Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", "GPT 4.1": "OPENAI_GPT4_1", "GPT o3-mini": "OPENAI_O3_MINI", + "bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE", + "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE", } else: # TODO: apparently the list for OSS is to be much larger @@ -284,6 +286,8 @@ class Settings(BaseSettings): "Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20", "GPT 4.1": "OPENAI_GPT4_1", "GPT o3-mini": "OPENAI_O3_MINI", + "bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE", + "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE", } def is_cloud_environment(self) -> bool: diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index a5d0876f..569e6ec5 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -382,14 +382,17 @@ class ForgeAgent: if page := await browser_state.get_working_page(): await self.register_async_operations(organization, task, page) - if not llm_caller: + if engine == RunEngine.anthropic_cua and not llm_caller: + # see if the llm_caller is already set in memory llm_caller = LLMCallerManager.get_llm_caller(task.task_id) - if engine == RunEngine.anthropic_cua and not llm_caller: - # llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE") - llm_caller = LLMCallerManager.get_llm_caller(task.task_id) - if not llm_caller: - llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True) - LLMCallerManager.set_llm_caller(task.task_id, llm_caller) + if not llm_caller: + # if not, create a new llm_caller + llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True) + + # TODO: remove the code after migrating everything to llm callers + # currently, only anthropic cua tasks use llm_caller + if engine == RunEngine.anthropic_cua and llm_caller: + LLMCallerManager.set_llm_caller(task.task_id, llm_caller) step, detailed_output = await self.agent_step( task, diff --git a/skyvern/forge/sdk/executor/async_executor.py b/skyvern/forge/sdk/executor/async_executor.py index 2adb64e2..e377cdd9 100644 --- a/skyvern/forge/sdk/executor/async_executor.py +++ b/skyvern/forge/sdk/executor/async_executor.py @@ -96,10 +96,12 @@ class BackgroundTaskExecutor(AsyncExecutor): ) run_obj = await app.DATABASE.get_run(run_id=task_id, organization_id=organization_id) engine = RunEngine.skyvern_v1 + screenshot_scaling_enabled = False if run_obj and run_obj.task_run_type == RunType.openai_cua: engine = RunEngine.openai_cua elif run_obj and run_obj.task_run_type == RunType.anthropic_cua: engine = RunEngine.anthropic_cua + screenshot_scaling_enabled = True context: SkyvernContext = skyvern_context.ensure_context() context.task_id = task.task_id @@ -107,7 +109,7 @@ class BackgroundTaskExecutor(AsyncExecutor): context.max_steps_override = max_steps_override llm_key = task.llm_key - llm_caller = LLMCaller(llm_key) if llm_key else None + llm_caller = LLMCaller(llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled) if llm_key else None if background_tasks: background_tasks.add_task( diff --git a/skyvern/forge/sdk/routes/code_samples.py b/skyvern/forge/sdk/routes/code_samples.py index f7524e3a..2dc5e5d5 100644 --- a/skyvern/forge/sdk/routes/code_samples.py +++ b/skyvern/forge/sdk/routes/code_samples.py @@ -37,7 +37,7 @@ webhook_callback_url: https://example.com/webhook totp_verification_url: https://example.com/totp persist_browser_session: false model: - model: gpt-3.5-turbo + name: gpt-4.1 workflow_definition: parameters: - key: website_url @@ -121,7 +121,7 @@ workflow_definition = { "webhook_callback_url": "https://example.com/webhook", "totp_verification_url": "https://example.com/totp", "totp_identifier": "4155555555", - "model": {"model": "gpt-3.5-turbo"}, + "model": {"name": "gpt-4.1"}, "workflow_definition": { "parameters": [ { @@ -204,7 +204,8 @@ proxy_location: RESIDENTIAL webhook_callback_url: https://example.com/webhook totp_verification_url: https://example.com/totp persist_browser_session: false -model: {model: gpt-3.5-turbo} +model: + name: gpt-4.1 workflow_definition: parameters: - key: website_url @@ -287,7 +288,7 @@ updated_workflow_definition = { "webhook_callback_url": "https://example.com/webhook", "totp_verification_url": "https://example.com/totp", "totp_identifier": "4155555555", - "model": {"model": "gpt-3.5-turbo"}, + "model": {"name": "gpt-4.1"}, "workflow_definition": { "parameters": [ { diff --git a/skyvern/forge/sdk/schemas/task_v2.py b/skyvern/forge/sdk/schemas/task_v2.py index 57c8c516..9d76f001 100644 --- a/skyvern/forge/sdk/schemas/task_v2.py +++ b/skyvern/forge/sdk/schemas/task_v2.py @@ -57,7 +57,7 @@ class TaskV2(BaseModel): """ if self.model: - model_name = self.model.get("model_name") + model_name = self.model.get("name") if model_name: mapping = settings.get_model_name_to_llm_key() llm_key = mapping.get(model_name) diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 47647984..c154012b 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -248,7 +248,7 @@ class Task(TaskBase): Otherwise return `None`. """ if self.model: - model_name = self.model.get("model_name") + model_name = self.model.get("name") if model_name: mapping = settings.get_model_name_to_llm_key() return mapping.get(model_name) diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 9ddf6e46..037da9ca 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -629,7 +629,14 @@ class BaseTaskBlock(Block): current_context = skyvern_context.ensure_context() current_context.task_id = task.task_id llm_key = workflow.determine_llm_key(block=self) - llm_caller = None if not llm_key else LLMCaller(llm_key=llm_key) + screenshot_scaling_enabled = False + if self.engine == RunEngine.anthropic_cua: + screenshot_scaling_enabled = True + llm_caller = ( + None + if not llm_key + else LLMCaller(llm_key=llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled) + ) await app.agent.execute_step( organization=organization, diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index 61d55b42..254d6dac 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -94,14 +94,14 @@ class Workflow(BaseModel): mapping = settings.get_model_name_to_llm_key() if block: - model_name = (block.model or {}).get("model") + model_name = (block.model or {}).get("name") if model_name: llm_key = mapping.get(model_name) if llm_key: return llm_key - workflow_model_name = (self.model or {}).get("model") + workflow_model_name = (self.model or {}).get("name") if workflow_model_name: llm_key = mapping.get(workflow_model_name) diff --git a/skyvern/library/skyvern.py b/skyvern/library/skyvern.py index c8e7fbbf..1196b382 100644 --- a/skyvern/library/skyvern.py +++ b/skyvern/library/skyvern.py @@ -292,6 +292,7 @@ class Skyvern(AsyncSkyvern): self, prompt: str, engine: RunEngine = RunEngine.skyvern_v2, + model: dict[str, Any] | None = None, url: str | None = None, webhook_url: str | None = None, totp_identifier: str | None = None, @@ -325,6 +326,7 @@ class Skyvern(AsyncSkyvern): task_request = TaskRequest( title=title or task_generation.suggested_title, url=url, + model=model, navigation_goal=navigation_goal, navigation_payload=navigation_payload, data_extraction_goal=data_extraction_goal, @@ -371,6 +373,7 @@ class Skyvern(AsyncSkyvern): extracted_information_schema=data_extraction_schema, error_code_mapping=error_code_mapping, create_task_run=True, + model=model, ) await self._run_task_v2(organization, task_v2) diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index 13e673d7..0a4b4fc5 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -70,6 +70,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.tasks import Task from skyvern.forge.sdk.services.bitwarden import BitwardenConstants +from skyvern.schemas.runs import CUA_RUN_TYPES from skyvern.utils.prompt_engine import CheckPhoneNumberFormatResponse, load_prompt_with_elements from skyvern.webeye.actions import actions from skyvern.webeye.actions.actions import ( @@ -3363,12 +3364,18 @@ async def extract_information_for_navigation_goal( local_datetime=datetime.now(context.tz_info).isoformat(), ) + task_run = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) + llm_key_override = task.llm_key + if task_run and task_run.task_run_type in CUA_RUN_TYPES: + # CUA tasks should use the default data extraction llm key + llm_key_override = None + json_response = await app.LLM_API_HANDLER( prompt=extract_information_prompt, step=step, screenshots=scraped_page.screenshots, prompt_name="extract-information", - llm_key_override=task.llm_key, + llm_key_override=llm_key_override, ) return ScrapeResult(