anthropic CUA (#2231)

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Shuchang Zheng
2025-04-28 09:49:44 +08:00
committed by GitHub
parent 5582998490
commit 0a0228b341
18 changed files with 378 additions and 45 deletions

View File

@@ -58,6 +58,7 @@ from skyvern.forge.sdk.api.files import (
rename_file,
wait_for_download_finished,
)
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller, LLMCallerManager
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
@@ -70,7 +71,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
from skyvern.schemas.runs import RunEngine, RunType
from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine
from skyvern.utils.prompt_engine import load_prompt_with_elements
from skyvern.webeye.actions.actions import (
Action,
@@ -88,7 +89,7 @@ from skyvern.webeye.actions.actions import (
from skyvern.webeye.actions.caching import retrieve_action_plan
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.parse_actions import parse_actions, parse_cua_actions
from skyvern.webeye.actions.parse_actions import parse_actions, parse_anthropic_actions, parse_cua_actions
from skyvern.webeye.actions.responses import ActionResult, ActionSuccess
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
@@ -253,6 +254,7 @@ class ForgeAgent:
complete_verification: bool = True,
engine: RunEngine = RunEngine.skyvern_v1,
cua_response: OpenAIResponse | None = None,
llm_caller: LLMCaller | None = None,
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
workflow_run: WorkflowRun | None = None
if task.workflow_run_id:
@@ -378,6 +380,13 @@ class ForgeAgent:
if page := await browser_state.get_working_page():
await self.register_async_operations(organization, task, page)
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if engine == RunEngine.anthropic_cua and not llm_caller:
# llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE")
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if not llm_caller:
llm_caller = LLMCaller(llm_key="ANTHROPIC_CLAUDE3.5_SONNET")
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
step, detailed_output = await self.agent_step(
task,
step,
@@ -387,6 +396,7 @@ class ForgeAgent:
complete_verification=complete_verification,
engine=engine,
cua_response=cua_response,
llm_caller=llm_caller,
)
await app.AGENT_FUNCTION.post_step_execution(task, step)
task = await self.update_task_errors_from_detailed_output(task, detailed_output)
@@ -778,6 +788,7 @@ class ForgeAgent:
task_block: BaseTaskBlock | None = None,
complete_verification: bool = True,
cua_response: OpenAIResponse | None = None,
llm_caller: LLMCaller | None = None,
) -> tuple[Step, DetailedAgentStepOutput]:
detailed_agent_step_output = DetailedAgentStepOutput(
scraped_page=None,
@@ -821,8 +832,17 @@ class ForgeAgent:
step=step,
scraped_page=scraped_page,
previous_response=cua_response,
engine=engine,
)
detailed_agent_step_output.cua_response = new_cua_response
elif engine == RunEngine.anthropic_cua:
assert llm_caller is not None
actions = await self._generate_anthropic_actions(
task=task,
step=step,
scraped_page=scraped_page,
llm_caller=llm_caller,
)
else:
using_cached_action_plan = False
if not task.navigation_goal and not isinstance(task_block, ValidationBlock):
@@ -834,7 +854,7 @@ class ForgeAgent:
):
using_cached_action_plan = True
else:
if engine != RunEngine.openai_cua:
if engine in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
json_response = await app.LLM_API_HANDLER(
prompt=extract_action_prompt,
@@ -1219,7 +1239,8 @@ class ForgeAgent:
step: Step,
scraped_page: ScrapedPage,
previous_response: OpenAIResponse | None = None,
) -> tuple[list[Action], OpenAIResponse]:
engine: RunEngine = RunEngine.openai_cua,
) -> tuple[list[Action], OpenAIResponse | None]:
if not previous_response:
# this is the first step
first_response: OpenAIResponse = await app.OPENAI_CLIENT.responses.create(
@@ -1377,6 +1398,48 @@ class ForgeAgent:
return await parse_cua_actions(task, step, current_response), current_response
async def _generate_anthropic_actions(
self,
task: Task,
step: Step,
scraped_page: ScrapedPage,
llm_caller: LLMCaller,
) -> list[Action]:
if llm_caller.current_tool_results:
llm_caller.message_history.append({"role": "user", "content": llm_caller.current_tool_results})
llm_caller.clear_tool_results()
tools = [
{
"type": "computer_20250124",
"name": "computer",
"display_height_px": settings.BROWSER_HEIGHT,
"display_width_px": settings.BROWSER_WIDTH,
}
]
if not llm_caller.message_history:
llm_response = await llm_caller.call(
prompt=task.navigation_goal,
screenshots=scraped_page.screenshots,
use_message_history=True,
tools=tools,
raw_response=True,
betas=["computer-use-2025-01-24"],
)
else:
llm_response = await llm_caller.call(
screenshots=scraped_page.screenshots,
use_message_history=True,
tools=tools,
raw_response=True,
betas=["computer-use-2025-01-24"],
)
LOG.info("Anthropic response", llm_response=llm_response)
assistant_content = llm_response["content"]
llm_caller.message_history.append({"role": "assistant", "content": assistant_content})
actions = await parse_anthropic_actions(task, step, assistant_content)
return actions
@staticmethod
async def complete_verify(page: Page, scraped_page: ScrapedPage, task: Task, step: Step) -> CompleteVerifyResult:
LOG.info(
@@ -1387,7 +1450,7 @@ class ForgeAgent:
)
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
scroll = True
if run_obj and run_obj.task_run_type == RunType.openai_cua:
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
scroll = False
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll)
@@ -1454,7 +1517,7 @@ class ForgeAgent:
raise BrowserStateMissingPage()
fullpage_screenshot = True
if engine == RunEngine.openai_cua:
if engine in CUA_ENGINES:
fullpage_screenshot = False
try:
@@ -1580,7 +1643,7 @@ class ForgeAgent:
max_screenshot_number = settings.MAX_NUM_SCREENSHOTS
draw_boxes = True
scroll = True
if engine == RunEngine.openai_cua:
if engine in CUA_ENGINES:
max_screenshot_number = 1
draw_boxes = False
scroll = False
@@ -1602,7 +1665,7 @@ class ForgeAgent:
engine: RunEngine,
) -> tuple[ScrapedPage, str]:
# start the async tasks while running scrape_website
if engine != RunEngine.openai_cua:
if engine not in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
# Scrape the web page and get the screenshot and the elements
@@ -1653,7 +1716,7 @@ class ForgeAgent:
element_tree_format = ElementTreeFormat.HTML
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
extract_action_prompt = ""
if engine != RunEngine.openai_cua:
if engine not in CUA_ENGINES:
extract_action_prompt = await self._build_extract_action_prompt(
task,
step,
@@ -2371,7 +2434,7 @@ class ForgeAgent:
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
scroll = True
if run_obj and run_obj.task_run_type == RunType.openai_cua:
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
scroll = False
screenshots: list[bytes] = []

View File

@@ -1,5 +1,6 @@
from typing import Awaitable, Callable
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
from fastapi import FastAPI
from openai import AsyncAzureOpenAI, AsyncOpenAI
@@ -41,6 +42,9 @@ if SettingsManager.get_settings().ENABLE_AZURE_CUA:
azure_endpoint=SettingsManager.get_settings().AZURE_CUA_ENDPOINT,
azure_deployment=SettingsManager.get_settings().AZURE_CUA_DEPLOYMENT,
)
ANTHROPIC_CLIENT = AsyncAnthropic(api_key=SettingsManager.get_settings().ANTHROPIC_API_KEY)
if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC:
ANTHROPIC_CLIENT = AsyncAnthropicBedrock()
SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY

