ui-tars integration (#2656)

This commit is contained in:
Wyatt Marshall
2025-06-13 01:23:39 -04:00
committed by GitHub
parent 47cf755d9c
commit 15d46aab82
18 changed files with 986 additions and 13 deletions

View File

@@ -59,6 +59,7 @@ from skyvern.forge.sdk.api.files import (
wait_for_download_finished,
)
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
from skyvern.forge.sdk.api.llm.ui_tars_llm_caller import UITarsLLMCaller
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
@@ -91,7 +92,12 @@ from skyvern.webeye.actions.actions import (
from skyvern.webeye.actions.caching import retrieve_action_plan
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.parse_actions import parse_actions, parse_anthropic_actions, parse_cua_actions
from skyvern.webeye.actions.parse_actions import (
parse_actions,
parse_anthropic_actions,
parse_cua_actions,
parse_ui_tars_actions,
)
from skyvern.webeye.actions.responses import ActionResult, ActionSuccess
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
@@ -393,9 +399,18 @@ class ForgeAgent:
llm_key=llm_key or settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True
)
if engine == RunEngine.ui_tars and not llm_caller:
# see if the llm_caller is already set in memory
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if not llm_caller:
# create a new UI-TARS llm_caller
llm_key = task.llm_key or settings.UI_TARS_LLM_KEY
llm_caller = UITarsLLMCaller(llm_key=llm_key, screenshot_scaling_enabled=True)
llm_caller.initialize_conversation(task)
# TODO: remove the code after migrating everything to llm callers
# currently, only anthropic cua tasks use llm_caller
if engine == RunEngine.anthropic_cua and llm_caller:
# currently, only anthropic cua and ui_tars tasks use llm_caller
if engine in [RunEngine.anthropic_cua, RunEngine.ui_tars] and llm_caller:
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
step, detailed_output = await self.agent_step(
@@ -550,6 +565,7 @@ class ForgeAgent:
complete_verification=complete_verification,
engine=engine,
cua_response=cua_response_param,
llm_caller=llm_caller,
)
elif settings.execute_all_steps() and next_step:
return await self.execute_step(
@@ -563,6 +579,7 @@ class ForgeAgent:
complete_verification=complete_verification,
engine=engine,
cua_response=cua_response_param,
llm_caller=llm_caller,
)
else:
LOG.info(
@@ -854,6 +871,15 @@ class ForgeAgent:
scraped_page=scraped_page,
llm_caller=llm_caller,
)
elif engine == RunEngine.ui_tars:
assert llm_caller is not None
actions = await self._generate_ui_tars_actions(
task=task,
step=step,
scraped_page=scraped_page,
llm_caller=llm_caller,
)
else:
using_cached_action_plan = False
if not task.navigation_goal and not isinstance(task_block, ValidationBlock):
@@ -1483,6 +1509,56 @@ class ForgeAgent:
)
return actions
async def _generate_ui_tars_actions(
self,
task: Task,
step: Step,
scraped_page: ScrapedPage,
llm_caller: LLMCaller,
) -> list[Action]:
"""Generate actions using UI-TARS (Seed1.5-VL) model through the LLMCaller pattern."""
LOG.info(
"UI-TARS action generation starts",
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
)
# Ensure we have a UITarsLLMCaller instance
if not isinstance(llm_caller, UITarsLLMCaller):
raise ValueError(f"Expected UITarsLLMCaller, got {type(llm_caller)}")
# Add the current screenshot to conversation
if scraped_page.screenshots:
llm_caller.add_screenshot(scraped_page.screenshots[0])
else:
LOG.error("No screenshots found, skipping UI-TARS action generation")
raise ValueError("No screenshots found, skipping UI-TARS action generation")
# Generate response using the LLMCaller
response_content = await llm_caller.generate_ui_tars_response(step)
LOG.info(f"UI-TARS raw response: {response_content}")
window_dimension = (
cast(Resolution, scraped_page.window_dimension)
if scraped_page.window_dimension
else Resolution(width=1920, height=1080)
)
LOG.info(f"UI-TARS browser window dimension: {window_dimension}")
actions = await parse_ui_tars_actions(task, step, response_content, window_dimension)
LOG.info(
"UI-TARS action generation completed",
task_id=task.task_id,
step_id=step.step_id,
actions_count=len(actions),
)
return actions
async def complete_verify(
self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> CompleteVerifyResult:
@@ -2105,6 +2181,7 @@ class ForgeAgent:
return
await self.async_operation_pool.remove_task(task.task_id)
await self.cleanup_browser_and_create_artifacts(
close_browser_on_completion, last_step, task, browser_session_id=browser_session_id
)

View File

@@ -46,6 +46,14 @@ ANTHROPIC_CLIENT = AsyncAnthropic(api_key=SettingsManager.get_settings().ANTHROP
if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC:
ANTHROPIC_CLIENT = AsyncAnthropicBedrock()
# Add UI-TARS client setup
UI_TARS_CLIENT = None
if SettingsManager.get_settings().ENABLE_UI_TARS:
UI_TARS_CLIENT = AsyncOpenAI(
api_key=SettingsManager.get_settings().UI_TARS_API_KEY,
base_url=SettingsManager.get_settings().UI_TARS_API_BASE,
)
SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY
)

View File

@@ -0,0 +1,37 @@
{#
SPDX-License-Identifier: Apache-2.0
Adapted from:
https://github.com/bytedance/UI-TARS/blob/main/codes/ui_tars/prompt.py
Licensed under the Apache License, Version 2.0
This prompt is used for the UI-TARS agent.
#}
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(point='<point>x1 y1</point>')
left_double(point='<point>x1 y1</point>')
right_single(point='<point>x1 y1</point>')
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>')
hotkey(key='ctrl c') # Split keys with a space and use lowercase. Also, do not use more than 3 keys in one hotkey action.
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') # Show more information on the `direction` side.
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
## Note
- Use {{language}} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{{instruction}}

View File

@@ -2,7 +2,7 @@ import dataclasses
import json
import time
from asyncio import CancelledError
from typing import Any
from typing import Any, AsyncIterator
import litellm
import structlog
@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
from jinja2 import Template
from litellm.utils import CustomStreamWrapper, ModelResponse
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from pydantic import BaseModel
from skyvern.config import settings
@@ -23,6 +24,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
LLMProviderErrorRetryableTask,
)
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig, dummy_llm_api_handler
from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
@@ -744,10 +746,14 @@ class LLMCaller:
tools: list | None = None,
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | Any:
if self.llm_key and "ANTHROPIC" in self.llm_key:
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
# Route UI-TARS models to custom handler instead of LiteLLM
if self.llm_key and "UI_TARS" in self.llm_key:
return await self._call_ui_tars(messages, tools, timeout, **active_parameters)
return await litellm.acompletion(
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
)
@@ -790,8 +796,97 @@ class LLMCaller:
)
return response
async def get_call_stats(self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage) -> LLMCallStats:
async def _call_ui_tars(
self,
messages: list[dict[str, Any]],
tools: list | None = None,
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> Any:
"""Custom UI-TARS API call using OpenAI client with VolcEngine endpoint."""
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 400
model_name = self.llm_config.model_name
if not app.UI_TARS_CLIENT:
raise ValueError(
"UI_TARS_CLIENT not initialized. Please ensure ENABLE_UI_TARS=true and UI_TARS_API_KEY is set."
)
LOG.info(
"UI-TARS request",
model_name=model_name,
timeout=timeout,
messages_length=len(messages),
)
# Use the UI-TARS client (which is OpenAI-compatible with VolcEngine)
chat_completion: AsyncIterator[ChatCompletionChunk] = await app.UI_TARS_CLIENT.chat.completions.create(
model=model_name,
messages=messages,
top_p=None,
temperature=active_parameters.get("temperature", 0.0),
max_tokens=max_tokens,
stream=True,
seed=None,
stop=None,
frequency_penalty=None,
presence_penalty=None,
timeout=timeout,
)
# Aggregate streaming response like in ByteDance example
response_content = ""
async for message in chat_completion:
if message.choices[0].delta.content:
response_content += message.choices[0].delta.content
response = UITarsResponse(response_content, model_name)
LOG.info(
"UI-TARS response",
model_name=model_name,
response_length=len(response_content),
timeout=timeout,
)
return response
async def get_call_stats(
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | dict[str, Any] | Any
) -> LLMCallStats:
empty_call_stats = LLMCallStats()
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
if hasattr(response, "usage") and hasattr(response, "choices") and hasattr(response, "model"):
usage = response.usage
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
llm_cost = input_token_cost + output_token_cost
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
cached_tokens=0, # UI-TARS doesn't have cached tokens
reasoning_tokens=0,
)
# Handle UI-TARS response (dict format - fallback)
if isinstance(response, dict) and "choices" in response and "usage" in response:
usage = response["usage"]
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
llm_cost = input_token_cost + output_token_cost
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
cached_tokens=0, # UI-TARS doesn't have cached tokens
reasoning_tokens=0,
)
if isinstance(response, AnthropicMessage):
usage = response.usage
input_token_cost = (3.0 / 1000000) * usage.input_tokens

