fix llm_key_override (#2552)
This commit is contained in:
@@ -276,6 +276,8 @@ class Settings(BaseSettings):
|
||||
"Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
|
||||
"GPT 4.1": "OPENAI_GPT4_1",
|
||||
"GPT o3-mini": "OPENAI_O3_MINI",
|
||||
"bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
|
||||
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
|
||||
}
|
||||
else:
|
||||
# TODO: apparently the list for OSS is to be much larger
|
||||
@@ -284,6 +286,8 @@ class Settings(BaseSettings):
|
||||
"Gemini 2.5 Flash": "VERTEX_GEMINI_2.5_FLASH_PREVIEW_05_20",
|
||||
"GPT 4.1": "OPENAI_GPT4_1",
|
||||
"GPT o3-mini": "OPENAI_O3_MINI",
|
||||
"bedrock/us.anthropic.claude-opus-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_OPUS_INFERENCE_PROFILE",
|
||||
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0": "BEDROCK_ANTHROPIC_CLAUDE4_SONNET_INFERENCE_PROFILE",
|
||||
}
|
||||
|
||||
def is_cloud_environment(self) -> bool:
|
||||
|
||||
@@ -382,14 +382,17 @@ class ForgeAgent:
|
||||
if page := await browser_state.get_working_page():
|
||||
await self.register_async_operations(organization, task, page)
|
||||
|
||||
if not llm_caller:
|
||||
if engine == RunEngine.anthropic_cua and not llm_caller:
|
||||
# see if the llm_caller is already set in memory
|
||||
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
|
||||
if engine == RunEngine.anthropic_cua and not llm_caller:
|
||||
# llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE")
|
||||
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
|
||||
if not llm_caller:
|
||||
llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True)
|
||||
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
|
||||
if not llm_caller:
|
||||
# if not, create a new llm_caller
|
||||
llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True)
|
||||
|
||||
# TODO: remove the code after migrating everything to llm callers
|
||||
# currently, only anthropic cua tasks use llm_caller
|
||||
if engine == RunEngine.anthropic_cua and llm_caller:
|
||||
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
|
||||
|
||||
step, detailed_output = await self.agent_step(
|
||||
task,
|
||||
|
||||
@@ -96,10 +96,12 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
)
|
||||
run_obj = await app.DATABASE.get_run(run_id=task_id, organization_id=organization_id)
|
||||
engine = RunEngine.skyvern_v1
|
||||
screenshot_scaling_enabled = False
|
||||
if run_obj and run_obj.task_run_type == RunType.openai_cua:
|
||||
engine = RunEngine.openai_cua
|
||||
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
|
||||
engine = RunEngine.anthropic_cua
|
||||
screenshot_scaling_enabled = True
|
||||
|
||||
context: SkyvernContext = skyvern_context.ensure_context()
|
||||
context.task_id = task.task_id
|
||||
@@ -107,7 +109,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
context.max_steps_override = max_steps_override
|
||||
|
||||
llm_key = task.llm_key
|
||||
llm_caller = LLMCaller(llm_key) if llm_key else None
|
||||
llm_caller = LLMCaller(llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled) if llm_key else None
|
||||
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
|
||||
@@ -37,7 +37,7 @@ webhook_callback_url: https://example.com/webhook
|
||||
totp_verification_url: https://example.com/totp
|
||||
persist_browser_session: false
|
||||
model:
|
||||
model: gpt-3.5-turbo
|
||||
name: gpt-4.1
|
||||
workflow_definition:
|
||||
parameters:
|
||||
- key: website_url
|
||||
@@ -121,7 +121,7 @@ workflow_definition = {
|
||||
"webhook_callback_url": "https://example.com/webhook",
|
||||
"totp_verification_url": "https://example.com/totp",
|
||||
"totp_identifier": "4155555555",
|
||||
"model": {"model": "gpt-3.5-turbo"},
|
||||
"model": {"name": "gpt-4.1"},
|
||||
"workflow_definition": {
|
||||
"parameters": [
|
||||
{
|
||||
@@ -204,7 +204,8 @@ proxy_location: RESIDENTIAL
|
||||
webhook_callback_url: https://example.com/webhook
|
||||
totp_verification_url: https://example.com/totp
|
||||
persist_browser_session: false
|
||||
model: {model: gpt-3.5-turbo}
|
||||
model:
|
||||
name: gpt-4.1
|
||||
workflow_definition:
|
||||
parameters:
|
||||
- key: website_url
|
||||
@@ -287,7 +288,7 @@ updated_workflow_definition = {
|
||||
"webhook_callback_url": "https://example.com/webhook",
|
||||
"totp_verification_url": "https://example.com/totp",
|
||||
"totp_identifier": "4155555555",
|
||||
"model": {"model": "gpt-3.5-turbo"},
|
||||
"model": {"name": "gpt-4.1"},
|
||||
"workflow_definition": {
|
||||
"parameters": [
|
||||
{
|
||||
|
||||
@@ -57,7 +57,7 @@ class TaskV2(BaseModel):
|
||||
"""
|
||||
|
||||
if self.model:
|
||||
model_name = self.model.get("model_name")
|
||||
model_name = self.model.get("name")
|
||||
if model_name:
|
||||
mapping = settings.get_model_name_to_llm_key()
|
||||
llm_key = mapping.get(model_name)
|
||||
|
||||
@@ -248,7 +248,7 @@ class Task(TaskBase):
|
||||
Otherwise return `None`.
|
||||
"""
|
||||
if self.model:
|
||||
model_name = self.model.get("model_name")
|
||||
model_name = self.model.get("name")
|
||||
if model_name:
|
||||
mapping = settings.get_model_name_to_llm_key()
|
||||
return mapping.get(model_name)
|
||||
|
||||
@@ -629,7 +629,14 @@ class BaseTaskBlock(Block):
|
||||
current_context = skyvern_context.ensure_context()
|
||||
current_context.task_id = task.task_id
|
||||
llm_key = workflow.determine_llm_key(block=self)
|
||||
llm_caller = None if not llm_key else LLMCaller(llm_key=llm_key)
|
||||
screenshot_scaling_enabled = False
|
||||
if self.engine == RunEngine.anthropic_cua:
|
||||
screenshot_scaling_enabled = True
|
||||
llm_caller = (
|
||||
None
|
||||
if not llm_key
|
||||
else LLMCaller(llm_key=llm_key, screenshot_scaling_enabled=screenshot_scaling_enabled)
|
||||
)
|
||||
|
||||
await app.agent.execute_step(
|
||||
organization=organization,
|
||||
|
||||
@@ -94,14 +94,14 @@ class Workflow(BaseModel):
|
||||
mapping = settings.get_model_name_to_llm_key()
|
||||
|
||||
if block:
|
||||
model_name = (block.model or {}).get("model")
|
||||
model_name = (block.model or {}).get("name")
|
||||
|
||||
if model_name:
|
||||
llm_key = mapping.get(model_name)
|
||||
if llm_key:
|
||||
return llm_key
|
||||
|
||||
workflow_model_name = (self.model or {}).get("model")
|
||||
workflow_model_name = (self.model or {}).get("name")
|
||||
|
||||
if workflow_model_name:
|
||||
llm_key = mapping.get(workflow_model_name)
|
||||
|
||||
@@ -292,6 +292,7 @@ class Skyvern(AsyncSkyvern):
|
||||
self,
|
||||
prompt: str,
|
||||
engine: RunEngine = RunEngine.skyvern_v2,
|
||||
model: dict[str, Any] | None = None,
|
||||
url: str | None = None,
|
||||
webhook_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
@@ -325,6 +326,7 @@ class Skyvern(AsyncSkyvern):
|
||||
task_request = TaskRequest(
|
||||
title=title or task_generation.suggested_title,
|
||||
url=url,
|
||||
model=model,
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
@@ -371,6 +373,7 @@ class Skyvern(AsyncSkyvern):
|
||||
extracted_information_schema=data_extraction_schema,
|
||||
error_code_mapping=error_code_mapping,
|
||||
create_task_run=True,
|
||||
model=model,
|
||||
)
|
||||
|
||||
await self._run_task_v2(organization, task_v2)
|
||||
|
||||
@@ -70,6 +70,7 @@ from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.services.bitwarden import BitwardenConstants
|
||||
from skyvern.schemas.runs import CUA_RUN_TYPES
|
||||
from skyvern.utils.prompt_engine import CheckPhoneNumberFormatResponse, load_prompt_with_elements
|
||||
from skyvern.webeye.actions import actions
|
||||
from skyvern.webeye.actions.actions import (
|
||||
@@ -3363,12 +3364,18 @@ async def extract_information_for_navigation_goal(
|
||||
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||
)
|
||||
|
||||
task_run = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
||||
llm_key_override = task.llm_key
|
||||
if task_run and task_run.task_run_type in CUA_RUN_TYPES:
|
||||
# CUA tasks should use the default data extraction llm key
|
||||
llm_key_override = None
|
||||
|
||||
json_response = await app.LLM_API_HANDLER(
|
||||
prompt=extract_information_prompt,
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
prompt_name="extract-information",
|
||||
llm_key_override=task.llm_key,
|
||||
llm_key_override=llm_key_override,
|
||||
)
|
||||
|
||||
return ScrapeResult(
|
||||
|
||||
Reference in New Issue
Block a user