View File

@@ -6,7 +6,9 @@ from typing import Any
import litellm
import structlog
from anthropic.types.message import Message as AnthropicMessage
from jinja2 import Template
from litellm.utils import CustomStreamWrapper, ModelResponse
from skyvern.config import settings
from skyvern.exceptions import SkyvernContextWindowExceededError
@@ -456,11 +458,18 @@ class LLMCaller:
self.llm_config = LLMConfigRegistry.get_config(llm_key)
self.base_parameters = base_parameters
self.message_history: list[dict[str, Any]] = []
self.current_tool_results: list[dict[str, Any]] = []
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
self.current_tool_results.append(tool_result)
def clear_tool_results(self) -> None:
self.current_tool_results = []
async def call(
self,
prompt: str,
prompt_name: str,
prompt: str | None = None,
prompt_name: str | None = None,
step: Step | None = None,
task_v2: TaskV2 | None = None,
thought: Thought | None = None,
@@ -469,6 +478,8 @@ class LLMCaller:
parameters: dict[str, Any] | None = None,
tools: list | None = None,
use_message_history: bool = False,
raw_response: bool = False,
**extra_parameters: Any,
) -> dict[str, Any]:
start_time = time.perf_counter()
active_parameters = self.base_parameters or {}
@@ -476,6 +487,8 @@ class LLMCaller:
parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config)
active_parameters.update(parameters)
if extra_parameters:
active_parameters.update(extra_parameters)
if self.llm_config.litellm_params: # type: ignore
active_parameters.update(self.llm_config.litellm_params) # type: ignore
@@ -491,7 +504,7 @@ class LLMCaller:
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
data=prompt.encode("utf-8") if prompt else b"",
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
@@ -525,8 +538,7 @@ class LLMCaller:
)
t_llm_request = time.perf_counter()
try:
response = await litellm.acompletion(
model=self.llm_config.model_name,
response = await self._dispatch_llm_call(
messages=messages,
tools=tools,
timeout=settings.LLM_CONFIG_TIMEOUT,
@@ -603,6 +615,21 @@ class LLMCaller:
cached_token_count=cached_tokens if cached_tokens > 0 else None,
thought_cost=llm_cost,
)
# Track LLM API handler duration
duration_seconds = time.perf_counter() - start_time
LOG.info(
"LLM API handler duration metrics",
llm_key=self.llm_key,
prompt_name=prompt_name,
model=self.llm_config.model_name,
duration_seconds=duration_seconds,
step_id=step.step_id if step else None,
thought_id=thought.observer_thought_id if thought else None,
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
)
if raw_response:
return response.model_dump()
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
@@ -626,17 +653,53 @@ class LLMCaller:
ai_suggestion=ai_suggestion,
)
# Track LLM API handler duration
duration_seconds = time.perf_counter() - start_time
LOG.info(
"LLM API handler duration metrics",
llm_key=self.llm_key,
prompt_name=prompt_name,
model=self.llm_config.model_name,
duration_seconds=duration_seconds,
step_id=step.step_id if step else None,
thought_id=thought.observer_thought_id if thought else None,
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
return parsed_response
async def _dispatch_llm_call(
self,
messages: list[dict[str, Any]],
tools: list | None = None,
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
if self.llm_key and self.llm_key.startswith("ANTHROPIC"):
return await self._call_anthropic(messages, tools, timeout)
return await litellm.acompletion(
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
)
return parsed_response
async def _call_anthropic(
self,
messages: list[dict[str, Any]],
tools: list | None = None,
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> AnthropicMessage:
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 4096
model_name = self.llm_config.model_name.replace("bedrock/", "").replace("anthropic/", "")
return await app.ANTHROPIC_CLIENT.messages.create(
max_tokens=max_tokens,
messages=messages,
model=model_name,
tools=tools,
timeout=timeout,
betas=active_parameters.get("betas", None),
)
class LLMCallerManager:
_llm_callers: dict[str, LLMCaller] = {}
@classmethod
def get_llm_caller(cls, uid: str) -> LLMCaller | None:
return cls._llm_callers.get(uid)
@classmethod
def set_llm_caller(cls, uid: str, llm_caller: LLMCaller) -> None:
cls._llm_callers[uid] = llm_caller
@classmethod
def clear_llm_caller(cls, uid: str) -> None:
if uid in cls._llm_callers:
del cls._llm_callers[uid]

View File

@@ -47,19 +47,22 @@ async def llm_messages_builder(
async def llm_messages_builder_with_history(
prompt: str,
prompt: str | None = None,
screenshots: list[bytes] | None = None,
message_history: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
if message_history:
messages = copy.deepcopy(message_history)
current_user_messages: list[dict[str, Any]] = [
{
"type": "text",
"text": prompt,
}
]
current_user_messages: list[dict[str, Any]] = []
if prompt:
current_user_messages.append(
{
"type": "text",
"text": prompt,
}
)
if screenshots:
for screenshot in screenshots:

View File

@@ -96,6 +96,8 @@ class BackgroundTaskExecutor(AsyncExecutor):
engine = RunEngine.skyvern_v1
if run_obj and run_obj.task_run_type == RunType.openai_cua:
engine = RunEngine.openai_cua
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
engine = RunEngine.anthropic_cua
context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id

View File

@@ -62,6 +62,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
)
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import (
CUA_ENGINES,
RunEngine,
RunResponse,
RunType,
@@ -1466,7 +1467,7 @@ async def run_task(
analytics.capture("skyvern-oss-run-task", data={"url": run_request.url})
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id)
if run_request.engine in [RunEngine.skyvern_v1, RunEngine.openai_cua]:
if run_request.engine in CUA_ENGINES:
# create task v1
# if there's no url, call task generation first to generate the url, data schema if any
url = run_request.url
@@ -1480,7 +1481,7 @@ async def run_task(
)
url = url or task_generation.url
navigation_goal = task_generation.navigation_goal or run_request.prompt
if run_request.engine == RunEngine.openai_cua:
if run_request.engine in CUA_ENGINES:
navigation_goal = run_request.prompt
navigation_payload = task_generation.navigation_payload
data_extraction_goal = task_generation.data_extraction_goal
@@ -1511,6 +1512,8 @@ async def run_task(
run_type = RunType.task_v1
if run_request.engine == RunEngine.openai_cua:
run_type = RunType.openai_cua
elif run_request.engine == RunEngine.anthropic_cua:
run_type = RunType.anthropic_cua
# build the task run response
return TaskRunResponse(
run_id=task_v1_response.task_id,
@@ -1586,8 +1589,6 @@ async def run_task(
publish_workflow=run_request.publish_workflow,
),
)
if run_request.engine == RunEngine.openai_cua:
pass
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")