View File

@@ -568,6 +568,18 @@ if settings.ENABLE_AZURE_O3:
max_completion_tokens=100000,
),
)
if settings.ENABLE_UI_TARS:
LLMConfigRegistry.register_config(
"UI_TARS_SEED1_5_VL",
LLMConfig(
settings.UI_TARS_MODEL,
["UI_TARS_API_KEY"],
supports_vision=True,
add_assistant_prefix=False,
max_tokens=400,
temperature=0.0,
),
)
if settings.ENABLE_GEMINI:
LLMConfigRegistry.register_config(
@@ -630,6 +642,16 @@ if settings.ENABLE_GEMINI:
max_completion_tokens=65536,
),
)
LLMConfigRegistry.register_config(
"GEMINI_2.5_FLASH_PREVIEW",
LLMConfig(
"gemini/gemini-2.5-flash-preview-05-20",
["GEMINI_API_KEY"],
supports_vision=True,
add_assistant_prefix=False,
max_completion_tokens=65536,
),
)
if settings.ENABLE_NOVITA:

View File

@@ -0,0 +1,200 @@
#
# SPDX-License-Identifier: Apache-2.0
# Code partially adapted from:
# https://github.com/ByteDance-Seed/Seed1.5-VL/blob/main/GUI/gui.ipynb
#
# Licensed under the Apache License, Version 2.0
#
# For managing the conversation history of the UI-TARS agent.
#
"""
UI-TARS LLM Caller that follows the standard LLMCaller pattern.
"""
import base64
from io import BytesIO
from typing import Any, Dict
import structlog
from PIL import Image
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.utils.image_resizer import Resolution
LOG = structlog.get_logger()
def _build_system_prompt(instruction: str, language: str = "English") -> str:
"""Build system prompt for UI-TARS using the prompt engine."""
return prompt_engine.load_prompt("ui-tars-system-prompt", language=language, instruction=instruction)
def _is_image_message(message: Dict[str, Any]) -> bool:
"""Check if message contains an image."""
return (
message.get("role") == "user"
and isinstance(message.get("content"), list)
and any(item.get("type") == "image_url" for item in message["content"])
)
class UITarsLLMCaller(LLMCaller):
"""
UI-TARS specific LLM caller that manages conversation history.
Follows the established LLMCaller pattern used by Anthropic CUA.
"""
def __init__(self, llm_key: str, screenshot_scaling_enabled: bool = False):
super().__init__(llm_key, screenshot_scaling_enabled)
self.max_history_images = 5
self._conversation_initialized = False
def initialize_conversation(self, task: Task) -> None:
"""Initialize conversation with system prompt for the given task."""
if not self._conversation_initialized:
# Handle None case for navigation_goal
instruction = task.navigation_goal or "Default navigation task"
system_prompt = _build_system_prompt(instruction)
self.message_history = [{"role": "user", "content": system_prompt}]
self._conversation_initialized = True
LOG.debug("Initialized UI-TARS conversation", task_id=task.task_id)
def add_screenshot(self, screenshot_bytes: bytes) -> None:
"""Add screenshot to conversation history."""
if not screenshot_bytes:
return
# Convert to PIL Image to get format
image = Image.open(BytesIO(screenshot_bytes))
image_format = self._get_image_format_from_pil(image)
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
# Add image message
image_message = {
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/{image_format};base64,{screenshot_b64}"}}
],
}
self.message_history.append(image_message)
self._maintain_history_limit()
LOG.debug("Added screenshot to conversation", total_messages=len(self.message_history))
def add_assistant_response(self, response: str) -> None:
"""Add assistant response to conversation history."""
self.message_history.append({"role": "assistant", "content": response})
LOG.debug("Added assistant response to conversation")
def _maintain_history_limit(self) -> None:
"""Maintain history limit: keep system prompt + all assistant responses + last N screenshots."""
image_count = self._count_image_messages()
if image_count <= self.max_history_images:
return
# Ensure we have a system prompt (first message should be user with string content)
if (
not self.message_history
or self.message_history[0]["role"] != "user"
or not isinstance(self.message_history[0]["content"], str)
):
LOG.error("Conversation history corrupted - missing system prompt")
return
# Remove oldest screenshots only (keep system prompt and all assistant responses)
removed_count = 0
images_to_remove = image_count - self.max_history_images
i = 1 # Start after system prompt (index 0)
while i < len(self.message_history) and removed_count < images_to_remove:
message = self.message_history[i]
if _is_image_message(message):
# Remove only the screenshot message, keep all assistant responses
self.message_history.pop(i)
removed_count += 1
# Don't increment i since we removed an element
else:
i += 1
LOG.debug(
f"Maintained history limit, removed {removed_count} old images, "
f"current messages: {len(self.message_history)}"
)
def _count_image_messages(self) -> int:
"""Count existing image messages in the conversation history."""
count = 0
for message in self.message_history:
if _is_image_message(message):
count += 1
return count
def _get_image_format_from_pil(self, image: Image.Image) -> str:
"""Extract and validate image format from PIL Image object."""
format_str = image.format.lower() if image.format else "png"
if format_str not in ["jpg", "jpeg", "png", "webp"]:
return "png" # Default to PNG for unsupported formats
return format_str
async def call(
self,
prompt: str | None = None,
prompt_name: str | None = None,
step: Step | None = None,
task_v2: TaskV2 | None = None,
thought: Thought | None = None,
ai_suggestion: AISuggestion | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
tools: list[Any] | None = None,
use_message_history: bool = False,
raw_response: bool = False,
window_dimension: Resolution | None = None,
**extra_parameters: Any,
) -> dict[str, Any]:
"""Override call method to use standard LLM routing instead of direct LiteLLM."""
# Use raw_response=True to bypass JSON parsing since UI-TARS returns plain text
response = await super().call(
prompt=prompt,
prompt_name=prompt_name,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
screenshots=screenshots,
parameters=parameters,
tools=tools,
use_message_history=True, # Use message history for UI-TARS
raw_response=True, # Bypass JSON parsing - UI-TARS returns plain text
window_dimension=window_dimension,
**extra_parameters,
)
# Extract content from the raw response
if isinstance(response, dict) and "choices" in response:
content = response["choices"][0]["message"]["content"]
return {"content": content}
else:
# Fallback for unexpected response format
return {"content": str(response)}
async def generate_ui_tars_response(self, step: Step) -> str:
"""Generate UI-TARS response using the overridden call method."""
response = await self.call(step=step)
content = response.get("content", "").strip()
# Add the response to conversation history
self.add_assistant_response(content)
return content

