ui-tars integration (#2656)
This commit is contained in:
@@ -2,7 +2,7 @@ import dataclasses
|
||||
import json
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from typing import Any
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import litellm
|
||||
import structlog
|
||||
@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
|
||||
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
|
||||
from jinja2 import Template
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.config import settings
|
||||
@@ -23,6 +24,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
|
||||
LLMProviderErrorRetryableTask,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig, dummy_llm_api_handler
|
||||
from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
|
||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
@@ -744,10 +746,14 @@ class LLMCaller:
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | Any:
|
||||
if self.llm_key and "ANTHROPIC" in self.llm_key:
|
||||
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
|
||||
|
||||
# Route UI-TARS models to custom handler instead of LiteLLM
|
||||
if self.llm_key and "UI_TARS" in self.llm_key:
|
||||
return await self._call_ui_tars(messages, tools, timeout, **active_parameters)
|
||||
|
||||
return await litellm.acompletion(
|
||||
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
|
||||
)
|
||||
@@ -790,8 +796,97 @@ class LLMCaller:
|
||||
)
|
||||
return response
|
||||
|
||||
async def get_call_stats(self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage) -> LLMCallStats:
|
||||
async def _call_ui_tars(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Custom UI-TARS API call using OpenAI client with VolcEngine endpoint."""
|
||||
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 400
|
||||
model_name = self.llm_config.model_name
|
||||
|
||||
if not app.UI_TARS_CLIENT:
|
||||
raise ValueError(
|
||||
"UI_TARS_CLIENT not initialized. Please ensure ENABLE_UI_TARS=true and UI_TARS_API_KEY is set."
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"UI-TARS request",
|
||||
model_name=model_name,
|
||||
timeout=timeout,
|
||||
messages_length=len(messages),
|
||||
)
|
||||
|
||||
# Use the UI-TARS client (which is OpenAI-compatible with VolcEngine)
|
||||
chat_completion: AsyncIterator[ChatCompletionChunk] = await app.UI_TARS_CLIENT.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
top_p=None,
|
||||
temperature=active_parameters.get("temperature", 0.0),
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
seed=None,
|
||||
stop=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Aggregate streaming response like in ByteDance example
|
||||
response_content = ""
|
||||
async for message in chat_completion:
|
||||
if message.choices[0].delta.content:
|
||||
response_content += message.choices[0].delta.content
|
||||
|
||||
response = UITarsResponse(response_content, model_name)
|
||||
|
||||
LOG.info(
|
||||
"UI-TARS response",
|
||||
model_name=model_name,
|
||||
response_length=len(response_content),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response
|
||||
|
||||
async def get_call_stats(
|
||||
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | dict[str, Any] | Any
|
||||
) -> LLMCallStats:
|
||||
empty_call_stats = LLMCallStats()
|
||||
|
||||
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
|
||||
if hasattr(response, "usage") and hasattr(response, "choices") and hasattr(response, "model"):
|
||||
usage = response.usage
|
||||
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
|
||||
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
|
||||
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
|
||||
llm_cost = input_token_cost + output_token_cost
|
||||
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=usage.get("prompt_tokens", 0),
|
||||
output_tokens=usage.get("completion_tokens", 0),
|
||||
cached_tokens=0, # UI-TARS doesn't have cached tokens
|
||||
reasoning_tokens=0,
|
||||
)
|
||||
|
||||
# Handle UI-TARS response (dict format - fallback)
|
||||
if isinstance(response, dict) and "choices" in response and "usage" in response:
|
||||
usage = response["usage"]
|
||||
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
|
||||
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
|
||||
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
|
||||
llm_cost = input_token_cost + output_token_cost
|
||||
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=usage.get("prompt_tokens", 0),
|
||||
output_tokens=usage.get("completion_tokens", 0),
|
||||
cached_tokens=0, # UI-TARS doesn't have cached tokens
|
||||
reasoning_tokens=0,
|
||||
)
|
||||
|
||||
if isinstance(response, AnthropicMessage):
|
||||
usage = response.usage
|
||||
input_token_cost = (3.0 / 1000000) * usage.input_tokens
|
||||
|
||||
@@ -568,6 +568,18 @@ if settings.ENABLE_AZURE_O3:
|
||||
max_completion_tokens=100000,
|
||||
),
|
||||
)
|
||||
if settings.ENABLE_UI_TARS:
|
||||
LLMConfigRegistry.register_config(
|
||||
"UI_TARS_SEED1_5_VL",
|
||||
LLMConfig(
|
||||
settings.UI_TARS_MODEL,
|
||||
["UI_TARS_API_KEY"],
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
max_tokens=400,
|
||||
temperature=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
if settings.ENABLE_GEMINI:
|
||||
LLMConfigRegistry.register_config(
|
||||
@@ -630,6 +642,16 @@ if settings.ENABLE_GEMINI:
|
||||
max_completion_tokens=65536,
|
||||
),
|
||||
)
|
||||
LLMConfigRegistry.register_config(
|
||||
"GEMINI_2.5_FLASH_PREVIEW",
|
||||
LLMConfig(
|
||||
"gemini/gemini-2.5-flash-preview-05-20",
|
||||
["GEMINI_API_KEY"],
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
max_completion_tokens=65536,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if settings.ENABLE_NOVITA:
|
||||
|
||||
200
skyvern/forge/sdk/api/llm/ui_tars_llm_caller.py
Normal file
200
skyvern/forge/sdk/api/llm/ui_tars_llm_caller.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Code partially adapted from:
|
||||
# https://github.com/ByteDance-Seed/Seed1.5-VL/blob/main/GUI/gui.ipynb
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
#
|
||||
# For managing the conversation history of the UI-TARS agent.
|
||||
#
|
||||
|
||||
"""
|
||||
UI-TARS LLM Caller that follows the standard LLMCaller pattern.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict
|
||||
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.utils.image_resizer import Resolution
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
def _build_system_prompt(instruction: str, language: str = "English") -> str:
|
||||
"""Build system prompt for UI-TARS using the prompt engine."""
|
||||
return prompt_engine.load_prompt("ui-tars-system-prompt", language=language, instruction=instruction)
|
||||
|
||||
|
||||
def _is_image_message(message: Dict[str, Any]) -> bool:
|
||||
"""Check if message contains an image."""
|
||||
return (
|
||||
message.get("role") == "user"
|
||||
and isinstance(message.get("content"), list)
|
||||
and any(item.get("type") == "image_url" for item in message["content"])
|
||||
)
|
||||
|
||||
|
||||
class UITarsLLMCaller(LLMCaller):
|
||||
"""
|
||||
UI-TARS specific LLM caller that manages conversation history.
|
||||
Follows the established LLMCaller pattern used by Anthropic CUA.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_key: str, screenshot_scaling_enabled: bool = False):
|
||||
super().__init__(llm_key, screenshot_scaling_enabled)
|
||||
self.max_history_images = 5
|
||||
self._conversation_initialized = False
|
||||
|
||||
def initialize_conversation(self, task: Task) -> None:
|
||||
"""Initialize conversation with system prompt for the given task."""
|
||||
if not self._conversation_initialized:
|
||||
# Handle None case for navigation_goal
|
||||
instruction = task.navigation_goal or "Default navigation task"
|
||||
system_prompt = _build_system_prompt(instruction)
|
||||
self.message_history = [{"role": "user", "content": system_prompt}]
|
||||
self._conversation_initialized = True
|
||||
LOG.debug("Initialized UI-TARS conversation", task_id=task.task_id)
|
||||
|
||||
def add_screenshot(self, screenshot_bytes: bytes) -> None:
|
||||
"""Add screenshot to conversation history."""
|
||||
if not screenshot_bytes:
|
||||
return
|
||||
|
||||
# Convert to PIL Image to get format
|
||||
image = Image.open(BytesIO(screenshot_bytes))
|
||||
image_format = self._get_image_format_from_pil(image)
|
||||
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
# Add image message
|
||||
image_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/{image_format};base64,{screenshot_b64}"}}
|
||||
],
|
||||
}
|
||||
|
||||
self.message_history.append(image_message)
|
||||
self._maintain_history_limit()
|
||||
|
||||
LOG.debug("Added screenshot to conversation", total_messages=len(self.message_history))
|
||||
|
||||
def add_assistant_response(self, response: str) -> None:
|
||||
"""Add assistant response to conversation history."""
|
||||
self.message_history.append({"role": "assistant", "content": response})
|
||||
LOG.debug("Added assistant response to conversation")
|
||||
|
||||
def _maintain_history_limit(self) -> None:
|
||||
"""Maintain history limit: keep system prompt + all assistant responses + last N screenshots."""
|
||||
image_count = self._count_image_messages()
|
||||
|
||||
if image_count <= self.max_history_images:
|
||||
return
|
||||
|
||||
# Ensure we have a system prompt (first message should be user with string content)
|
||||
if (
|
||||
not self.message_history
|
||||
or self.message_history[0]["role"] != "user"
|
||||
or not isinstance(self.message_history[0]["content"], str)
|
||||
):
|
||||
LOG.error("Conversation history corrupted - missing system prompt")
|
||||
return
|
||||
|
||||
# Remove oldest screenshots only (keep system prompt and all assistant responses)
|
||||
removed_count = 0
|
||||
images_to_remove = image_count - self.max_history_images
|
||||
|
||||
i = 1 # Start after system prompt (index 0)
|
||||
while i < len(self.message_history) and removed_count < images_to_remove:
|
||||
message = self.message_history[i]
|
||||
if _is_image_message(message):
|
||||
# Remove only the screenshot message, keep all assistant responses
|
||||
self.message_history.pop(i)
|
||||
removed_count += 1
|
||||
# Don't increment i since we removed an element
|
||||
else:
|
||||
i += 1
|
||||
|
||||
LOG.debug(
|
||||
f"Maintained history limit, removed {removed_count} old images, "
|
||||
f"current messages: {len(self.message_history)}"
|
||||
)
|
||||
|
||||
def _count_image_messages(self) -> int:
|
||||
"""Count existing image messages in the conversation history."""
|
||||
count = 0
|
||||
for message in self.message_history:
|
||||
if _is_image_message(message):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _get_image_format_from_pil(self, image: Image.Image) -> str:
|
||||
"""Extract and validate image format from PIL Image object."""
|
||||
format_str = image.format.lower() if image.format else "png"
|
||||
if format_str not in ["jpg", "jpeg", "png", "webp"]:
|
||||
return "png" # Default to PNG for unsupported formats
|
||||
return format_str
|
||||
|
||||
async def call(
|
||||
self,
|
||||
prompt: str | None = None,
|
||||
prompt_name: str | None = None,
|
||||
step: Step | None = None,
|
||||
task_v2: TaskV2 | None = None,
|
||||
thought: Thought | None = None,
|
||||
ai_suggestion: AISuggestion | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
tools: list[Any] | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
**extra_parameters: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Override call method to use standard LLM routing instead of direct LiteLLM."""
|
||||
|
||||
# Use raw_response=True to bypass JSON parsing since UI-TARS returns plain text
|
||||
response = await super().call(
|
||||
prompt=prompt,
|
||||
prompt_name=prompt_name,
|
||||
step=step,
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
screenshots=screenshots,
|
||||
parameters=parameters,
|
||||
tools=tools,
|
||||
use_message_history=True, # Use message history for UI-TARS
|
||||
raw_response=True, # Bypass JSON parsing - UI-TARS returns plain text
|
||||
window_dimension=window_dimension,
|
||||
**extra_parameters,
|
||||
)
|
||||
|
||||
# Extract content from the raw response
|
||||
if isinstance(response, dict) and "choices" in response:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
return {"content": content}
|
||||
else:
|
||||
# Fallback for unexpected response format
|
||||
return {"content": str(response)}
|
||||
|
||||
async def generate_ui_tars_response(self, step: Step) -> str:
|
||||
"""Generate UI-TARS response using the overridden call method."""
|
||||
response = await self.call(step=step)
|
||||
|
||||
content = response.get("content", "").strip()
|
||||
|
||||
# Add the response to conversation history
|
||||
self.add_assistant_response(content)
|
||||
|
||||
return content
|
||||
66
skyvern/forge/sdk/api/llm/ui_tars_response.py
Normal file
66
skyvern/forge/sdk/api/llm/ui_tars_response.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""UI-TARS response model that mimics the ModelResponse interface."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
class UITarsResponse:
|
||||
"""A response object that mimics the ModelResponse interface for UI-TARS API responses."""
|
||||
|
||||
def __init__(self, content: str, model: str):
|
||||
# Create choice objects with proper nested structure for parse_api_response
|
||||
class Message:
|
||||
def __init__(self, content: str):
|
||||
self.content = content
|
||||
self.role = "assistant"
|
||||
|
||||
class Choice:
|
||||
def __init__(self, content: str):
|
||||
self.message = Message(content)
|
||||
|
||||
self.choices = [Choice(content)]
|
||||
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
self.model = model
|
||||
self.object = "chat.completion"
|
||||
|
||||
def model_dump_json(self, indent: int = 2) -> str:
|
||||
"""Provide model_dump_json compatibility for artifact creation."""
|
||||
return json.dumps(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": self.choices[0].message.content,
|
||||
"role": self.choices[0].message.role,
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": self.usage,
|
||||
"model": self.model,
|
||||
"object": self.object,
|
||||
},
|
||||
indent=indent,
|
||||
)
|
||||
|
||||
def model_dump(self, exclude_none: bool = True) -> dict:
|
||||
"""Provide model_dump compatibility for raw_response."""
|
||||
return {
|
||||
"choices": [
|
||||
{"message": {"content": self.choices[0].message.content, "role": self.choices[0].message.role}}
|
||||
],
|
||||
"usage": self.usage,
|
||||
"model": self.model,
|
||||
"object": self.object,
|
||||
}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Provide dict-like access for compatibility."""
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Provide dict-like access for compatibility."""
|
||||
return getattr(self, key)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Provide dict-like access for compatibility."""
|
||||
return hasattr(self, key)
|
||||
Reference in New Issue
Block a user