support openrouter qwen model (#3630)

This commit is contained in:
Shuchang Zheng
2025-10-06 18:55:52 -07:00
committed by GitHub
parent ccc49902e8
commit ea92ca4c51
4 changed files with 88 additions and 12 deletions

View File

@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
from jinja2 import Template
from litellm.utils import CustomStreamWrapper, ModelResponse
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from pydantic import BaseModel
@@ -224,6 +225,10 @@ class LLMAPIHandlerFactory:
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
organization_id: str | None = None,
tools: list | None = None,
use_message_history: bool = False,
raw_response: bool = False,
window_dimension: Resolution | None = None,
) -> dict[str, Any]:
"""
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
@@ -499,6 +504,11 @@ class LLMAPIHandlerFactory:
if LLMConfigRegistry.is_router_config(llm_key):
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
# For OpenRouter models, use LLMCaller which has native OpenRouter support
if llm_key.startswith("openrouter/"):
llm_caller = LLMCaller(llm_key=llm_key, base_parameters=base_parameters)
return llm_caller.call
assert isinstance(llm_config, LLMConfig)
@TraceManager.traced_async(tags=[llm_key], ignore_inputs=["prompt", "screenshots", "parameters"])
@@ -512,6 +522,10 @@ class LLMAPIHandlerFactory:
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
organization_id: str | None = None,
tools: list | None = None,
use_message_history: bool = False,
raw_response: bool = False,
window_dimension: Resolution | None = None,
) -> dict[str, Any]:
start_time = time.time()
active_parameters = base_parameters or {}
@@ -827,6 +841,10 @@ class LLMAPIHandlerFactory:
class LLMCaller:
"""
An LLMCaller instance defines the LLM configs and keeps the chat history if needed.
A couple of things to keep in mind:
- LLMCaller should be compatible with litellm interface
- LLMCaller should also support models that are not supported by litellm
"""
def __init__(
@@ -835,6 +853,7 @@ class LLMCaller:
screenshot_scaling_enabled: bool = False,
base_parameters: dict[str, Any] | None = None,
):
self.original_llm_key = llm_key
self.llm_key = llm_key
self.llm_config = LLMConfigRegistry.get_config(llm_key)
self.base_parameters = base_parameters
@@ -846,6 +865,11 @@ class LLMCaller:
if screenshot_scaling_enabled:
self.screenshot_resize_target_dimension = get_resize_target_dimension(self.browser_window_dimension)
self.openai_client = None
if self.llm_key.startswith("openrouter/"):
self.llm_key = self.llm_key.replace("openrouter/", "")
self.openai_client = AsyncOpenAI(api_key=settings.OPENROUTER_API_KEY, base_url=settings.OPENROUTER_API_BASE)
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
self.current_tool_results.append(tool_result)
@@ -862,11 +886,11 @@ class LLMCaller:
ai_suggestion: AISuggestion | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
organization_id: str | None = None,
tools: list | None = None,
use_message_history: bool = False,
raw_response: bool = False,
window_dimension: Resolution | None = None,
organization_id: str | None = None,
**extra_parameters: Any,
) -> dict[str, Any]:
start_time = time.perf_counter()
@@ -1081,6 +1105,34 @@ class LLMCaller:
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse:
if self.openai_client:
# Extract OpenRouter-specific parameters
extra_headers = {}
if settings.SKYVERN_APP_URL:
extra_headers["HTTP-Referer"] = settings.SKYVERN_APP_URL
extra_headers["X-Title"] = "Skyvern"
# Filter out parameters that OpenAI client doesn't support
openai_params = {}
if "max_completion_tokens" in active_parameters:
openai_params["max_completion_tokens"] = active_parameters["max_completion_tokens"]
elif "max_tokens" in active_parameters:
openai_params["max_tokens"] = active_parameters["max_tokens"]
if "temperature" in active_parameters:
openai_params["temperature"] = active_parameters["temperature"]
completion = await self.openai_client.chat.completions.create(
model=self.llm_key,
messages=messages,
extra_headers=extra_headers if extra_headers else None,
timeout=timeout,
**openai_params,
)
# Convert OpenAI ChatCompletion to litellm ModelResponse format
# litellm.utils.convert_to_model_response_object expects a dict
response_dict = completion.model_dump()
return litellm.ModelResponse(**response_dict)
if self.llm_key and "ANTHROPIC" in self.llm_key:
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
@@ -1193,6 +1245,8 @@ class LLMCaller:
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse
) -> LLMCallStats:
empty_call_stats = LLMCallStats()
if self.original_llm_key.startswith("openrouter/"):
return empty_call_stats
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
if isinstance(response, UITarsResponse):

View File

@@ -7,6 +7,7 @@ from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.utils.image_resizer import Resolution
class LiteLLMParams(TypedDict, total=False):
@@ -96,6 +97,10 @@ class LLMAPIHandler(Protocol):
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
organization_id: str | None = None,
tools: list | None = None,
use_message_history: bool = False,
raw_response: bool = False,
window_dimension: Resolution | None = None,
) -> Awaitable[dict[str, Any]]: ...
@@ -109,5 +114,9 @@ async def dummy_llm_api_handler(
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
organization_id: str | None = None,
tools: list | None = None,
use_message_history: bool = False,
raw_response: bool = False,
window_dimension: Resolution | None = None,
) -> dict[str, Any]:
raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")