ui-tars integration (#2656)

This commit is contained in:
Wyatt Marshall
2025-06-13 01:23:39 -04:00
committed by GitHub
parent 47cf755d9c
commit 15d46aab82
18 changed files with 986 additions and 13 deletions

View File

@@ -2,7 +2,7 @@ import dataclasses
import json
import time
from asyncio import CancelledError
from typing import Any
from typing import Any, AsyncIterator
import litellm
import structlog
@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
from jinja2 import Template
from litellm.utils import CustomStreamWrapper, ModelResponse
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from pydantic import BaseModel
from skyvern.config import settings
@@ -23,6 +24,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
LLMProviderErrorRetryableTask,
)
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig, dummy_llm_api_handler
from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
@@ -744,10 +746,14 @@ class LLMCaller:
tools: list | None = None,
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | Any:
if self.llm_key and "ANTHROPIC" in self.llm_key:
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
# Route UI-TARS models to custom handler instead of LiteLLM
if self.llm_key and "UI_TARS" in self.llm_key:
return await self._call_ui_tars(messages, tools, timeout, **active_parameters)
return await litellm.acompletion(
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
)
@@ -790,8 +796,97 @@ class LLMCaller:
)
return response
async def get_call_stats(self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage) -> LLMCallStats:
async def _call_ui_tars(
self,
messages: list[dict[str, Any]],
tools: list | None = None,
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> Any:
"""Custom UI-TARS API call using OpenAI client with VolcEngine endpoint."""
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 400
model_name = self.llm_config.model_name
if not app.UI_TARS_CLIENT:
raise ValueError(
"UI_TARS_CLIENT not initialized. Please ensure ENABLE_UI_TARS=true and UI_TARS_API_KEY is set."
)
LOG.info(
"UI-TARS request",
model_name=model_name,
timeout=timeout,
messages_length=len(messages),
)
# Use the UI-TARS client (which is OpenAI-compatible with VolcEngine)
chat_completion: AsyncIterator[ChatCompletionChunk] = await app.UI_TARS_CLIENT.chat.completions.create(
model=model_name,
messages=messages,
top_p=None,
temperature=active_parameters.get("temperature", 0.0),
max_tokens=max_tokens,
stream=True,
seed=None,
stop=None,
frequency_penalty=None,
presence_penalty=None,
timeout=timeout,
)
# Aggregate streaming response like in ByteDance example
response_content = ""
async for message in chat_completion:
if message.choices[0].delta.content:
response_content += message.choices[0].delta.content
response = UITarsResponse(response_content, model_name)
LOG.info(
"UI-TARS response",
model_name=model_name,
response_length=len(response_content),
timeout=timeout,
)
return response
async def get_call_stats(
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | dict[str, Any] | Any
) -> LLMCallStats:
empty_call_stats = LLMCallStats()
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
if hasattr(response, "usage") and hasattr(response, "choices") and hasattr(response, "model"):
usage = response.usage
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
llm_cost = input_token_cost + output_token_cost
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
cached_tokens=0, # UI-TARS doesn't have cached tokens
reasoning_tokens=0,
)
# Handle UI-TARS response (dict format - fallback)
if isinstance(response, dict) and "choices" in response and "usage" in response:
usage = response["usage"]
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
llm_cost = input_token_cost + output_token_cost
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
cached_tokens=0, # UI-TARS doesn't have cached tokens
reasoning_tokens=0,
)
if isinstance(response, AnthropicMessage):
usage = response.usage
input_token_cost = (3.0 / 1000000) * usage.input_tokens

View File

