From 0a0228b341c512ce2edbe83d95e14f6c387965dc Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Mon, 28 Apr 2025 09:49:44 +0800 Subject: [PATCH] anthropic CUA (#2231) Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- poetry.lock | 27 +++- pyproject.toml | 1 + skyvern/agent/client.py | 2 +- skyvern/config.py | 1 + skyvern/exceptions.py | 5 + skyvern/forge/agent.py | 83 +++++++++-- skyvern/forge/app.py | 4 + .../forge/sdk/api/llm/api_handler_factory.py | 97 ++++++++++--- skyvern/forge/sdk/api/llm/utils.py | 17 ++- skyvern/forge/sdk/executor/async_executor.py | 2 + skyvern/forge/sdk/routes/agent_protocol.py | 9 +- skyvern/schemas/runs.py | 10 +- skyvern/services/run_service.py | 10 +- skyvern/services/task_v1_service.py | 2 + skyvern/webeye/actions/actions.py | 1 + skyvern/webeye/actions/handler.py | 20 ++- skyvern/webeye/actions/parse_actions.py | 131 ++++++++++++++++++ skyvern/webeye/actions/responses.py | 1 + 18 files changed, 378 insertions(+), 45 deletions(-) diff --git a/poetry.lock b/poetry.lock index a5b7819c..33fa1c19 100644 --- a/poetry.lock +++ b/poetry.lock @@ -270,6 +270,31 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] +[[package]] +name = "anthropic" +version = "0.50.0" +description = "The official Python library for the anthropic API" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "anthropic-0.50.0-py3-none-any.whl", hash = "sha256:defbd79327ca2fa61fd7b9eb2f1627dfb1f69c25d49288c52e167ddb84574f80"}, + {file = "anthropic-0.50.0.tar.gz", hash = "sha256:42175ec04ce4ff2fa37cd436710206aadff546ee99d70d974699f59b49adc66f"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.25.0,<1" +jiter = ">=0.4.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +typing-extensions = ">=4.10,<5" + +[package.extras] +bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"] +vertex = ["google-auth[requests] (>=2,<3)"] + [[package]] name = "anyio" version = "4.9.0" @@ -6804,4 +6829,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.11,<3.12" -content-hash = "b8883bdb02803bdb77dfe2de47aca0a28b509720513bf2d7fc2ee001bedf05fb" +content-hash = "926815050df2b2d2fbdb96ac5084cb0e19a628a04d29cbc78ea63936e11b213c" diff --git a/pyproject.toml b/pyproject.toml index b48061ac..fb9550f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ pypdf = "^5.1.0" fastmcp = "^0.4.1" psutil = ">=7.0.0" tiktoken = ">=0.9.0" +anthropic = "^0.50.0" [tool.poetry.group.dev.dependencies] isort = "^5.13.2" diff --git a/skyvern/agent/client.py b/skyvern/agent/client.py index 30057f94..39c9b9c2 100644 --- a/skyvern/agent/client.py +++ b/skyvern/agent/client.py @@ -81,7 +81,7 @@ class SkyvernClient: run_id: str, ) -> RunResponse: run_obj = await self.client.agent.get_run(run_id=run_id) - if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua]: + if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]: return TaskRunResponse.model_validate(run_obj.dict()) elif run_obj.run_type == RunType.workflow_run: return WorkflowRunResponse.model_validate(run_obj.dict()) diff --git a/skyvern/config.py b/skyvern/config.py index 1cbbc93a..9114e24d 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -119,6 +119,7 @@ class Settings(BaseSettings): # LLM PROVIDER SPECIFIC ENABLE_OPENAI: bool = False ENABLE_ANTHROPIC: bool = False + ENABLE_BEDROCK_ANTHROPIC: bool = False ENABLE_AZURE: bool = False ENABLE_AZURE_GPT4O_MINI: bool = False ENABLE_AZURE_O3_MINI: bool = False diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 42bb772f..62bce444 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -672,3 +672,8 @@ class SkyvernContextWindowExceededError(SkyvernException): def __init__(self) -> None: message = "Context window exceeded. Please contact support@skyvern.com for help." super().__init__(message) + + +class LLMCallerNotFoundError(SkyvernException): + def __init__(self, uid: str) -> None: + super().__init__(f"LLM caller for {uid} is not found") diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index ff3abff0..37e37032 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -58,6 +58,7 @@ from skyvern.forge.sdk.api.files import ( rename_file, wait_for_download_finished, ) +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller, LLMCallerManager from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers @@ -70,7 +71,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus -from skyvern.schemas.runs import RunEngine, RunType +from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.webeye.actions.actions import ( Action, @@ -88,7 +89,7 @@ from skyvern.webeye.actions.actions import ( from skyvern.webeye.actions.caching import retrieve_action_plan from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput -from skyvern.webeye.actions.parse_actions import parse_actions, parse_cua_actions +from skyvern.webeye.actions.parse_actions import parse_actions, parse_anthropic_actions, parse_cua_actions from skyvern.webeye.actions.responses import ActionResult, ActionSuccess from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website @@ -253,6 +254,7 @@ class ForgeAgent: complete_verification: bool = True, engine: RunEngine = RunEngine.skyvern_v1, cua_response: OpenAIResponse | None = None, + llm_caller: LLMCaller | None = None, ) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]: workflow_run: WorkflowRun | None = None if task.workflow_run_id: @@ -378,6 +380,13 @@ class ForgeAgent: if page := await browser_state.get_working_page(): await self.register_async_operations(organization, task, page) + llm_caller = LLMCallerManager.get_llm_caller(task.task_id) + if engine == RunEngine.anthropic_cua and not llm_caller: + # llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE") + llm_caller = LLMCallerManager.get_llm_caller(task.task_id) + if not llm_caller: + llm_caller = LLMCaller(llm_key="ANTHROPIC_CLAUDE3.5_SONNET") + LLMCallerManager.set_llm_caller(task.task_id, llm_caller) step, detailed_output = await self.agent_step( task, step, @@ -387,6 +396,7 @@ class ForgeAgent: complete_verification=complete_verification, engine=engine, cua_response=cua_response, + llm_caller=llm_caller, ) await app.AGENT_FUNCTION.post_step_execution(task, step) task = await self.update_task_errors_from_detailed_output(task, detailed_output) @@ -778,6 +788,7 @@ class ForgeAgent: task_block: BaseTaskBlock | None = None, complete_verification: bool = True, cua_response: OpenAIResponse | None = None, + llm_caller: LLMCaller | None = None, ) -> tuple[Step, DetailedAgentStepOutput]: detailed_agent_step_output = DetailedAgentStepOutput( scraped_page=None, @@ -821,8 +832,17 @@ class ForgeAgent: step=step, scraped_page=scraped_page, previous_response=cua_response, + engine=engine, ) detailed_agent_step_output.cua_response = new_cua_response + elif engine == RunEngine.anthropic_cua: + assert llm_caller is not None + actions = await self._generate_anthropic_actions( + task=task, + step=step, + scraped_page=scraped_page, + llm_caller=llm_caller, + ) else: using_cached_action_plan = False if not task.navigation_goal and not isinstance(task_block, ValidationBlock): @@ -834,7 +854,7 @@ class ForgeAgent: ): using_cached_action_plan = True else: - if engine != RunEngine.openai_cua: + if engine in CUA_ENGINES: self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm) json_response = await app.LLM_API_HANDLER( prompt=extract_action_prompt, @@ -1219,7 +1239,8 @@ class ForgeAgent: step: Step, scraped_page: ScrapedPage, previous_response: OpenAIResponse | None = None, - ) -> tuple[list[Action], OpenAIResponse]: + engine: RunEngine = RunEngine.openai_cua, + ) -> tuple[list[Action], OpenAIResponse | None]: if not previous_response: # this is the first step first_response: OpenAIResponse = await app.OPENAI_CLIENT.responses.create( @@ -1377,6 +1398,48 @@ class ForgeAgent: return await parse_cua_actions(task, step, current_response), current_response + async def _generate_anthropic_actions( + self, + task: Task, + step: Step, + scraped_page: ScrapedPage, + llm_caller: LLMCaller, + ) -> list[Action]: + if llm_caller.current_tool_results: + llm_caller.message_history.append({"role": "user", "content": llm_caller.current_tool_results}) + llm_caller.clear_tool_results() + tools = [ + { + "type": "computer_20250124", + "name": "computer", + "display_height_px": settings.BROWSER_HEIGHT, + "display_width_px": settings.BROWSER_WIDTH, + } + ] + if not llm_caller.message_history: + llm_response = await llm_caller.call( + prompt=task.navigation_goal, + screenshots=scraped_page.screenshots, + use_message_history=True, + tools=tools, + raw_response=True, + betas=["computer-use-2025-01-24"], + ) + else: + llm_response = await llm_caller.call( + screenshots=scraped_page.screenshots, + use_message_history=True, + tools=tools, + raw_response=True, + betas=["computer-use-2025-01-24"], + ) + LOG.info("Anthropic response", llm_response=llm_response) + assistant_content = llm_response["content"] + llm_caller.message_history.append({"role": "assistant", "content": assistant_content}) + + actions = await parse_anthropic_actions(task, step, assistant_content) + return actions + @staticmethod async def complete_verify(page: Page, scraped_page: ScrapedPage, task: Task, step: Step) -> CompleteVerifyResult: LOG.info( @@ -1387,7 +1450,7 @@ class ForgeAgent: ) run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) scroll = True - if run_obj and run_obj.task_run_type == RunType.openai_cua: + if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: scroll = False scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll) @@ -1454,7 +1517,7 @@ class ForgeAgent: raise BrowserStateMissingPage() fullpage_screenshot = True - if engine == RunEngine.openai_cua: + if engine in CUA_ENGINES: fullpage_screenshot = False try: @@ -1580,7 +1643,7 @@ class ForgeAgent: max_screenshot_number = settings.MAX_NUM_SCREENSHOTS draw_boxes = True scroll = True - if engine == RunEngine.openai_cua: + if engine in CUA_ENGINES: max_screenshot_number = 1 draw_boxes = False scroll = False @@ -1602,7 +1665,7 @@ class ForgeAgent: engine: RunEngine, ) -> tuple[ScrapedPage, str]: # start the async tasks while running scrape_website - if engine != RunEngine.openai_cua: + if engine not in CUA_ENGINES: self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape) # Scrape the web page and get the screenshot and the elements @@ -1653,7 +1716,7 @@ class ForgeAgent: element_tree_format = ElementTreeFormat.HTML element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format) extract_action_prompt = "" - if engine != RunEngine.openai_cua: + if engine not in CUA_ENGINES: extract_action_prompt = await self._build_extract_action_prompt( task, step, @@ -2371,7 +2434,7 @@ class ForgeAgent: run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) scroll = True - if run_obj and run_obj.task_run_type == RunType.openai_cua: + if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: scroll = False screenshots: list[bytes] = [] diff --git a/skyvern/forge/app.py b/skyvern/forge/app.py index ef4aac10..f464fe78 100644 --- a/skyvern/forge/app.py +++ b/skyvern/forge/app.py @@ -1,5 +1,6 @@ from typing import Awaitable, Callable +from anthropic import AsyncAnthropic, AsyncAnthropicBedrock from fastapi import FastAPI from openai import AsyncAzureOpenAI, AsyncOpenAI @@ -41,6 +42,9 @@ if SettingsManager.get_settings().ENABLE_AZURE_CUA: azure_endpoint=SettingsManager.get_settings().AZURE_CUA_ENDPOINT, azure_deployment=SettingsManager.get_settings().AZURE_CUA_DEPLOYMENT, ) +ANTHROPIC_CLIENT = AsyncAnthropic(api_key=SettingsManager.get_settings().ANTHROPIC_API_KEY) +if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC: + ANTHROPIC_CLIENT = AsyncAnthropicBedrock() SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index e30aeb2a..1386d295 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -6,7 +6,9 @@ from typing import Any import litellm import structlog +from anthropic.types.message import Message as AnthropicMessage from jinja2 import Template +from litellm.utils import CustomStreamWrapper, ModelResponse from skyvern.config import settings from skyvern.exceptions import SkyvernContextWindowExceededError @@ -456,11 +458,18 @@ class LLMCaller: self.llm_config = LLMConfigRegistry.get_config(llm_key) self.base_parameters = base_parameters self.message_history: list[dict[str, Any]] = [] + self.current_tool_results: list[dict[str, Any]] = [] + + def add_tool_result(self, tool_result: dict[str, Any]) -> None: + self.current_tool_results.append(tool_result) + + def clear_tool_results(self) -> None: + self.current_tool_results = [] async def call( self, - prompt: str, - prompt_name: str, + prompt: str | None = None, + prompt_name: str | None = None, step: Step | None = None, task_v2: TaskV2 | None = None, thought: Thought | None = None, @@ -469,6 +478,8 @@ class LLMCaller: parameters: dict[str, Any] | None = None, tools: list | None = None, use_message_history: bool = False, + raw_response: bool = False, + **extra_parameters: Any, ) -> dict[str, Any]: start_time = time.perf_counter() active_parameters = self.base_parameters or {} @@ -476,6 +487,8 @@ class LLMCaller: parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config) active_parameters.update(parameters) + if extra_parameters: + active_parameters.update(extra_parameters) if self.llm_config.litellm_params: # type: ignore active_parameters.update(self.llm_config.litellm_params) # type: ignore @@ -491,7 +504,7 @@ class LLMCaller: ) await app.ARTIFACT_MANAGER.create_llm_artifact( - data=prompt.encode("utf-8"), + data=prompt.encode("utf-8") if prompt else b"", artifact_type=ArtifactType.LLM_PROMPT, screenshots=screenshots, step=step, @@ -525,8 +538,7 @@ class LLMCaller: ) t_llm_request = time.perf_counter() try: - response = await litellm.acompletion( - model=self.llm_config.model_name, + response = await self._dispatch_llm_call( messages=messages, tools=tools, timeout=settings.LLM_CONFIG_TIMEOUT, @@ -603,6 +615,21 @@ class LLMCaller: cached_token_count=cached_tokens if cached_tokens > 0 else None, thought_cost=llm_cost, ) + # Track LLM API handler duration + duration_seconds = time.perf_counter() - start_time + LOG.info( + "LLM API handler duration metrics", + llm_key=self.llm_key, + prompt_name=prompt_name, + model=self.llm_config.model_name, + duration_seconds=duration_seconds, + step_id=step.step_id if step else None, + thought_id=thought.observer_thought_id if thought else None, + organization_id=step.organization_id if step else (thought.organization_id if thought else None), + ) + if raw_response: + return response.model_dump() + parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix) await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(parsed_response, indent=2).encode("utf-8"), @@ -626,17 +653,53 @@ class LLMCaller: ai_suggestion=ai_suggestion, ) - # Track LLM API handler duration - duration_seconds = time.perf_counter() - start_time - LOG.info( - "LLM API handler duration metrics", - llm_key=self.llm_key, - prompt_name=prompt_name, - model=self.llm_config.model_name, - duration_seconds=duration_seconds, - step_id=step.step_id if step else None, - thought_id=thought.observer_thought_id if thought else None, - organization_id=step.organization_id if step else (thought.organization_id if thought else None), + return parsed_response + + async def _dispatch_llm_call( + self, + messages: list[dict[str, Any]], + tools: list | None = None, + timeout: float = settings.LLM_CONFIG_TIMEOUT, + **active_parameters: dict[str, Any], + ) -> ModelResponse | CustomStreamWrapper | AnthropicMessage: + if self.llm_key and self.llm_key.startswith("ANTHROPIC"): + return await self._call_anthropic(messages, tools, timeout) + + return await litellm.acompletion( + model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters ) - return parsed_response + async def _call_anthropic( + self, + messages: list[dict[str, Any]], + tools: list | None = None, + timeout: float = settings.LLM_CONFIG_TIMEOUT, + **active_parameters: dict[str, Any], + ) -> AnthropicMessage: + max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 4096 + model_name = self.llm_config.model_name.replace("bedrock/", "").replace("anthropic/", "") + return await app.ANTHROPIC_CLIENT.messages.create( + max_tokens=max_tokens, + messages=messages, + model=model_name, + tools=tools, + timeout=timeout, + betas=active_parameters.get("betas", None), + ) + + +class LLMCallerManager: + _llm_callers: dict[str, LLMCaller] = {} + + @classmethod + def get_llm_caller(cls, uid: str) -> LLMCaller | None: + return cls._llm_callers.get(uid) + + @classmethod + def set_llm_caller(cls, uid: str, llm_caller: LLMCaller) -> None: + cls._llm_callers[uid] = llm_caller + + @classmethod + def clear_llm_caller(cls, uid: str) -> None: + if uid in cls._llm_callers: + del cls._llm_callers[uid] diff --git a/skyvern/forge/sdk/api/llm/utils.py b/skyvern/forge/sdk/api/llm/utils.py index 0bbb00a8..811a2247 100644 --- a/skyvern/forge/sdk/api/llm/utils.py +++ b/skyvern/forge/sdk/api/llm/utils.py @@ -47,19 +47,22 @@ async def llm_messages_builder( async def llm_messages_builder_with_history( - prompt: str, + prompt: str | None = None, screenshots: list[bytes] | None = None, message_history: list[dict[str, Any]] | None = None, ) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] if message_history: messages = copy.deepcopy(message_history) - current_user_messages: list[dict[str, Any]] = [ - { - "type": "text", - "text": prompt, - } - ] + + current_user_messages: list[dict[str, Any]] = [] + if prompt: + current_user_messages.append( + { + "type": "text", + "text": prompt, + } + ) if screenshots: for screenshot in screenshots: diff --git a/skyvern/forge/sdk/executor/async_executor.py b/skyvern/forge/sdk/executor/async_executor.py index 2bc36c15..135cc57b 100644 --- a/skyvern/forge/sdk/executor/async_executor.py +++ b/skyvern/forge/sdk/executor/async_executor.py @@ -96,6 +96,8 @@ class BackgroundTaskExecutor(AsyncExecutor): engine = RunEngine.skyvern_v1 if run_obj and run_obj.task_run_type == RunType.openai_cua: engine = RunEngine.openai_cua + elif run_obj and run_obj.task_run_type == RunType.anthropic_cua: + engine = RunEngine.anthropic_cua context: SkyvernContext = skyvern_context.ensure_context() context.task_id = task.task_id diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 917ef20c..1bb27887 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -62,6 +62,7 @@ from skyvern.forge.sdk.workflow.models.workflow import ( ) from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest from skyvern.schemas.runs import ( + CUA_ENGINES, RunEngine, RunResponse, RunType, @@ -1466,7 +1467,7 @@ async def run_task( analytics.capture("skyvern-oss-run-task", data={"url": run_request.url}) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id) - if run_request.engine in [RunEngine.skyvern_v1, RunEngine.openai_cua]: + if run_request.engine in CUA_ENGINES: # create task v1 # if there's no url, call task generation first to generate the url, data schema if any url = run_request.url @@ -1480,7 +1481,7 @@ async def run_task( ) url = url or task_generation.url navigation_goal = task_generation.navigation_goal or run_request.prompt - if run_request.engine == RunEngine.openai_cua: + if run_request.engine in CUA_ENGINES: navigation_goal = run_request.prompt navigation_payload = task_generation.navigation_payload data_extraction_goal = task_generation.data_extraction_goal @@ -1511,6 +1512,8 @@ async def run_task( run_type = RunType.task_v1 if run_request.engine == RunEngine.openai_cua: run_type = RunType.openai_cua + elif run_request.engine == RunEngine.anthropic_cua: + run_type = RunType.anthropic_cua # build the task run response return TaskRunResponse( run_id=task_v1_response.task_id, @@ -1586,8 +1589,6 @@ async def run_task( publish_workflow=run_request.publish_workflow, ), ) - if run_request.engine == RunEngine.openai_cua: - pass raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}") diff --git a/skyvern/schemas/runs.py b/skyvern/schemas/runs.py index b4f3ddf5..18276cd7 100644 --- a/skyvern/schemas/runs.py +++ b/skyvern/schemas/runs.py @@ -93,12 +93,18 @@ class RunType(StrEnum): task_v2 = "task_v2" workflow_run = "workflow_run" openai_cua = "openai_cua" + anthropic_cua = "anthropic_cua" class RunEngine(StrEnum): skyvern_v1 = "skyvern-1.0" skyvern_v2 = "skyvern-2.0" openai_cua = "openai-cua" + anthropic_cua = "anthropic-cua" + + +CUA_ENGINES = [RunEngine.openai_cua, RunEngine.anthropic_cua] +CUA_RUN_TYPES = [RunType.openai_cua, RunType.anthropic_cua] class RunStatus(StrEnum): @@ -217,8 +223,8 @@ class BaseRunResponse(BaseModel): class TaskRunResponse(BaseRunResponse): - run_type: Literal[RunType.task_v1, RunType.task_v2, RunType.openai_cua] = Field( - description="Types of a task run - task_v1, task_v2, openai_cua" + run_type: Literal[RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua] = Field( + description="Types of a task run - task_v1, task_v2, openai_cua, anthropic_cua" ) run_request: TaskRunRequest | None = Field( default=None, description="The original request parameters used to start this task run" diff --git a/skyvern/services/run_service.py b/skyvern/services/run_service.py index 247d202a..487751c8 100644 --- a/skyvern/services/run_service.py +++ b/skyvern/services/run_service.py @@ -13,7 +13,11 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R if not run: return None - if run.task_run_type == RunType.task_v1 or run.task_run_type == RunType.openai_cua: + if ( + run.task_run_type == RunType.task_v1 + or run.task_run_type == RunType.openai_cua + or run.task_run_type == RunType.anthropic_cua + ): # fetch task v1 from db and transform to task run response task_v1 = await app.DATABASE.get_task(run.run_id, organization_id=organization_id) if not task_v1: @@ -21,6 +25,8 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R run_engine = RunEngine.skyvern_v1 if run.task_run_type == RunType.openai_cua: run_engine = RunEngine.openai_cua + elif run.task_run_type == RunType.anthropic_cua: + run_engine = RunEngine.anthropic_cua return TaskRunResponse( run_id=run.run_id, run_type=run.task_run_type, @@ -136,7 +142,7 @@ async def cancel_run(run_id: str, organization_id: str | None = None, api_key: s detail=f"Run not found {run_id}", ) - if run.task_run_type in [RunType.task_v1, RunType.openai_cua]: + if run.task_run_type in [RunType.task_v1, RunType.openai_cua, RunType.anthropic_cua]: await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key) elif run.task_run_type == RunType.task_v2: await cancel_task_v2(run_id, organization_id=organization_id) diff --git a/skyvern/services/task_v1_service.py b/skyvern/services/task_v1_service.py index 5a319e4c..199377a4 100644 --- a/skyvern/services/task_v1_service.py +++ b/skyvern/services/task_v1_service.py @@ -87,6 +87,8 @@ async def run_task( run_type = RunType.task_v1 if engine == RunEngine.openai_cua: run_type = RunType.openai_cua + elif engine == RunEngine.anthropic_cua: + run_type = RunType.anthropic_cua await app.DATABASE.create_task_run( task_run_type=run_type, organization_id=organization.organization_id, diff --git a/skyvern/webeye/actions/actions.py b/skyvern/webeye/actions/actions.py index 9342fee3..72fe8f2e 100644 --- a/skyvern/webeye/actions/actions.py +++ b/skyvern/webeye/actions/actions.py @@ -113,6 +113,7 @@ class Action(BaseModel): element_id: Annotated[str, Field(coerce_numbers_to_str=True)] | None = None skyvern_element_hash: str | None = None skyvern_element_data: dict[str, Any] | None = None + tool_call_id: str | None = None # DecisiveAction (CompleteAction, TerminateAction) fields errors: list[UserDefinedError] | None = None diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index c6838630..a2f00890 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -59,6 +59,7 @@ from skyvern.forge.sdk.api.files import ( list_files_in_directory, wait_for_download_finished, ) +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCallerManager from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post @@ -363,9 +364,26 @@ class ActionHandler: handler = ActionHandler._handled_action_types[action.action_type] results = await handler(action, page, scraped_page, task, step) actions_result.extend(results) + llm_caller = LLMCallerManager.get_llm_caller(task.task_id) if not results or not isinstance(actions_result[-1], ActionSuccess): + if llm_caller and action.tool_call_id: + # add failure message to the llm caller + tool_call_result = { + "type": "tool_result", + "tool_use_id": action.tool_call_id, + "content": {"result": "Tool execution failed"}, + } + llm_caller.add_tool_result(tool_call_result) return actions_result + if llm_caller and action.tool_call_id: + tool_call_result = { + "type": "tool_result", + "tool_use_id": action.tool_call_id, + "content": {"result": "Tool executed successfully"}, + } + llm_caller.add_tool_result(tool_call_result) + # do the teardown teardown = ActionHandler._teardown_action_types.get(action.action_type) if teardown: @@ -1532,7 +1550,7 @@ async def handle_keypress_action( ) -> list[ActionResult]: updated_keys = [] for key in action.keys: - if key.lower() == "enter": + if key.lower() in ("enter", "return"): updated_keys.append("Enter") elif key.lower() == "space": updated_keys.append(" ") diff --git a/skyvern/webeye/actions/parse_actions.py b/skyvern/webeye/actions/parse_actions.py index 747cceb0..8e0cd40d 100644 --- a/skyvern/webeye/actions/parse_actions.py +++ b/skyvern/webeye/actions/parse_actions.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict import structlog @@ -448,3 +449,133 @@ async def parse_cua_actions( action.action_order = 0 return [action] return actions + + +async def parse_anthropic_actions( + task: Task, + step: Step, + assistant_content: list[dict[str, Any]], +) -> list[Action]: + tool_calls = [block for block in assistant_content if block["type"] == "tool_use"] + idx = 0 + actions: list[Action] = [] + while idx < len(tool_calls): + tool_call = tool_calls[idx] + tool_call_id = tool_call["id"] + parsed_args = _parse_anthropic_computer_args(tool_call) + if not parsed_args: + idx += 1 + continue + action = parsed_args["action"] + if action == "mouse_move": + x, y = parsed_args["coordinate"] + actions.append( + MoveAction( + x=x, + y=y, + organization_id=task.organization_id, + workflow_run_id=task.workflow_run_id, + task_id=task.task_id, + step_id=step.step_id, + step_order=step.order, + action_order=idx, + tool_call_id=tool_call_id, + ) + ) + idx += 1 + elif action == "left_click": + if idx - 1 >= 0: + prev_tool_call = tool_calls[idx - 1] + prev_parsed_args = _parse_anthropic_computer_args(prev_tool_call) + if prev_parsed_args and prev_parsed_args["action"] == "mouse_move": + coordinate = prev_parsed_args["coordinate"] + else: + coordinate = parsed_args.get("coordinate") + else: + coordinate = parsed_args.get("coordinate") + + idx += 1 + if not coordinate: + LOG.warning( + "Left click action has no coordinate and it doesn't have mouse_move before it", + tool_call=tool_call, + ) + continue + x, y = coordinate + actions.append( + ClickAction( + element_id="", + x=x, + y=y, + button="left", + organization_id=task.organization_id, + workflow_run_id=task.workflow_run_id, + task_id=task.task_id, + step_id=step.step_id, + step_order=step.order, + action_order=idx - 1, + tool_call_id=tool_call_id, + ) + ) + elif action == "type": + text = parsed_args.get("text") + idx += 1 + if not text: + LOG.warning( + "Type action has no text", + tool_call=tool_call, + ) + continue + actions.append( + InputTextAction( + element_id="", + text=text, + organization_id=task.organization_id, + workflow_run_id=task.workflow_run_id, + task_id=task.task_id, + step_id=step.step_id, + step_order=step.order, + action_order=idx, + tool_call_id=tool_call_id, + ) + ) + elif action == "key": + text = parsed_args.get("text") + idx += 1 + if not text: + LOG.warning( + "Key action has no text", + tool_call=tool_call, + ) + continue + actions.append( + KeypressAction( + element_id="", + keys=[text], + organization_id=task.organization_id, + workflow_run_id=task.workflow_run_id, + task_id=task.task_id, + step_id=step.step_id, + step_order=step.order, + action_order=idx, + tool_call_id=tool_call_id, + ) + ) + else: + LOG.error( + "Unsupported action", + tool_call=tool_call, + ) + idx += 1 + return actions + + +def _parse_anthropic_computer_args(tool_call: dict[str, Any]) -> dict[str, Any] | None: + tool_call_type = tool_call["type"] + if tool_call_type != "function": + return None + tool_call_name = tool_call["function"]["name"] + if tool_call_name != "computer": + return None + tool_call_arguments = tool_call["function"]["arguments"] + return json.loads(tool_call_arguments) diff --git a/skyvern/webeye/actions/responses.py b/skyvern/webeye/actions/responses.py index 76e1744a..1fc1c445 100644 --- a/skyvern/webeye/actions/responses.py +++ b/skyvern/webeye/actions/responses.py @@ -18,6 +18,7 @@ class ActionResult(BaseModel): interacted_with_sibling: bool | None = None interacted_with_parent: bool | None = None skip_remaining_actions: bool | None = None + tool_call_result: dict[str, Any] | None = None def __str__(self) -> str: results = [f"ActionResult(success={self.success}"]