support volcengine + migrate ui tars to volcengine (#2705)
This commit is contained in:
14
.env.example
14
.env.example
@@ -43,14 +43,12 @@ ENABLE_NOVITA=false
|
||||
# NOVITA_API_KEY: Your Novita AI API key.
|
||||
NOVITA_API_KEY=""
|
||||
|
||||
# ENABLE_UI_TARS: Set to true to enable UI-TARS (Seed1.5-VL) as a language model provider.
|
||||
ENABLE_UI_TARS=false
|
||||
# UI_TARS_API_KEY: Your ByteDance Doubao API key for accessing UI-TARS models.
|
||||
UI_TARS_API_KEY=""
|
||||
# UI_TARS_API_BASE: The base URL for ByteDance Doubao API.
|
||||
UI_TARS_API_BASE="https://ark.cn-beijing.volces.com/api/v3"
|
||||
# UI_TARS_MODEL: Your UI-TARS model endpoint ID from ByteDance Doubao.
|
||||
UI_TARS_MODEL="doubao-1-5-thinking-vision-pro-250428"
|
||||
# ENABLE_VOLCENGINE: Set to true to enable Volcengine(ByteDance Doubao) as a language model provider.
|
||||
ENABLE_VOLCENGINE=false
|
||||
# VOLCENGINE_API_KEY: Your Volcengine(ByteDance Doubao) API key.
|
||||
VOLCENGINE_API_KEY=""
|
||||
# VOLCENGINE_API_BASE: The base URL for Volcengine(ByteDance Doubao) API.
|
||||
VOLCENGINE_API_BASE="https://ark.cn-beijing.volces.com/api/v3"
|
||||
|
||||
# LLM_KEY: The chosen language model to use. This should be one of the models
|
||||
# provided by the enabled LLM providers (e.g., OPENAI_GPT4_TURBO, OPENAI_GPT4V, ANTHROPIC_CLAUDE3, AZURE_OPENAI_GPT4V).
|
||||
|
||||
@@ -160,30 +160,26 @@ def setup_llm_providers() -> None:
|
||||
else:
|
||||
update_or_add_env_var("ENABLE_NOVITA", "false")
|
||||
|
||||
console.print("\n[bold blue]--- UI-TARS Configuration ---[/bold blue]")
|
||||
console.print("To enable UI-TARS (Seed1.5-VL), you must have a ByteDance Doubao API key.")
|
||||
console.print("UI-TARS now uses direct VolcEngine API calls for improved compatibility.")
|
||||
enable_ui_tars = Confirm.ask("Do you want to enable UI-TARS?")
|
||||
if enable_ui_tars:
|
||||
ui_tars_api_key = Prompt.ask("Enter your ByteDance Doubao API key", password=True)
|
||||
if not ui_tars_api_key:
|
||||
console.print("[red]Error: UI-TARS API key is required. UI-TARS will not be enabled.[/red]")
|
||||
console.print("\n[bold blue]--- VolcEngine Configuration ---[/bold blue]")
|
||||
console.print("To enable VolcEngine, you must have a ByteDance Doubao API key.")
|
||||
enable_volcengine = Confirm.ask("Do you want to enable VolcEngine?")
|
||||
if enable_volcengine:
|
||||
volcengine_api_key = Prompt.ask("Enter your VolcEngine(ByteDance Doubao) API key", password=True)
|
||||
if not volcengine_api_key:
|
||||
console.print("[red]Error: VolcEngine key is required. VolcEngine will not be enabled.[/red]")
|
||||
else:
|
||||
update_or_add_env_var("UI_TARS_API_KEY", ui_tars_api_key)
|
||||
update_or_add_env_var("ENABLE_UI_TARS", "true")
|
||||
update_or_add_env_var("VOLCENGINE_API_KEY", volcengine_api_key)
|
||||
update_or_add_env_var("ENABLE_VOLCENGINE", "true")
|
||||
|
||||
# Optional: Allow customizing model endpoint
|
||||
custom_model = Confirm.ask(
|
||||
"Do you want to use a custom model endpoint? (default: doubao-1-5-thinking-vision-pro-250428)"
|
||||
model_options.extend(
|
||||
[
|
||||
"VOLCENGINE_DOUBAO_SEED_1_6",
|
||||
"VOLCENGINE_DOUBAO_SEED_1_6_FLASH",
|
||||
"VOLCENGINE_DOUBAO_1_5_THINKING_VISION_PRO",
|
||||
]
|
||||
)
|
||||
if custom_model:
|
||||
ui_tars_model = Prompt.ask("Enter your UI-TARS model endpoint ID")
|
||||
if ui_tars_model:
|
||||
update_or_add_env_var("UI_TARS_MODEL", ui_tars_model)
|
||||
|
||||
model_options.append("UI_TARS_SEED1_5_VL")
|
||||
else:
|
||||
update_or_add_env_var("ENABLE_UI_TARS", "false")
|
||||
update_or_add_env_var("ENABLE_VOLCENGINE", "false")
|
||||
|
||||
console.print("\n[bold blue]--- OpenAI-Compatible Provider Configuration ---[/bold blue]")
|
||||
console.print("To enable an OpenAI-compatible provider, you must have a model name, API key, and API base URL.")
|
||||
|
||||
@@ -134,12 +134,11 @@ class Settings(BaseSettings):
|
||||
ANTHROPIC_API_KEY: str | None = None
|
||||
ANTHROPIC_CUA_LLM_KEY: str = "ANTHROPIC_CLAUDE3.7_SONNET"
|
||||
|
||||
# UI-TARS (Seed1.5-VL via Doubao)
|
||||
UI_TARS_API_KEY: str | None = None
|
||||
UI_TARS_API_BASE: str = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
UI_TARS_MODEL: str = "doubao-1-5-thinking-vision-pro-250428"
|
||||
UI_TARS_LLM_KEY: str = "UI_TARS_SEED1_5_VL"
|
||||
ENABLE_UI_TARS: bool = False
|
||||
# VOLCENGINE (Doubao)
|
||||
ENABLE_VOLCENGINE: bool = False
|
||||
VOLCENGINE_API_KEY: str | None = None
|
||||
VOLCENGINE_API_BASE: str = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
VOLCENGINE_CUA_LLM_KEY: str = "VOLCENGINE_DOUBAO_1_5_THINKING_VISION_PRO"
|
||||
|
||||
# OPENAI COMPATIBLE
|
||||
OPENAI_COMPATIBLE_MODEL_NAME: str | None = None
|
||||
|
||||
@@ -404,7 +404,7 @@ class ForgeAgent:
|
||||
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
|
||||
if not llm_caller:
|
||||
# create a new UI-TARS llm_caller
|
||||
llm_key = task.llm_key or settings.UI_TARS_LLM_KEY
|
||||
llm_key = task.llm_key or settings.VOLCENGINE_CUA_LLM_KEY
|
||||
llm_caller = UITarsLLMCaller(llm_key=llm_key, screenshot_scaling_enabled=True)
|
||||
llm_caller.initialize_conversation(task)
|
||||
|
||||
|
||||
@@ -48,10 +48,10 @@ if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC:
|
||||
|
||||
# Add UI-TARS client setup
|
||||
UI_TARS_CLIENT = None
|
||||
if SettingsManager.get_settings().ENABLE_UI_TARS:
|
||||
if SettingsManager.get_settings().ENABLE_VOLCENGINE:
|
||||
UI_TARS_CLIENT = AsyncOpenAI(
|
||||
api_key=SettingsManager.get_settings().UI_TARS_API_KEY,
|
||||
base_url=SettingsManager.get_settings().UI_TARS_API_BASE,
|
||||
api_key=SettingsManager.get_settings().VOLCENGINE_API_KEY,
|
||||
base_url=SettingsManager.get_settings().VOLCENGINE_API_BASE,
|
||||
)
|
||||
|
||||
SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||
|
||||
@@ -198,6 +198,7 @@ class LLMAPIHandlerFactory:
|
||||
)
|
||||
if step or thought:
|
||||
try:
|
||||
# FIXME: volcengine doesn't support litellm cost calculation.
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
except Exception as e:
|
||||
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||
@@ -401,6 +402,7 @@ class LLMAPIHandlerFactory:
|
||||
|
||||
if step or thought:
|
||||
try:
|
||||
# FIXME: volcengine doesn't support litellm cost calculation.
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
except Exception as e:
|
||||
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||
@@ -746,7 +748,7 @@ class LLMCaller:
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | Any:
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse:
|
||||
if self.llm_key and "ANTHROPIC" in self.llm_key:
|
||||
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
|
||||
|
||||
@@ -802,14 +804,14 @@ class LLMCaller:
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> Any:
|
||||
) -> UITarsResponse:
|
||||
"""Custom UI-TARS API call using OpenAI client with VolcEngine endpoint."""
|
||||
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 400
|
||||
model_name = self.llm_config.model_name
|
||||
model_name = self.llm_config.model_name.replace("volcengine/", "")
|
||||
|
||||
if not app.UI_TARS_CLIENT:
|
||||
raise ValueError(
|
||||
"UI_TARS_CLIENT not initialized. Please ensure ENABLE_UI_TARS=true and UI_TARS_API_KEY is set."
|
||||
"UI_TARS_CLIENT not initialized. Please ensure ENABLE_VOLCENGINE=true and VOLCENGINE_API_KEY is set."
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
@@ -851,39 +853,18 @@ class LLMCaller:
|
||||
return response
|
||||
|
||||
async def get_call_stats(
|
||||
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | dict[str, Any] | Any
|
||||
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse
|
||||
) -> LLMCallStats:
|
||||
empty_call_stats = LLMCallStats()
|
||||
|
||||
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
|
||||
if hasattr(response, "usage") and hasattr(response, "choices") and hasattr(response, "model"):
|
||||
usage = response.usage
|
||||
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
|
||||
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
|
||||
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
|
||||
llm_cost = input_token_cost + output_token_cost
|
||||
|
||||
if isinstance(response, UITarsResponse):
|
||||
ui_tars_usage = response.usage
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=usage.get("prompt_tokens", 0),
|
||||
output_tokens=usage.get("completion_tokens", 0),
|
||||
cached_tokens=0, # UI-TARS doesn't have cached tokens
|
||||
reasoning_tokens=0,
|
||||
)
|
||||
|
||||
# Handle UI-TARS response (dict format - fallback)
|
||||
if isinstance(response, dict) and "choices" in response and "usage" in response:
|
||||
usage = response["usage"]
|
||||
# Use Doubao pricing: ¥0.8/1M input, ¥2/1M output (convert to USD: ~$0.11/$0.28)
|
||||
input_token_cost = (0.11 / 1000000) * usage.get("prompt_tokens", 0)
|
||||
output_token_cost = (0.28 / 1000000) * usage.get("completion_tokens", 0)
|
||||
llm_cost = input_token_cost + output_token_cost
|
||||
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=usage.get("prompt_tokens", 0),
|
||||
output_tokens=usage.get("completion_tokens", 0),
|
||||
cached_tokens=0, # UI-TARS doesn't have cached tokens
|
||||
llm_cost=0, # TODO: calculate the cost according to the price: https://www.volcengine.com/docs/82379/1544106
|
||||
input_tokens=ui_tars_usage.get("prompt_tokens", 0),
|
||||
output_tokens=ui_tars_usage.get("completion_tokens", 0),
|
||||
cached_tokens=0, # only part of model support cached tokens
|
||||
reasoning_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -568,16 +568,46 @@ if settings.ENABLE_AZURE_O3:
|
||||
max_completion_tokens=100000,
|
||||
),
|
||||
)
|
||||
if settings.ENABLE_UI_TARS:
|
||||
if settings.ENABLE_VOLCENGINE:
|
||||
LLMConfigRegistry.register_config(
|
||||
"UI_TARS_SEED1_5_VL",
|
||||
"VOLCENGINE_DOUBAO_SEED_1_6",
|
||||
LLMConfig(
|
||||
settings.UI_TARS_MODEL,
|
||||
["UI_TARS_API_KEY"],
|
||||
"volcengine/doubao-seed-1.6-250615",
|
||||
["VOLCENGINE_API_KEY"],
|
||||
litellm_params=LiteLLMParams(
|
||||
api_base=settings.VOLCENGINE_API_BASE,
|
||||
api_key=settings.VOLCENGINE_API_KEY,
|
||||
),
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
),
|
||||
)
|
||||
|
||||
LLMConfigRegistry.register_config(
|
||||
"VOLCENGINE_DOUBAO_SEED_1_6_FLASH",
|
||||
LLMConfig(
|
||||
"volcengine/doubao-seed-1.6-flash-250615",
|
||||
["VOLCENGINE_API_KEY"],
|
||||
litellm_params=LiteLLMParams(
|
||||
api_base=settings.VOLCENGINE_API_BASE,
|
||||
api_key=settings.VOLCENGINE_API_KEY,
|
||||
),
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
),
|
||||
)
|
||||
|
||||
LLMConfigRegistry.register_config(
|
||||
"VOLCENGINE_DOUBAO_1_5_THINKING_VISION_PRO",
|
||||
LLMConfig(
|
||||
"volcengine/doubao-1-5-thinking-vision-pro-250428",
|
||||
["VOLCENGINE_API_KEY"],
|
||||
litellm_params=LiteLLMParams(
|
||||
api_base=settings.VOLCENGINE_API_BASE,
|
||||
api_key=settings.VOLCENGINE_API_KEY,
|
||||
),
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
max_tokens=400,
|
||||
temperature=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ class UITarsLLMCaller(LLMCaller):
|
||||
# Handle None case for navigation_goal
|
||||
instruction = task.navigation_goal or "Default navigation task"
|
||||
system_prompt = _build_system_prompt(instruction)
|
||||
self.message_history = [{"role": "user", "content": system_prompt}]
|
||||
self.message_history: list = [{"role": "user", "content": system_prompt}]
|
||||
self._conversation_initialized = True
|
||||
LOG.debug("Initialized UI-TARS conversation", task_id=task.task_id)
|
||||
|
||||
|
||||
@@ -3,21 +3,25 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from anthropic import BaseModel
|
||||
|
||||
class UITarsResponse:
|
||||
|
||||
class Message:
|
||||
def __init__(self, content: str):
|
||||
self.content = content
|
||||
self.role = "assistant"
|
||||
|
||||
|
||||
class Choice:
|
||||
def __init__(self, content: str):
|
||||
self.message = Message(content)
|
||||
|
||||
|
||||
class UITarsResponse(BaseModel):
|
||||
"""A response object that mimics the ModelResponse interface for UI-TARS API responses."""
|
||||
|
||||
def __init__(self, content: str, model: str):
|
||||
# Create choice objects with proper nested structure for parse_api_response
|
||||
class Message:
|
||||
def __init__(self, content: str):
|
||||
self.content = content
|
||||
self.role = "assistant"
|
||||
|
||||
class Choice:
|
||||
def __init__(self, content: str):
|
||||
self.message = Message(content)
|
||||
|
||||
self.choices = [Choice(content)]
|
||||
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
self.model = model
|
||||
|
||||
Reference in New Issue
Block a user