@@ -568,6 +568,18 @@ if settings.ENABLE_AZURE_O3:
max_completion_tokens=100000,
),
)
if settings.ENABLE_UI_TARS:
LLMConfigRegistry.register_config(
"UI_TARS_SEED1_5_VL",
LLMConfig(
settings.UI_TARS_MODEL,
["UI_TARS_API_KEY"],
supports_vision=True,
add_assistant_prefix=False,
max_tokens=400,
temperature=0.0,
),
)
if settings.ENABLE_GEMINI:
LLMConfigRegistry.register_config(
@@ -630,6 +642,16 @@ if settings.ENABLE_GEMINI:
max_completion_tokens=65536,
),
)
LLMConfigRegistry.register_config(
"GEMINI_2.5_FLASH_PREVIEW",
LLMConfig(
"gemini/gemini-2.5-flash-preview-05-20",
["GEMINI_API_KEY"],
supports_vision=True,
add_assistant_prefix=False,
max_completion_tokens=65536,
),
)
if settings.ENABLE_NOVITA:

View File

@@ -0,0 +1,200 @@
#
# SPDX-License-Identifier: Apache-2.0
# Code partially adapted from:
# https://github.com/ByteDance-Seed/Seed1.5-VL/blob/main/GUI/gui.ipynb
#
# Licensed under the Apache License, Version 2.0
#
# For managing the conversation history of the UI-TARS agent.
#
"""
UI-TARS LLM Caller that follows the standard LLMCaller pattern.
"""
import base64
from io import BytesIO
from typing import Any, Dict
import structlog
from PIL import Image
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.utils.image_resizer import Resolution
LOG = structlog.get_logger()
def _build_system_prompt(instruction: str, language: str = "English") -> str:
"""Build system prompt for UI-TARS using the prompt engine."""
return prompt_engine.load_prompt("ui-tars-system-prompt", language=language, instruction=instruction)
def _is_image_message(message: Dict[str, Any]) -> bool:
"""Check if message contains an image."""
return (
message.get("role") == "user"
and isinstance(message.get("content"), list)
and any(item.get("type") == "image_url" for item in message["content"])
)
class UITarsLLMCaller(LLMCaller):
"""
UI-TARS specific LLM caller that manages conversation history.
Follows the established LLMCaller pattern used by Anthropic CUA.
"""
def __init__(self, llm_key: str, screenshot_scaling_enabled: bool = False):
super().__init__(llm_key, screenshot_scaling_enabled)
self.max_history_images = 5
self._conversation_initialized = False
def initialize_conversation(self, task: Task) -> None:
"""Initialize conversation with system prompt for the given task."""
if not self._conversation_initialized:
# Handle None case for navigation_goal
instruction = task.navigation_goal or "Default navigation task"
system_prompt = _build_system_prompt(instruction)
self.message_history = [{"role": "user", "content": system_prompt}]
self._conversation_initialized = True
LOG.debug("Initialized UI-TARS conversation", task_id=task.task_id)
def add_screenshot(self, screenshot_bytes: bytes) -> None:
"""Add screenshot to conversation history."""
if not screenshot_bytes:
return
# Convert to PIL Image to get format
image = Image.open(BytesIO(screenshot_bytes))
image_format = self._get_image_format_from_pil(image)
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
# Add image message
image_message = {
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/{image_format};base64,{screenshot_b64}"}}
],
}
self.message_history.append(image_message)
self._maintain_history_limit()
LOG.debug("Added screenshot to conversation", total_messages=len(self.message_history))
def add_assistant_response(self, response: str) -> None:
"""Add assistant response to conversation history."""
self.message_history.append({"role": "assistant", "content": response})
LOG.debug("Added assistant response to conversation")
def _maintain_history_limit(self) -> None:
"""Maintain history limit: keep system prompt + all assistant responses + last N screenshots."""
image_count = self._count_image_messages()
if image_count <= self.max_history_images:
return
# Ensure we have a system prompt (first message should be user with string content)
if (
not self.message_history
or self.message_history[0]["role"] != "user"
or not isinstance(self.message_history[0]["content"], str)
):
LOG.error("Conversation history corrupted - missing system prompt")
return
# Remove oldest screenshots only (keep system prompt and all assistant responses)
removed_count = 0
images_to_remove = image_count - self.max_history_images
i = 1 # Start after system prompt (index 0)
while i < len(self.message_history) and removed_count < images_to_remove:
message = self.message_history[i]
if _is_image_message(message):
# Remove only the screenshot message, keep all assistant responses
self.message_history.pop(i)
removed_count += 1
# Don't increment i since we removed an element
else:
i += 1
LOG.debug(
f"Maintained history limit, removed {removed_count} old images, "
f"current messages: {len(self.message_history)}"
)
def _count_image_messages(self) -> int:
"""Count existing image messages in the conversation history."""
count = 0
for message in self.message_history:
if _is_image_message(message):
count += 1
return count
def _get_image_format_from_pil(self, image: Image.Image) -> str:
"""Extract and validate image format from PIL Image object."""
format_str = image.format.lower() if image.format else "png"
if format_str not in ["jpg", "jpeg", "png", "webp"]:
return "png" # Default to PNG for unsupported formats
return format_str
async def call(
self,
prompt: str | None = None,
prompt_name: str | None = None,
step: Step | None = None,
task_v2: TaskV2 | None = None,
thought: Thought | None = None,
ai_suggestion: AISuggestion | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
tools: list[Any] | None = None,
use_message_history: bool = False,
raw_response: bool = False,
window_dimension: Resolution | None = None,
**extra_parameters: Any,
) -> dict[str, Any]:
"""Override call method to use standard LLM routing instead of direct LiteLLM."""
# Use raw_response=True to bypass JSON parsing since UI-TARS returns plain text
response = await super().call(
prompt=prompt,
prompt_name=prompt_name,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
screenshots=screenshots,
parameters=parameters,
tools=tools,
use_message_history=True, # Use message history for UI-TARS
raw_response=True, # Bypass JSON parsing - UI-TARS returns plain text
window_dimension=window_dimension,
**extra_parameters,
)
# Extract content from the raw response
if isinstance(response, dict) and "choices" in response:
content = response["choices"][0]["message"]["content"]
return {"content": content}
else:
# Fallback for unexpected response format
return {"content": str(response)}
async def generate_ui_tars_response(self, step: Step) -> str:
"""Generate UI-TARS response using the overridden call method."""
response = await self.call(step=step)
content = response.get("content", "").strip()
# Add the response to conversation history
self.add_assistant_response(content)
return content

