ui-tars integration (#2656)
This commit is contained in:
@@ -59,6 +59,7 @@ from skyvern.forge.sdk.api.files import (
|
||||
wait_for_download_finished,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
|
||||
from skyvern.forge.sdk.api.llm.ui_tars_llm_caller import UITarsLLMCaller
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
||||
@@ -91,7 +92,12 @@ from skyvern.webeye.actions.actions import (
|
||||
from skyvern.webeye.actions.caching import retrieve_action_plan
|
||||
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
|
||||
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
|
||||
from skyvern.webeye.actions.parse_actions import parse_actions, parse_anthropic_actions, parse_cua_actions
|
||||
from skyvern.webeye.actions.parse_actions import (
|
||||
parse_actions,
|
||||
parse_anthropic_actions,
|
||||
parse_cua_actions,
|
||||
parse_ui_tars_actions,
|
||||
)
|
||||
from skyvern.webeye.actions.responses import ActionResult, ActionSuccess
|
||||
from skyvern.webeye.browser_factory import BrowserState
|
||||
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
|
||||
@@ -393,9 +399,18 @@ class ForgeAgent:
|
||||
llm_key=llm_key or settings.ANTHROPIC_CUA_LLM_KEY, screenshot_scaling_enabled=True
|
||||
)
|
||||
|
||||
if engine == RunEngine.ui_tars and not llm_caller:
|
||||
# see if the llm_caller is already set in memory
|
||||
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
|
||||
if not llm_caller:
|
||||
# create a new UI-TARS llm_caller
|
||||
llm_key = task.llm_key or settings.UI_TARS_LLM_KEY
|
||||
llm_caller = UITarsLLMCaller(llm_key=llm_key, screenshot_scaling_enabled=True)
|
||||
llm_caller.initialize_conversation(task)
|
||||
|
||||
# TODO: remove the code after migrating everything to llm callers
|
||||
# currently, only anthropic cua tasks use llm_caller
|
||||
if engine == RunEngine.anthropic_cua and llm_caller:
|
||||
# currently, only anthropic cua and ui_tars tasks use llm_caller
|
||||
if engine in [RunEngine.anthropic_cua, RunEngine.ui_tars] and llm_caller:
|
||||
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
|
||||
|
||||
step, detailed_output = await self.agent_step(
|
||||
@@ -550,6 +565,7 @@ class ForgeAgent:
|
||||
complete_verification=complete_verification,
|
||||
engine=engine,
|
||||
cua_response=cua_response_param,
|
||||
llm_caller=llm_caller,
|
||||
)
|
||||
elif settings.execute_all_steps() and next_step:
|
||||
return await self.execute_step(
|
||||
@@ -563,6 +579,7 @@ class ForgeAgent:
|
||||
complete_verification=complete_verification,
|
||||
engine=engine,
|
||||
cua_response=cua_response_param,
|
||||
llm_caller=llm_caller,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
@@ -854,6 +871,15 @@ class ForgeAgent:
|
||||
scraped_page=scraped_page,
|
||||
llm_caller=llm_caller,
|
||||
)
|
||||
elif engine == RunEngine.ui_tars:
|
||||
assert llm_caller is not None
|
||||
actions = await self._generate_ui_tars_actions(
|
||||
task=task,
|
||||
step=step,
|
||||
scraped_page=scraped_page,
|
||||
llm_caller=llm_caller,
|
||||
)
|
||||
|
||||
else:
|
||||
using_cached_action_plan = False
|
||||
if not task.navigation_goal and not isinstance(task_block, ValidationBlock):
|
||||
@@ -1483,6 +1509,56 @@ class ForgeAgent:
|
||||
)
|
||||
return actions
|
||||
|
||||
async def _generate_ui_tars_actions(
|
||||
self,
|
||||
task: Task,
|
||||
step: Step,
|
||||
scraped_page: ScrapedPage,
|
||||
llm_caller: LLMCaller,
|
||||
) -> list[Action]:
|
||||
"""Generate actions using UI-TARS (Seed1.5-VL) model through the LLMCaller pattern."""
|
||||
|
||||
LOG.info(
|
||||
"UI-TARS action generation starts",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
step_order=step.order,
|
||||
)
|
||||
|
||||
# Ensure we have a UITarsLLMCaller instance
|
||||
if not isinstance(llm_caller, UITarsLLMCaller):
|
||||
raise ValueError(f"Expected UITarsLLMCaller, got {type(llm_caller)}")
|
||||
|
||||
# Add the current screenshot to conversation
|
||||
if scraped_page.screenshots:
|
||||
llm_caller.add_screenshot(scraped_page.screenshots[0])
|
||||
else:
|
||||
LOG.error("No screenshots found, skipping UI-TARS action generation")
|
||||
raise ValueError("No screenshots found, skipping UI-TARS action generation")
|
||||
|
||||
# Generate response using the LLMCaller
|
||||
response_content = await llm_caller.generate_ui_tars_response(step)
|
||||
|
||||
LOG.info(f"UI-TARS raw response: {response_content}")
|
||||
|
||||
window_dimension = (
|
||||
cast(Resolution, scraped_page.window_dimension)
|
||||
if scraped_page.window_dimension
|
||||
else Resolution(width=1920, height=1080)
|
||||
)
|
||||
LOG.info(f"UI-TARS browser window dimension: {window_dimension}")
|
||||
|
||||
actions = await parse_ui_tars_actions(task, step, response_content, window_dimension)
|
||||
|
||||
LOG.info(
|
||||
"UI-TARS action generation completed",
|
||||
task_id=task.task_id,
|
||||
step_id=step.step_id,
|
||||
actions_count=len(actions),
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
async def complete_verify(
|
||||
self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
|
||||
) -> CompleteVerifyResult:
|
||||
@@ -2105,6 +2181,7 @@ class ForgeAgent:
|
||||
return
|
||||
|
||||
await self.async_operation_pool.remove_task(task.task_id)
|
||||
|
||||
await self.cleanup_browser_and_create_artifacts(
|
||||
close_browser_on_completion, last_step, task, browser_session_id=browser_session_id
|
||||
)
|
||||
|
||||
@@ -46,6 +46,14 @@ ANTHROPIC_CLIENT = AsyncAnthropic(api_key=SettingsManager.get_settings().ANTHROP
|
||||
if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC:
|
||||
ANTHROPIC_CLIENT = AsyncAnthropicBedrock()
|
||||
|
||||
# Add UI-TARS client setup
|
||||
UI_TARS_CLIENT = None
|
||||
if SettingsManager.get_settings().ENABLE_UI_TARS:
|
||||
UI_TARS_CLIENT = AsyncOpenAI(
|
||||
api_key=SettingsManager.get_settings().UI_TARS_API_KEY,
|
||||
base_url=SettingsManager.get_settings().UI_TARS_API_BASE,
|
||||
)
|
||||
|
||||
SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY
|
||||
)
|
||||
|
||||
37
skyvern/forge/prompts/skyvern/ui-tars-system-prompt.j2
Normal file
37
skyvern/forge/prompts/skyvern/ui-tars-system-prompt.j2
Normal file
@@ -0,0 +1,37 @@
|
||||
{#
|
||||
SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
Adapted from:
|
||||
https://github.com/bytedance/UI-TARS/blob/main/codes/ui_tars/prompt.py
|
||||
|
||||
Licensed under the Apache License, Version 2.0
|
||||
|
||||
This prompt is used for the UI-TARS agent.
|
||||
#}
|
||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(point='<point>x1 y1</point>')
|
||||
left_double(point='<point>x1 y1</point>')
|
||||
right_single(point='<point>x1 y1</point>')
|
||||
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>')
|
||||
hotkey(key='ctrl c') # Split keys with a space and use lowercase. Also, do not use more than 3 keys in one hotkey action.
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') # Show more information on the `direction` side.
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
|
||||
|
||||
## Note
|
||||
- Use {{language}} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{{instruction}}
|
||||
@@ -2,7 +2,7 @@ import dataclasses
|
||||
import json
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from typing import Any
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import litellm
|
||||
import structlog
|
||||
@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
|
||||
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
|
||||
from jinja2 import Template
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.config import settings
|
||||
@@ -23,6 +24,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
|
||||
LLMProviderErrorRetryableTask,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig, dummy_llm_api_handler
|
||||
from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
|
||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
@@ -744,10 +746,14 @@ class LLMCaller:
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | Any:
|
||||
if self.llm_key and "ANTHROPIC" in self.llm_key:
|
||||
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
|
||||
|
||||
# Route UI-TARS models to custom handler instead of LiteLLM
|
||||
if self.llm_key and "UI_TARS" in self.llm_key:
|
||||
return await self._call_ui_tars(messages, tools, timeout, **active_parameters)
|
||||
|
||||
return await litellm.acompletion(
|
||||
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
|
||||
)
|
||||
@@ -790,8 +796,97 @@ class LLMCaller:
|
||||
)
|
||||
return response
|
||||
|
||||
async def get_call_stats(self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage) -> LLMCallStats:
|
||||
async def _call_ui_tars(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Custom UI-TARS API call using OpenAI client with VolcEngine endpoint."""
|
||||
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 400
|
||||
model_name = self.llm_config.model_name
|
||||
|
||||
if not app.UI_TARS_CLIENT:
|
||||
raise ValueError(
|
||||
"UI_TARS_CLIENT not initialized. Please ensure ENABLE_UI_TARS=true and UI_TARS_API_KEY is set."
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"UI-TARS request",
|
||||
model_name=model_name,
|
||||
timeout=timeout,
|
||||
messages_length=len(messages),
|
||||
)
|
||||
|
||||
# Use the UI-TARS client (which is OpenAI-compatible with VolcEngine)
|
||||
chat_completion: AsyncIterator[ChatCompletionChunk] = await app.UI_TARS_CLIENT.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
top_p=None,
|
||||
temperature=active_parameters.get("temperature", 0.0),
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
seed=None,
|
||||
stop=None,
|
||||
frequency_penalty=None,
|
||||
presence_penalty=None,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Aggregate streaming response like in ByteDance example
|
||||
response_content = ""
|
||||
async for message in chat_completion:
|
||||
if message.choices[0].delta.content:
|
||||
response_content += message.choices[0].delta.content
|
||||
|
||||
response = UITarsResponse(response_content, model_name)
|
||||
|
||||
LOG.info(
|
||||
"UI-TARS response",
|
||||
model_name=model_name,
|
||||
response_length=len(response_content),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response
|
||||
|
||||
async def get_call_stats(
|
||||
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | dict[str, Any] | Any
|
||||
) -> LLMCallStats:
|
||||
empty_call_stats = LLMCallStats()
|
||||
|
||||
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
|
||||
if hasattr(response, "usage") and hasattr(response, "choices") and hasattr(response, "model"):
|
||||
usage = response.usage
|
||||
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
|
||||
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
|
||||
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
|
||||
llm_cost = input_token_cost + output_token_cost
|
||||
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=usage.get("prompt_tokens", 0),
|
||||
output_tokens=usage.get("completion_tokens", 0),
|
||||
cached_tokens=0, # UI-TARS doesn't have cached tokens
|
||||
reasoning_tokens=0,
|
||||
)
|
||||
|
||||
# Handle UI-TARS response (dict format - fallback)
|
||||
if isinstance(response, dict) and "choices" in response and "usage" in response:
|
||||
usage = response["usage"]
|
||||
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
|
||||
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
|
||||
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
|
||||
llm_cost = input_token_cost + output_token_cost
|
||||
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=usage.get("prompt_tokens", 0),
|
||||
output_tokens=usage.get("completion_tokens", 0),
|
||||
cached_tokens=0, # UI-TARS doesn't have cached tokens
|
||||
reasoning_tokens=0,
|
||||
)
|
||||
|
||||
if isinstance(response, AnthropicMessage):
|
||||
usage = response.usage
|
||||
input_token_cost = (3.0 / 1000000) * usage.input_tokens
|
||||
|
||||
@@ -568,6 +568,18 @@ if settings.ENABLE_AZURE_O3:
|
||||
max_completion_tokens=100000,
|
||||
),
|
||||
)
|
||||
if settings.ENABLE_UI_TARS:
|
||||
LLMConfigRegistry.register_config(
|
||||
"UI_TARS_SEED1_5_VL",
|
||||
LLMConfig(
|
||||
settings.UI_TARS_MODEL,
|
||||
["UI_TARS_API_KEY"],
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
max_tokens=400,
|
||||
temperature=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
if settings.ENABLE_GEMINI:
|
||||
LLMConfigRegistry.register_config(
|
||||
@@ -630,6 +642,16 @@ if settings.ENABLE_GEMINI:
|
||||
max_completion_tokens=65536,
|
||||
),
|
||||
)
|
||||
LLMConfigRegistry.register_config(
|
||||
"GEMINI_2.5_FLASH_PREVIEW",
|
||||
LLMConfig(
|
||||
"gemini/gemini-2.5-flash-preview-05-20",
|
||||
["GEMINI_API_KEY"],
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
max_completion_tokens=65536,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if settings.ENABLE_NOVITA:
|
||||
|
||||
200
skyvern/forge/sdk/api/llm/ui_tars_llm_caller.py
Normal file
200
skyvern/forge/sdk/api/llm/ui_tars_llm_caller.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Code partially adapted from:
|
||||
# https://github.com/ByteDance-Seed/Seed1.5-VL/blob/main/GUI/gui.ipynb
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
#
|
||||
# For managing the conversation history of the UI-TARS agent.
|
||||
#
|
||||
|
||||
"""
|
||||
UI-TARS LLM Caller that follows the standard LLMCaller pattern.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict
|
||||
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.utils.image_resizer import Resolution
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
def _build_system_prompt(instruction: str, language: str = "English") -> str:
|
||||
"""Build system prompt for UI-TARS using the prompt engine."""
|
||||
return prompt_engine.load_prompt("ui-tars-system-prompt", language=language, instruction=instruction)
|
||||
|
||||
|
||||
def _is_image_message(message: Dict[str, Any]) -> bool:
|
||||
"""Check if message contains an image."""
|
||||
return (
|
||||
message.get("role") == "user"
|
||||
and isinstance(message.get("content"), list)
|
||||
and any(item.get("type") == "image_url" for item in message["content"])
|
||||
)
|
||||
|
||||
|
||||
class UITarsLLMCaller(LLMCaller):
|
||||
"""
|
||||
UI-TARS specific LLM caller that manages conversation history.
|
||||
Follows the established LLMCaller pattern used by Anthropic CUA.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_key: str, screenshot_scaling_enabled: bool = False):
|
||||
super().__init__(llm_key, screenshot_scaling_enabled)
|
||||
self.max_history_images = 5
|
||||
self._conversation_initialized = False
|
||||
|
||||
def initialize_conversation(self, task: Task) -> None:
|
||||
"""Initialize conversation with system prompt for the given task."""
|
||||
if not self._conversation_initialized:
|
||||
# Handle None case for navigation_goal
|
||||
instruction = task.navigation_goal or "Default navigation task"
|
||||
system_prompt = _build_system_prompt(instruction)
|
||||
self.message_history = [{"role": "user", "content": system_prompt}]
|
||||
self._conversation_initialized = True
|
||||
LOG.debug("Initialized UI-TARS conversation", task_id=task.task_id)
|
||||
|
||||
def add_screenshot(self, screenshot_bytes: bytes) -> None:
|
||||
"""Add screenshot to conversation history."""
|
||||
if not screenshot_bytes:
|
||||
return
|
||||
|
||||
# Convert to PIL Image to get format
|
||||
image = Image.open(BytesIO(screenshot_bytes))
|
||||
image_format = self._get_image_format_from_pil(image)
|
||||
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
# Add image message
|
||||
image_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/{image_format};base64,{screenshot_b64}"}}
|
||||
],
|
||||
}
|
||||
|
||||
self.message_history.append(image_message)
|
||||
self._maintain_history_limit()
|
||||
|
||||
LOG.debug("Added screenshot to conversation", total_messages=len(self.message_history))
|
||||
|
||||
def add_assistant_response(self, response: str) -> None:
|
||||
"""Add assistant response to conversation history."""
|
||||
self.message_history.append({"role": "assistant", "content": response})
|
||||
LOG.debug("Added assistant response to conversation")
|
||||
|
||||
def _maintain_history_limit(self) -> None:
|
||||
"""Maintain history limit: keep system prompt + all assistant responses + last N screenshots."""
|
||||
image_count = self._count_image_messages()
|
||||
|
||||
if image_count <= self.max_history_images:
|
||||
return
|
||||
|
||||
# Ensure we have a system prompt (first message should be user with string content)
|
||||
if (
|
||||
not self.message_history
|
||||
or self.message_history[0]["role"] != "user"
|
||||
or not isinstance(self.message_history[0]["content"], str)
|
||||
):
|
||||
LOG.error("Conversation history corrupted - missing system prompt")
|
||||
return
|
||||
|
||||
# Remove oldest screenshots only (keep system prompt and all assistant responses)
|
||||
removed_count = 0
|
||||
images_to_remove = image_count - self.max_history_images
|
||||
|
||||
i = 1 # Start after system prompt (index 0)
|
||||
while i < len(self.message_history) and removed_count < images_to_remove:
|
||||
message = self.message_history[i]
|
||||
if _is_image_message(message):
|
||||
# Remove only the screenshot message, keep all assistant responses
|
||||
self.message_history.pop(i)
|
||||
removed_count += 1
|
||||
# Don't increment i since we removed an element
|
||||
else:
|
||||
i += 1
|
||||
|
||||
LOG.debug(
|
||||
f"Maintained history limit, removed {removed_count} old images, "
|
||||
f"current messages: {len(self.message_history)}"
|
||||
)
|
||||
|
||||
def _count_image_messages(self) -> int:
|
||||
"""Count existing image messages in the conversation history."""
|
||||
count = 0
|
||||
for message in self.message_history:
|
||||
if _is_image_message(message):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _get_image_format_from_pil(self, image: Image.Image) -> str:
|
||||
"""Extract and validate image format from PIL Image object."""
|
||||
format_str = image.format.lower() if image.format else "png"
|
||||
if format_str not in ["jpg", "jpeg", "png", "webp"]:
|
||||
return "png" # Default to PNG for unsupported formats
|
||||
return format_str
|
||||
|
||||
async def call(
|
||||
self,
|
||||
prompt: str | None = None,
|
||||
prompt_name: str | None = None,
|
||||
step: Step | None = None,
|
||||
task_v2: TaskV2 | None = None,
|
||||
thought: Thought | None = None,
|
||||
ai_suggestion: AISuggestion | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
tools: list[Any] | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
**extra_parameters: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Override call method to use standard LLM routing instead of direct LiteLLM."""
|
||||
|
||||
# Use raw_response=True to bypass JSON parsing since UI-TARS returns plain text
|
||||
response = await super().call(
|
||||
prompt=prompt,
|
||||
prompt_name=prompt_name,
|
||||
step=step,
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
screenshots=screenshots,
|
||||
parameters=parameters,
|
||||
tools=tools,
|
||||
use_message_history=True, # Use message history for UI-TARS
|
||||
raw_response=True, # Bypass JSON parsing - UI-TARS returns plain text
|
||||
window_dimension=window_dimension,
|
||||
**extra_parameters,
|
||||
)
|
||||
|
||||
# Extract content from the raw response
|
||||
if isinstance(response, dict) and "choices" in response:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
return {"content": content}
|
||||
else:
|
||||
# Fallback for unexpected response format
|
||||
return {"content": str(response)}
|
||||
|
||||
async def generate_ui_tars_response(self, step: Step) -> str:
|
||||
"""Generate UI-TARS response using the overridden call method."""
|
||||
response = await self.call(step=step)
|
||||
|
||||
content = response.get("content", "").strip()
|
||||
|
||||
# Add the response to conversation history
|
||||
self.add_assistant_response(content)
|
||||
|
||||
return content
|
||||
66
skyvern/forge/sdk/api/llm/ui_tars_response.py
Normal file
66
skyvern/forge/sdk/api/llm/ui_tars_response.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""UI-TARS response model that mimics the ModelResponse interface."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
class UITarsResponse:
|
||||
"""A response object that mimics the ModelResponse interface for UI-TARS API responses."""
|
||||
|
||||
def __init__(self, content: str, model: str):
|
||||
# Create choice objects with proper nested structure for parse_api_response
|
||||
class Message:
|
||||
def __init__(self, content: str):
|
||||
self.content = content
|
||||
self.role = "assistant"
|
||||
|
||||
class Choice:
|
||||
def __init__(self, content: str):
|
||||
self.message = Message(content)
|
||||
|
||||
self.choices = [Choice(content)]
|
||||
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
self.model = model
|
||||
self.object = "chat.completion"
|
||||
|
||||
def model_dump_json(self, indent: int = 2) -> str:
|
||||
"""Provide model_dump_json compatibility for artifact creation."""
|
||||
return json.dumps(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": self.choices[0].message.content,
|
||||
"role": self.choices[0].message.role,
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": self.usage,
|
||||
"model": self.model,
|
||||
"object": self.object,
|
||||
},
|
||||
indent=indent,
|
||||
)
|
||||
|
||||
def model_dump(self, exclude_none: bool = True) -> dict:
|
||||
"""Provide model_dump compatibility for raw_response."""
|
||||
return {
|
||||
"choices": [
|
||||
{"message": {"content": self.choices[0].message.content, "role": self.choices[0].message.role}}
|
||||
],
|
||||
"usage": self.usage,
|
||||
"model": self.model,
|
||||
"object": self.object,
|
||||
}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Provide dict-like access for compatibility."""
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Provide dict-like access for compatibility."""
|
||||
return getattr(self, key)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Provide dict-like access for compatibility."""
|
||||
return hasattr(self, key)
|
||||
@@ -100,6 +100,8 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
engine = RunEngine.openai_cua
|
||||
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
|
||||
engine = RunEngine.anthropic_cua
|
||||
elif run_obj and run_obj.task_run_type == RunType.ui_tars:
|
||||
engine = RunEngine.ui_tars
|
||||
|
||||
context: SkyvernContext = skyvern_context.ensure_context()
|
||||
context.task_id = task.task_id
|
||||
|
||||
Reference in New Issue
Block a user