From ea92ca4c51e2261260a556392d09b5626eb1fbd3 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Mon, 6 Oct 2025 18:55:52 -0700 Subject: [PATCH] support openrouter qwen model (#3630) --- skyvern/config.py | 2 +- .../forge/sdk/api/llm/api_handler_factory.py | 56 ++++++++++++++++++- skyvern/forge/sdk/api/llm/models.py | 9 +++ .../unit_tests/test_openrouter_integration.py | 33 +++++++---- 4 files changed, 88 insertions(+), 12 deletions(-) diff --git a/skyvern/config.py b/skyvern/config.py index 551f3905..d2c8ce1e 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -270,7 +270,7 @@ class Settings(BaseSettings): ENABLE_OPENROUTER: bool = False OPENROUTER_API_KEY: str | None = None OPENROUTER_MODEL: str | None = None - OPENROUTER_API_BASE: str = "https://api.openrouter.ai/v1" + OPENROUTER_API_BASE: str = "https://openrouter.ai/api/v1" # GROQ ENABLE_GROQ: bool = False diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 524003c5..b073f6f3 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage from jinja2 import Template from litellm.utils import CustomStreamWrapper, ModelResponse +from openai import AsyncOpenAI from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel @@ -224,6 +225,10 @@ class LLMAPIHandlerFactory: screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, organization_id: str | None = None, + tools: list | None = None, + use_message_history: bool = False, + raw_response: bool = False, + window_dimension: Resolution | None = None, ) -> dict[str, Any]: """ Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision. @@ -499,6 +504,11 @@ class LLMAPIHandlerFactory: if LLMConfigRegistry.is_router_config(llm_key): return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key) + # For OpenRouter models, use LLMCaller which has native OpenRouter support + if llm_key.startswith("openrouter/"): + llm_caller = LLMCaller(llm_key=llm_key, base_parameters=base_parameters) + return llm_caller.call + assert isinstance(llm_config, LLMConfig) @TraceManager.traced_async(tags=[llm_key], ignore_inputs=["prompt", "screenshots", "parameters"]) @@ -512,6 +522,10 @@ class LLMAPIHandlerFactory: screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, organization_id: str | None = None, + tools: list | None = None, + use_message_history: bool = False, + raw_response: bool = False, + window_dimension: Resolution | None = None, ) -> dict[str, Any]: start_time = time.time() active_parameters = base_parameters or {} @@ -827,6 +841,10 @@ class LLMAPIHandlerFactory: class LLMCaller: """ An LLMCaller instance defines the LLM configs and keeps the chat history if needed. + + A couple of things to keep in mind: + - LLMCaller should be compatible with litellm interface + - LLMCaller should also support models that are not supported by litellm """ def __init__( @@ -835,6 +853,7 @@ class LLMCaller: screenshot_scaling_enabled: bool = False, base_parameters: dict[str, Any] | None = None, ): + self.original_llm_key = llm_key self.llm_key = llm_key self.llm_config = LLMConfigRegistry.get_config(llm_key) self.base_parameters = base_parameters @@ -846,6 +865,11 @@ class LLMCaller: if screenshot_scaling_enabled: self.screenshot_resize_target_dimension = get_resize_target_dimension(self.browser_window_dimension) + self.openai_client = None + if self.llm_key.startswith("openrouter/"): + self.llm_key = self.llm_key.replace("openrouter/", "") + self.openai_client = AsyncOpenAI(api_key=settings.OPENROUTER_API_KEY, base_url=settings.OPENROUTER_API_BASE) + def add_tool_result(self, tool_result: dict[str, Any]) -> None: self.current_tool_results.append(tool_result) @@ -862,11 +886,11 @@ class LLMCaller: ai_suggestion: AISuggestion | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, + organization_id: str | None = None, tools: list | None = None, use_message_history: bool = False, raw_response: bool = False, window_dimension: Resolution | None = None, - organization_id: str | None = None, **extra_parameters: Any, ) -> dict[str, Any]: start_time = time.perf_counter() @@ -1081,6 +1105,34 @@ class LLMCaller: timeout: float = settings.LLM_CONFIG_TIMEOUT, **active_parameters: dict[str, Any], ) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse: + if self.openai_client: + # Extract OpenRouter-specific parameters + extra_headers = {} + if settings.SKYVERN_APP_URL: + extra_headers["HTTP-Referer"] = settings.SKYVERN_APP_URL + extra_headers["X-Title"] = "Skyvern" + + # Filter out parameters that OpenAI client doesn't support + openai_params = {} + if "max_completion_tokens" in active_parameters: + openai_params["max_completion_tokens"] = active_parameters["max_completion_tokens"] + elif "max_tokens" in active_parameters: + openai_params["max_tokens"] = active_parameters["max_tokens"] + if "temperature" in active_parameters: + openai_params["temperature"] = active_parameters["temperature"] + + completion = await self.openai_client.chat.completions.create( + model=self.llm_key, + messages=messages, + extra_headers=extra_headers if extra_headers else None, + timeout=timeout, + **openai_params, + ) + # Convert OpenAI ChatCompletion to litellm ModelResponse format + # litellm.utils.convert_to_model_response_object expects a dict + response_dict = completion.model_dump() + return litellm.ModelResponse(**response_dict) + if self.llm_key and "ANTHROPIC" in self.llm_key: return await self._call_anthropic(messages, tools, timeout, **active_parameters) @@ -1193,6 +1245,8 @@ class LLMCaller: self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse ) -> LLMCallStats: empty_call_stats = LLMCallStats() + if self.original_llm_key.startswith("openrouter/"): + return empty_call_stats # Handle UI-TARS response (UITarsResponse object from _call_ui_tars) if isinstance(response, UITarsResponse): diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index b059a4c7..655a1e99 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -7,6 +7,7 @@ from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought from skyvern.forge.sdk.settings_manager import SettingsManager +from skyvern.utils.image_resizer import Resolution class LiteLLMParams(TypedDict, total=False): @@ -96,6 +97,10 @@ class LLMAPIHandler(Protocol): screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, organization_id: str | None = None, + tools: list | None = None, + use_message_history: bool = False, + raw_response: bool = False, + window_dimension: Resolution | None = None, ) -> Awaitable[dict[str, Any]]: ... @@ -109,5 +114,9 @@ async def dummy_llm_api_handler( screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, organization_id: str | None = None, + tools: list | None = None, + use_message_history: bool = False, + raw_response: bool = False, + window_dimension: Resolution | None = None, ) -> dict[str, Any]: raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.") diff --git a/tests/unit_tests/test_openrouter_integration.py b/tests/unit_tests/test_openrouter_integration.py index 97faa89a..93329d5b 100644 --- a/tests/unit_tests/test_openrouter_integration.py +++ b/tests/unit_tests/test_openrouter_integration.py @@ -1,10 +1,11 @@ import importlib import json import types -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest +from skyvern import config from skyvern.config import Settings from skyvern.forge import app from skyvern.forge.sdk.api.llm import api_handler_factory, config_registry @@ -19,6 +20,9 @@ class DummyResponse(dict): def model_dump_json(self, indent: int = 2): return json.dumps(self, indent=indent) + def model_dump(self): + return self + class DummyArtifactManager: async def create_llm_artifact(self, *args, **kwargs): @@ -49,27 +53,36 @@ async def test_openrouter_basic_completion(monkeypatch): @pytest.mark.asyncio async def test_openrouter_dynamic_model(monkeypatch): - settings = Settings( - ENABLE_OPENROUTER=True, - OPENROUTER_API_KEY="key", - OPENROUTER_MODEL="base-model", - LLM_KEY="OPENROUTER", - ) - SettingsManager.set_settings(settings) + # Update settings via monkeypatch to ensure config_registry sees them + + monkeypatch.setattr(config.settings, "ENABLE_OPENROUTER", True) + monkeypatch.setattr(config.settings, "OPENROUTER_API_KEY", "key") + monkeypatch.setattr(config.settings, "OPENROUTER_MODEL", "base-model") + monkeypatch.setattr(config.settings, "OPENROUTER_API_BASE", "https://openrouter.ai/api/v1") + + # Clear existing configs before reload + config_registry.LLMConfigRegistry._configs.clear() importlib.reload(config_registry) monkeypatch.setattr(app, "ARTIFACT_MANAGER", DummyArtifactManager()) + + # Mock the AsyncOpenAI client async_mock = AsyncMock(return_value=DummyResponse('{"status": "ok"}')) - monkeypatch.setattr(api_handler_factory.litellm, "acompletion", async_mock) + mock_client = MagicMock() + mock_client.chat.completions.create = async_mock + + # Patch AsyncOpenAI to return our mock client + monkeypatch.setattr(api_handler_factory, "AsyncOpenAI", lambda **kwargs: mock_client) base_handler = api_handler_factory.LLMAPIHandlerFactory.get_llm_api_handler("OPENROUTER") override_handler = api_handler_factory.LLMAPIHandlerFactory.get_override_llm_api_handler( "openrouter/other-model", default=base_handler ) + result = await override_handler("hi", "test") assert result == {"status": "ok"} called_model = async_mock.call_args.kwargs.get("model") - assert called_model == "openrouter/other-model" + assert called_model == "other-model" @pytest.mark.asyncio