support openrouter qwen model (#3630)
This commit is contained in:
@@ -270,7 +270,7 @@ class Settings(BaseSettings):
|
||||
ENABLE_OPENROUTER: bool = False
|
||||
OPENROUTER_API_KEY: str | None = None
|
||||
OPENROUTER_MODEL: str | None = None
|
||||
OPENROUTER_API_BASE: str = "https://api.openrouter.ai/v1"
|
||||
OPENROUTER_API_BASE: str = "https://openrouter.ai/api/v1"
|
||||
|
||||
# GROQ
|
||||
ENABLE_GROQ: bool = False
|
||||
|
||||
@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
|
||||
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
|
||||
from jinja2 import Template
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -224,6 +225,10 @@ class LLMAPIHandlerFactory:
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
|
||||
@@ -499,6 +504,11 @@ class LLMAPIHandlerFactory:
|
||||
if LLMConfigRegistry.is_router_config(llm_key):
|
||||
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
|
||||
|
||||
# For OpenRouter models, use LLMCaller which has native OpenRouter support
|
||||
if llm_key.startswith("openrouter/"):
|
||||
llm_caller = LLMCaller(llm_key=llm_key, base_parameters=base_parameters)
|
||||
return llm_caller.call
|
||||
|
||||
assert isinstance(llm_config, LLMConfig)
|
||||
|
||||
@TraceManager.traced_async(tags=[llm_key], ignore_inputs=["prompt", "screenshots", "parameters"])
|
||||
@@ -512,6 +522,10 @@ class LLMAPIHandlerFactory:
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
) -> dict[str, Any]:
|
||||
start_time = time.time()
|
||||
active_parameters = base_parameters or {}
|
||||
@@ -827,6 +841,10 @@ class LLMAPIHandlerFactory:
|
||||
class LLMCaller:
|
||||
"""
|
||||
An LLMCaller instance defines the LLM configs and keeps the chat history if needed.
|
||||
|
||||
A couple of things to keep in mind:
|
||||
- LLMCaller should be compatible with litellm interface
|
||||
- LLMCaller should also support models that are not supported by litellm
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -835,6 +853,7 @@ class LLMCaller:
|
||||
screenshot_scaling_enabled: bool = False,
|
||||
base_parameters: dict[str, Any] | None = None,
|
||||
):
|
||||
self.original_llm_key = llm_key
|
||||
self.llm_key = llm_key
|
||||
self.llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
self.base_parameters = base_parameters
|
||||
@@ -846,6 +865,11 @@ class LLMCaller:
|
||||
if screenshot_scaling_enabled:
|
||||
self.screenshot_resize_target_dimension = get_resize_target_dimension(self.browser_window_dimension)
|
||||
|
||||
self.openai_client = None
|
||||
if self.llm_key.startswith("openrouter/"):
|
||||
self.llm_key = self.llm_key.replace("openrouter/", "")
|
||||
self.openai_client = AsyncOpenAI(api_key=settings.OPENROUTER_API_KEY, base_url=settings.OPENROUTER_API_BASE)
|
||||
|
||||
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
|
||||
self.current_tool_results.append(tool_result)
|
||||
|
||||
@@ -862,11 +886,11 @@ class LLMCaller:
|
||||
ai_suggestion: AISuggestion | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
organization_id: str | None = None,
|
||||
**extra_parameters: Any,
|
||||
) -> dict[str, Any]:
|
||||
start_time = time.perf_counter()
|
||||
@@ -1081,6 +1105,34 @@ class LLMCaller:
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse:
|
||||
if self.openai_client:
|
||||
# Extract OpenRouter-specific parameters
|
||||
extra_headers = {}
|
||||
if settings.SKYVERN_APP_URL:
|
||||
extra_headers["HTTP-Referer"] = settings.SKYVERN_APP_URL
|
||||
extra_headers["X-Title"] = "Skyvern"
|
||||
|
||||
# Filter out parameters that OpenAI client doesn't support
|
||||
openai_params = {}
|
||||
if "max_completion_tokens" in active_parameters:
|
||||
openai_params["max_completion_tokens"] = active_parameters["max_completion_tokens"]
|
||||
elif "max_tokens" in active_parameters:
|
||||
openai_params["max_tokens"] = active_parameters["max_tokens"]
|
||||
if "temperature" in active_parameters:
|
||||
openai_params["temperature"] = active_parameters["temperature"]
|
||||
|
||||
completion = await self.openai_client.chat.completions.create(
|
||||
model=self.llm_key,
|
||||
messages=messages,
|
||||
extra_headers=extra_headers if extra_headers else None,
|
||||
timeout=timeout,
|
||||
**openai_params,
|
||||
)
|
||||
# Convert OpenAI ChatCompletion to litellm ModelResponse format
|
||||
# litellm.utils.convert_to_model_response_object expects a dict
|
||||
response_dict = completion.model_dump()
|
||||
return litellm.ModelResponse(**response_dict)
|
||||
|
||||
if self.llm_key and "ANTHROPIC" in self.llm_key:
|
||||
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
|
||||
|
||||
@@ -1193,6 +1245,8 @@ class LLMCaller:
|
||||
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse
|
||||
) -> LLMCallStats:
|
||||
empty_call_stats = LLMCallStats()
|
||||
if self.original_llm_key.startswith("openrouter/"):
|
||||
return empty_call_stats
|
||||
|
||||
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
|
||||
if isinstance(response, UITarsResponse):
|
||||
|
||||
@@ -7,6 +7,7 @@ from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.utils.image_resizer import Resolution
|
||||
|
||||
|
||||
class LiteLLMParams(TypedDict, total=False):
|
||||
@@ -96,6 +97,10 @@ class LLMAPIHandler(Protocol):
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
) -> Awaitable[dict[str, Any]]: ...
|
||||
|
||||
|
||||
@@ -109,5 +114,9 @@ async def dummy_llm_api_handler(
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import importlib
|
||||
import json
|
||||
import types
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern import config
|
||||
from skyvern.config import Settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.llm import api_handler_factory, config_registry
|
||||
@@ -19,6 +20,9 @@ class DummyResponse(dict):
|
||||
def model_dump_json(self, indent: int = 2):
|
||||
return json.dumps(self, indent=indent)
|
||||
|
||||
def model_dump(self):
|
||||
return self
|
||||
|
||||
|
||||
class DummyArtifactManager:
|
||||
async def create_llm_artifact(self, *args, **kwargs):
|
||||
@@ -49,27 +53,36 @@ async def test_openrouter_basic_completion(monkeypatch):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openrouter_dynamic_model(monkeypatch):
|
||||
settings = Settings(
|
||||
ENABLE_OPENROUTER=True,
|
||||
OPENROUTER_API_KEY="key",
|
||||
OPENROUTER_MODEL="base-model",
|
||||
LLM_KEY="OPENROUTER",
|
||||
)
|
||||
SettingsManager.set_settings(settings)
|
||||
# Update settings via monkeypatch to ensure config_registry sees them
|
||||
|
||||
monkeypatch.setattr(config.settings, "ENABLE_OPENROUTER", True)
|
||||
monkeypatch.setattr(config.settings, "OPENROUTER_API_KEY", "key")
|
||||
monkeypatch.setattr(config.settings, "OPENROUTER_MODEL", "base-model")
|
||||
monkeypatch.setattr(config.settings, "OPENROUTER_API_BASE", "https://openrouter.ai/api/v1")
|
||||
|
||||
# Clear existing configs before reload
|
||||
config_registry.LLMConfigRegistry._configs.clear()
|
||||
importlib.reload(config_registry)
|
||||
|
||||
monkeypatch.setattr(app, "ARTIFACT_MANAGER", DummyArtifactManager())
|
||||
|
||||
# Mock the AsyncOpenAI client
|
||||
async_mock = AsyncMock(return_value=DummyResponse('{"status": "ok"}'))
|
||||
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", async_mock)
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = async_mock
|
||||
|
||||
# Patch AsyncOpenAI to return our mock client
|
||||
monkeypatch.setattr(api_handler_factory, "AsyncOpenAI", lambda **kwargs: mock_client)
|
||||
|
||||
base_handler = api_handler_factory.LLMAPIHandlerFactory.get_llm_api_handler("OPENROUTER")
|
||||
override_handler = api_handler_factory.LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
"openrouter/other-model", default=base_handler
|
||||
)
|
||||
|
||||
result = await override_handler("hi", "test")
|
||||
assert result == {"status": "ok"}
|
||||
called_model = async_mock.call_args.kwargs.get("model")
|
||||
assert called_model == "openrouter/other-model"
|
||||
assert called_model == "other-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user