View File

@@ -0,0 +1,66 @@
"""UI-TARS response model that mimics the ModelResponse interface."""
import json
from typing import Any
class UITarsResponse:
"""A response object that mimics the ModelResponse interface for UI-TARS API responses."""
def __init__(self, content: str, model: str):
# Create choice objects with proper nested structure for parse_api_response
class Message:
def __init__(self, content: str):
self.content = content
self.role = "assistant"
class Choice:
def __init__(self, content: str):
self.message = Message(content)
self.choices = [Choice(content)]
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
self.model = model
self.object = "chat.completion"
def model_dump_json(self, indent: int = 2) -> str:
"""Provide model_dump_json compatibility for artifact creation."""
return json.dumps(
{
"choices": [
{
"message": {
"content": self.choices[0].message.content,
"role": self.choices[0].message.role,
}
}
],
"usage": self.usage,
"model": self.model,
"object": self.object,
},
indent=indent,
)
def model_dump(self, exclude_none: bool = True) -> dict:
"""Provide model_dump compatibility for raw_response."""
return {
"choices": [
{"message": {"content": self.choices[0].message.content, "role": self.choices[0].message.role}}
],
"usage": self.usage,
"model": self.model,
"object": self.object,
}
def get(self, key: str, default: Any = None) -> Any:
"""Provide dict-like access for compatibility."""
return getattr(self, key, default)
def __getitem__(self, key: str) -> Any:
"""Provide dict-like access for compatibility."""
return getattr(self, key)
def __contains__(self, key: str) -> bool:
"""Provide dict-like access for compatibility."""
return hasattr(self, key)