View File

@@ -0,0 +1,66 @@
"""UI-TARS response model that mimics the ModelResponse interface."""
import json
from typing import Any
class UITarsResponse:
"""A response object that mimics the ModelResponse interface for UI-TARS API responses."""
def __init__(self, content: str, model: str):
# Create choice objects with proper nested structure for parse_api_response
class Message:
def __init__(self, content: str):
self.content = content
self.role = "assistant"
class Choice:
def __init__(self, content: str):
self.message = Message(content)
self.choices = [Choice(content)]
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
self.model = model
self.object = "chat.completion"
def model_dump_json(self, indent: int = 2) -> str:
"""Provide model_dump_json compatibility for artifact creation."""
return json.dumps(
{
"choices": [
{
"message": {
"content": self.choices[0].message.content,
"role": self.choices[0].message.role,
}
}
],
"usage": self.usage,
"model": self.model,
"object": self.object,
},
indent=indent,
)
def model_dump(self, exclude_none: bool = True) -> dict:
"""Provide model_dump compatibility for raw_response."""
return {
"choices": [
{"message": {"content": self.choices[0].message.content, "role": self.choices[0].message.role}}
],
"usage": self.usage,
"model": self.model,
"object": self.object,
}
def get(self, key: str, default: Any = None) -> Any:
"""Provide dict-like access for compatibility."""
return getattr(self, key, default)
def __getitem__(self, key: str) -> Any:
"""Provide dict-like access for compatibility."""
return getattr(self, key)
def __contains__(self, key: str) -> bool:
"""Provide dict-like access for compatibility."""
return hasattr(self, key)

View File

@@ -100,6 +100,8 @@ class BackgroundTaskExecutor(AsyncExecutor):
engine = RunEngine.openai_cua
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
engine = RunEngine.anthropic_cua
elif run_obj and run_obj.task_run_type == RunType.ui_tars:
engine = RunEngine.ui_tars
context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id