anthropic CUA (#2231)
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
@@ -58,6 +58,7 @@ from skyvern.forge.sdk.api.files import (
|
||||
rename_file,
|
||||
wait_for_download_finished,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller, LLMCallerManager
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
||||
@@ -70,7 +71,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
|
||||
from skyvern.schemas.runs import RunEngine, RunType
|
||||
from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine
|
||||
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
||||
from skyvern.webeye.actions.actions import (
|
||||
Action,
|
||||
@@ -88,7 +89,7 @@ from skyvern.webeye.actions.actions import (
|
||||
from skyvern.webeye.actions.caching import retrieve_action_plan
|
||||
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
|
||||
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
|
||||
from skyvern.webeye.actions.parse_actions import parse_actions, parse_cua_actions
|
||||
from skyvern.webeye.actions.parse_actions import parse_actions, parse_anthropic_actions, parse_cua_actions
|
||||
from skyvern.webeye.actions.responses import ActionResult, ActionSuccess
|
||||
from skyvern.webeye.browser_factory import BrowserState
|
||||
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
|
||||
@@ -253,6 +254,7 @@ class ForgeAgent:
|
||||
complete_verification: bool = True,
|
||||
engine: RunEngine = RunEngine.skyvern_v1,
|
||||
cua_response: OpenAIResponse | None = None,
|
||||
llm_caller: LLMCaller | None = None,
|
||||
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
|
||||
workflow_run: WorkflowRun | None = None
|
||||
if task.workflow_run_id:
|
||||
@@ -378,6 +380,13 @@ class ForgeAgent:
|
||||
if page := await browser_state.get_working_page():
|
||||
await self.register_async_operations(organization, task, page)
|
||||
|
||||
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
|
||||
if engine == RunEngine.anthropic_cua and not llm_caller:
|
||||
# llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE")
|
||||
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
|
||||
if not llm_caller:
|
||||
llm_caller = LLMCaller(llm_key="ANTHROPIC_CLAUDE3.5_SONNET")
|
||||
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
|
||||
step, detailed_output = await self.agent_step(
|
||||
task,
|
||||
step,
|
||||
@@ -387,6 +396,7 @@ class ForgeAgent:
|
||||
complete_verification=complete_verification,
|
||||
engine=engine,
|
||||
cua_response=cua_response,
|
||||
llm_caller=llm_caller,
|
||||
)
|
||||
await app.AGENT_FUNCTION.post_step_execution(task, step)
|
||||
task = await self.update_task_errors_from_detailed_output(task, detailed_output)
|
||||
@@ -778,6 +788,7 @@ class ForgeAgent:
|
||||
task_block: BaseTaskBlock | None = None,
|
||||
complete_verification: bool = True,
|
||||
cua_response: OpenAIResponse | None = None,
|
||||
llm_caller: LLMCaller | None = None,
|
||||
) -> tuple[Step, DetailedAgentStepOutput]:
|
||||
detailed_agent_step_output = DetailedAgentStepOutput(
|
||||
scraped_page=None,
|
||||
@@ -821,8 +832,17 @@ class ForgeAgent:
|
||||
step=step,
|
||||
scraped_page=scraped_page,
|
||||
previous_response=cua_response,
|
||||
engine=engine,
|
||||
)
|
||||
detailed_agent_step_output.cua_response = new_cua_response
|
||||
elif engine == RunEngine.anthropic_cua:
|
||||
assert llm_caller is not None
|
||||
actions = await self._generate_anthropic_actions(
|
||||
task=task,
|
||||
step=step,
|
||||
scraped_page=scraped_page,
|
||||
llm_caller=llm_caller,
|
||||
)
|
||||
else:
|
||||
using_cached_action_plan = False
|
||||
if not task.navigation_goal and not isinstance(task_block, ValidationBlock):
|
||||
@@ -834,7 +854,7 @@ class ForgeAgent:
|
||||
):
|
||||
using_cached_action_plan = True
|
||||
else:
|
||||
if engine != RunEngine.openai_cua:
|
||||
if engine in CUA_ENGINES:
|
||||
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
|
||||
json_response = await app.LLM_API_HANDLER(
|
||||
prompt=extract_action_prompt,
|
||||
@@ -1219,7 +1239,8 @@ class ForgeAgent:
|
||||
step: Step,
|
||||
scraped_page: ScrapedPage,
|
||||
previous_response: OpenAIResponse | None = None,
|
||||
) -> tuple[list[Action], OpenAIResponse]:
|
||||
engine: RunEngine = RunEngine.openai_cua,
|
||||
) -> tuple[list[Action], OpenAIResponse | None]:
|
||||
if not previous_response:
|
||||
# this is the first step
|
||||
first_response: OpenAIResponse = await app.OPENAI_CLIENT.responses.create(
|
||||
@@ -1377,6 +1398,48 @@ class ForgeAgent:
|
||||
|
||||
return await parse_cua_actions(task, step, current_response), current_response
|
||||
|
||||
async def _generate_anthropic_actions(
|
||||
self,
|
||||
task: Task,
|
||||
step: Step,
|
||||
scraped_page: ScrapedPage,
|
||||
llm_caller: LLMCaller,
|
||||
) -> list[Action]:
|
||||
if llm_caller.current_tool_results:
|
||||
llm_caller.message_history.append({"role": "user", "content": llm_caller.current_tool_results})
|
||||
llm_caller.clear_tool_results()
|
||||
tools = [
|
||||
{
|
||||
"type": "computer_20250124",
|
||||
"name": "computer",
|
||||
"display_height_px": settings.BROWSER_HEIGHT,
|
||||
"display_width_px": settings.BROWSER_WIDTH,
|
||||
}
|
||||
]
|
||||
if not llm_caller.message_history:
|
||||
llm_response = await llm_caller.call(
|
||||
prompt=task.navigation_goal,
|
||||
screenshots=scraped_page.screenshots,
|
||||
use_message_history=True,
|
||||
tools=tools,
|
||||
raw_response=True,
|
||||
betas=["computer-use-2025-01-24"],
|
||||
)
|
||||
else:
|
||||
llm_response = await llm_caller.call(
|
||||
screenshots=scraped_page.screenshots,
|
||||
use_message_history=True,
|
||||
tools=tools,
|
||||
raw_response=True,
|
||||
betas=["computer-use-2025-01-24"],
|
||||
)
|
||||
LOG.info("Anthropic response", llm_response=llm_response)
|
||||
assistant_content = llm_response["content"]
|
||||
llm_caller.message_history.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
actions = await parse_anthropic_actions(task, step, assistant_content)
|
||||
return actions
|
||||
|
||||
@staticmethod
|
||||
async def complete_verify(page: Page, scraped_page: ScrapedPage, task: Task, step: Step) -> CompleteVerifyResult:
|
||||
LOG.info(
|
||||
@@ -1387,7 +1450,7 @@ class ForgeAgent:
|
||||
)
|
||||
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
||||
scroll = True
|
||||
if run_obj and run_obj.task_run_type == RunType.openai_cua:
|
||||
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
||||
scroll = False
|
||||
|
||||
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll)
|
||||
@@ -1454,7 +1517,7 @@ class ForgeAgent:
|
||||
raise BrowserStateMissingPage()
|
||||
|
||||
fullpage_screenshot = True
|
||||
if engine == RunEngine.openai_cua:
|
||||
if engine in CUA_ENGINES:
|
||||
fullpage_screenshot = False
|
||||
|
||||
try:
|
||||
@@ -1580,7 +1643,7 @@ class ForgeAgent:
|
||||
max_screenshot_number = settings.MAX_NUM_SCREENSHOTS
|
||||
draw_boxes = True
|
||||
scroll = True
|
||||
if engine == RunEngine.openai_cua:
|
||||
if engine in CUA_ENGINES:
|
||||
max_screenshot_number = 1
|
||||
draw_boxes = False
|
||||
scroll = False
|
||||
@@ -1602,7 +1665,7 @@ class ForgeAgent:
|
||||
engine: RunEngine,
|
||||
) -> tuple[ScrapedPage, str]:
|
||||
# start the async tasks while running scrape_website
|
||||
if engine != RunEngine.openai_cua:
|
||||
if engine not in CUA_ENGINES:
|
||||
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
|
||||
|
||||
# Scrape the web page and get the screenshot and the elements
|
||||
@@ -1653,7 +1716,7 @@ class ForgeAgent:
|
||||
element_tree_format = ElementTreeFormat.HTML
|
||||
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
|
||||
extract_action_prompt = ""
|
||||
if engine != RunEngine.openai_cua:
|
||||
if engine not in CUA_ENGINES:
|
||||
extract_action_prompt = await self._build_extract_action_prompt(
|
||||
task,
|
||||
step,
|
||||
@@ -2371,7 +2434,7 @@ class ForgeAgent:
|
||||
|
||||
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
||||
scroll = True
|
||||
if run_obj and run_obj.task_run_type == RunType.openai_cua:
|
||||
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
||||
scroll = False
|
||||
|
||||
screenshots: list[bytes] = []
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
|
||||
from fastapi import FastAPI
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
@@ -41,6 +42,9 @@ if SettingsManager.get_settings().ENABLE_AZURE_CUA:
|
||||
azure_endpoint=SettingsManager.get_settings().AZURE_CUA_ENDPOINT,
|
||||
azure_deployment=SettingsManager.get_settings().AZURE_CUA_DEPLOYMENT,
|
||||
)
|
||||
ANTHROPIC_CLIENT = AsyncAnthropic(api_key=SettingsManager.get_settings().ANTHROPIC_API_KEY)
|
||||
if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC:
|
||||
ANTHROPIC_CLIENT = AsyncAnthropicBedrock()
|
||||
|
||||
SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY
|
||||
|
||||
@@ -6,7 +6,9 @@ from typing import Any
|
||||
|
||||
import litellm
|
||||
import structlog
|
||||
from anthropic.types.message import Message as AnthropicMessage
|
||||
from jinja2 import Template
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import SkyvernContextWindowExceededError
|
||||
@@ -456,11 +458,18 @@ class LLMCaller:
|
||||
self.llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
self.base_parameters = base_parameters
|
||||
self.message_history: list[dict[str, Any]] = []
|
||||
self.current_tool_results: list[dict[str, Any]] = []
|
||||
|
||||
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
|
||||
self.current_tool_results.append(tool_result)
|
||||
|
||||
def clear_tool_results(self) -> None:
|
||||
self.current_tool_results = []
|
||||
|
||||
async def call(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_name: str,
|
||||
prompt: str | None = None,
|
||||
prompt_name: str | None = None,
|
||||
step: Step | None = None,
|
||||
task_v2: TaskV2 | None = None,
|
||||
thought: Thought | None = None,
|
||||
@@ -469,6 +478,8 @@ class LLMCaller:
|
||||
parameters: dict[str, Any] | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
**extra_parameters: Any,
|
||||
) -> dict[str, Any]:
|
||||
start_time = time.perf_counter()
|
||||
active_parameters = self.base_parameters or {}
|
||||
@@ -476,6 +487,8 @@ class LLMCaller:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config)
|
||||
|
||||
active_parameters.update(parameters)
|
||||
if extra_parameters:
|
||||
active_parameters.update(extra_parameters)
|
||||
if self.llm_config.litellm_params: # type: ignore
|
||||
active_parameters.update(self.llm_config.litellm_params) # type: ignore
|
||||
|
||||
@@ -491,7 +504,7 @@ class LLMCaller:
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=prompt.encode("utf-8"),
|
||||
data=prompt.encode("utf-8") if prompt else b"",
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
screenshots=screenshots,
|
||||
step=step,
|
||||
@@ -525,8 +538,7 @@ class LLMCaller:
|
||||
)
|
||||
t_llm_request = time.perf_counter()
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=self.llm_config.model_name,
|
||||
response = await self._dispatch_llm_call(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
timeout=settings.LLM_CONFIG_TIMEOUT,
|
||||
@@ -603,6 +615,21 @@ class LLMCaller:
|
||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||
thought_cost=llm_cost,
|
||||
)
|
||||
# Track LLM API handler duration
|
||||
duration_seconds = time.perf_counter() - start_time
|
||||
LOG.info(
|
||||
"LLM API handler duration metrics",
|
||||
llm_key=self.llm_key,
|
||||
prompt_name=prompt_name,
|
||||
model=self.llm_config.model_name,
|
||||
duration_seconds=duration_seconds,
|
||||
step_id=step.step_id if step else None,
|
||||
thought_id=thought.observer_thought_id if thought else None,
|
||||
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
|
||||
)
|
||||
if raw_response:
|
||||
return response.model_dump()
|
||||
|
||||
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
@@ -626,17 +653,53 @@ class LLMCaller:
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
|
||||
# Track LLM API handler duration
|
||||
duration_seconds = time.perf_counter() - start_time
|
||||
LOG.info(
|
||||
"LLM API handler duration metrics",
|
||||
llm_key=self.llm_key,
|
||||
prompt_name=prompt_name,
|
||||
model=self.llm_config.model_name,
|
||||
duration_seconds=duration_seconds,
|
||||
step_id=step.step_id if step else None,
|
||||
thought_id=thought.observer_thought_id if thought else None,
|
||||
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
|
||||
return parsed_response
|
||||
|
||||
async def _dispatch_llm_call(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
|
||||
if self.llm_key and self.llm_key.startswith("ANTHROPIC"):
|
||||
return await self._call_anthropic(messages, tools, timeout)
|
||||
|
||||
return await litellm.acompletion(
|
||||
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
async def _call_anthropic(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> AnthropicMessage:
|
||||
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 4096
|
||||
model_name = self.llm_config.model_name.replace("bedrock/", "").replace("anthropic/", "")
|
||||
return await app.ANTHROPIC_CLIENT.messages.create(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
timeout=timeout,
|
||||
betas=active_parameters.get("betas", None),
|
||||
)
|
||||
|
||||
|
||||
class LLMCallerManager:
|
||||
_llm_callers: dict[str, LLMCaller] = {}
|
||||
|
||||
@classmethod
|
||||
def get_llm_caller(cls, uid: str) -> LLMCaller | None:
|
||||
return cls._llm_callers.get(uid)
|
||||
|
||||
@classmethod
|
||||
def set_llm_caller(cls, uid: str, llm_caller: LLMCaller) -> None:
|
||||
cls._llm_callers[uid] = llm_caller
|
||||
|
||||
@classmethod
|
||||
def clear_llm_caller(cls, uid: str) -> None:
|
||||
if uid in cls._llm_callers:
|
||||
del cls._llm_callers[uid]
|
||||
|
||||
@@ -47,19 +47,22 @@ async def llm_messages_builder(
|
||||
|
||||
|
||||
async def llm_messages_builder_with_history(
|
||||
prompt: str,
|
||||
prompt: str | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
message_history: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
if message_history:
|
||||
messages = copy.deepcopy(message_history)
|
||||
current_user_messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
]
|
||||
|
||||
current_user_messages: list[dict[str, Any]] = []
|
||||
if prompt:
|
||||
current_user_messages.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
)
|
||||
|
||||
if screenshots:
|
||||
for screenshot in screenshots:
|
||||
|
||||
@@ -96,6 +96,8 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
engine = RunEngine.skyvern_v1
|
||||
if run_obj and run_obj.task_run_type == RunType.openai_cua:
|
||||
engine = RunEngine.openai_cua
|
||||
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
|
||||
engine = RunEngine.anthropic_cua
|
||||
|
||||
context: SkyvernContext = skyvern_context.ensure_context()
|
||||
context.task_id = task.task_id
|
||||
|
||||
@@ -62,6 +62,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
||||
from skyvern.schemas.runs import (
|
||||
CUA_ENGINES,
|
||||
RunEngine,
|
||||
RunResponse,
|
||||
RunType,
|
||||
@@ -1466,7 +1467,7 @@ async def run_task(
|
||||
analytics.capture("skyvern-oss-run-task", data={"url": run_request.url})
|
||||
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id)
|
||||
|
||||
if run_request.engine in [RunEngine.skyvern_v1, RunEngine.openai_cua]:
|
||||
if run_request.engine in CUA_ENGINES:
|
||||
# create task v1
|
||||
# if there's no url, call task generation first to generate the url, data schema if any
|
||||
url = run_request.url
|
||||
@@ -1480,7 +1481,7 @@ async def run_task(
|
||||
)
|
||||
url = url or task_generation.url
|
||||
navigation_goal = task_generation.navigation_goal or run_request.prompt
|
||||
if run_request.engine == RunEngine.openai_cua:
|
||||
if run_request.engine in CUA_ENGINES:
|
||||
navigation_goal = run_request.prompt
|
||||
navigation_payload = task_generation.navigation_payload
|
||||
data_extraction_goal = task_generation.data_extraction_goal
|
||||
@@ -1511,6 +1512,8 @@ async def run_task(
|
||||
run_type = RunType.task_v1
|
||||
if run_request.engine == RunEngine.openai_cua:
|
||||
run_type = RunType.openai_cua
|
||||
elif run_request.engine == RunEngine.anthropic_cua:
|
||||
run_type = RunType.anthropic_cua
|
||||
# build the task run response
|
||||
return TaskRunResponse(
|
||||
run_id=task_v1_response.task_id,
|
||||
@@ -1586,8 +1589,6 @@ async def run_task(
|
||||
publish_workflow=run_request.publish_workflow,
|
||||
),
|
||||
)
|
||||
if run_request.engine == RunEngine.openai_cua:
|
||||
pass
|
||||
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user