support openrouter qwen model (#3630)
This commit is contained in:
@@ -270,7 +270,7 @@ class Settings(BaseSettings):
|
|||||||
ENABLE_OPENROUTER: bool = False
|
ENABLE_OPENROUTER: bool = False
|
||||||
OPENROUTER_API_KEY: str | None = None
|
OPENROUTER_API_KEY: str | None = None
|
||||||
OPENROUTER_MODEL: str | None = None
|
OPENROUTER_MODEL: str | None = None
|
||||||
OPENROUTER_API_BASE: str = "https://api.openrouter.ai/v1"
|
OPENROUTER_API_BASE: str = "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
# GROQ
|
# GROQ
|
||||||
ENABLE_GROQ: bool = False
|
ENABLE_GROQ: bool = False
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
|
|||||||
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
|
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||||
|
from openai import AsyncOpenAI
|
||||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -224,6 +225,10 @@ class LLMAPIHandlerFactory:
|
|||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
organization_id: str | None = None,
|
organization_id: str | None = None,
|
||||||
|
tools: list | None = None,
|
||||||
|
use_message_history: bool = False,
|
||||||
|
raw_response: bool = False,
|
||||||
|
window_dimension: Resolution | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
|
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
|
||||||
@@ -499,6 +504,11 @@ class LLMAPIHandlerFactory:
|
|||||||
if LLMConfigRegistry.is_router_config(llm_key):
|
if LLMConfigRegistry.is_router_config(llm_key):
|
||||||
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
|
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
|
||||||
|
|
||||||
|
# For OpenRouter models, use LLMCaller which has native OpenRouter support
|
||||||
|
if llm_key.startswith("openrouter/"):
|
||||||
|
llm_caller = LLMCaller(llm_key=llm_key, base_parameters=base_parameters)
|
||||||
|
return llm_caller.call
|
||||||
|
|
||||||
assert isinstance(llm_config, LLMConfig)
|
assert isinstance(llm_config, LLMConfig)
|
||||||
|
|
||||||
@TraceManager.traced_async(tags=[llm_key], ignore_inputs=["prompt", "screenshots", "parameters"])
|
@TraceManager.traced_async(tags=[llm_key], ignore_inputs=["prompt", "screenshots", "parameters"])
|
||||||
@@ -512,6 +522,10 @@ class LLMAPIHandlerFactory:
|
|||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
organization_id: str | None = None,
|
organization_id: str | None = None,
|
||||||
|
tools: list | None = None,
|
||||||
|
use_message_history: bool = False,
|
||||||
|
raw_response: bool = False,
|
||||||
|
window_dimension: Resolution | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
active_parameters = base_parameters or {}
|
active_parameters = base_parameters or {}
|
||||||
@@ -827,6 +841,10 @@ class LLMAPIHandlerFactory:
|
|||||||
class LLMCaller:
|
class LLMCaller:
|
||||||
"""
|
"""
|
||||||
An LLMCaller instance defines the LLM configs and keeps the chat history if needed.
|
An LLMCaller instance defines the LLM configs and keeps the chat history if needed.
|
||||||
|
|
||||||
|
A couple of things to keep in mind:
|
||||||
|
- LLMCaller should be compatible with litellm interface
|
||||||
|
- LLMCaller should also support models that are not supported by litellm
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -835,6 +853,7 @@ class LLMCaller:
|
|||||||
screenshot_scaling_enabled: bool = False,
|
screenshot_scaling_enabled: bool = False,
|
||||||
base_parameters: dict[str, Any] | None = None,
|
base_parameters: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
|
self.original_llm_key = llm_key
|
||||||
self.llm_key = llm_key
|
self.llm_key = llm_key
|
||||||
self.llm_config = LLMConfigRegistry.get_config(llm_key)
|
self.llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||||
self.base_parameters = base_parameters
|
self.base_parameters = base_parameters
|
||||||
@@ -846,6 +865,11 @@ class LLMCaller:
|
|||||||
if screenshot_scaling_enabled:
|
if screenshot_scaling_enabled:
|
||||||
self.screenshot_resize_target_dimension = get_resize_target_dimension(self.browser_window_dimension)
|
self.screenshot_resize_target_dimension = get_resize_target_dimension(self.browser_window_dimension)
|
||||||
|
|
||||||
|
self.openai_client = None
|
||||||
|
if self.llm_key.startswith("openrouter/"):
|
||||||
|
self.llm_key = self.llm_key.replace("openrouter/", "")
|
||||||
|
self.openai_client = AsyncOpenAI(api_key=settings.OPENROUTER_API_KEY, base_url=settings.OPENROUTER_API_BASE)
|
||||||
|
|
||||||
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
|
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
|
||||||
self.current_tool_results.append(tool_result)
|
self.current_tool_results.append(tool_result)
|
||||||
|
|
||||||
@@ -862,11 +886,11 @@ class LLMCaller:
|
|||||||
ai_suggestion: AISuggestion | None = None,
|
ai_suggestion: AISuggestion | None = None,
|
||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
|
organization_id: str | None = None,
|
||||||
tools: list | None = None,
|
tools: list | None = None,
|
||||||
use_message_history: bool = False,
|
use_message_history: bool = False,
|
||||||
raw_response: bool = False,
|
raw_response: bool = False,
|
||||||
window_dimension: Resolution | None = None,
|
window_dimension: Resolution | None = None,
|
||||||
organization_id: str | None = None,
|
|
||||||
**extra_parameters: Any,
|
**extra_parameters: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@@ -1081,6 +1105,34 @@ class LLMCaller:
|
|||||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||||
**active_parameters: dict[str, Any],
|
**active_parameters: dict[str, Any],
|
||||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse:
|
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse:
|
||||||
|
if self.openai_client:
|
||||||
|
# Extract OpenRouter-specific parameters
|
||||||
|
extra_headers = {}
|
||||||
|
if settings.SKYVERN_APP_URL:
|
||||||
|
extra_headers["HTTP-Referer"] = settings.SKYVERN_APP_URL
|
||||||
|
extra_headers["X-Title"] = "Skyvern"
|
||||||
|
|
||||||
|
# Filter out parameters that OpenAI client doesn't support
|
||||||
|
openai_params = {}
|
||||||
|
if "max_completion_tokens" in active_parameters:
|
||||||
|
openai_params["max_completion_tokens"] = active_parameters["max_completion_tokens"]
|
||||||
|
elif "max_tokens" in active_parameters:
|
||||||
|
openai_params["max_tokens"] = active_parameters["max_tokens"]
|
||||||
|
if "temperature" in active_parameters:
|
||||||
|
openai_params["temperature"] = active_parameters["temperature"]
|
||||||
|
|
||||||
|
completion = await self.openai_client.chat.completions.create(
|
||||||
|
model=self.llm_key,
|
||||||
|
messages=messages,
|
||||||
|
extra_headers=extra_headers if extra_headers else None,
|
||||||
|
timeout=timeout,
|
||||||
|
**openai_params,
|
||||||
|
)
|
||||||
|
# Convert OpenAI ChatCompletion to litellm ModelResponse format
|
||||||
|
# litellm.utils.convert_to_model_response_object expects a dict
|
||||||
|
response_dict = completion.model_dump()
|
||||||
|
return litellm.ModelResponse(**response_dict)
|
||||||
|
|
||||||
if self.llm_key and "ANTHROPIC" in self.llm_key:
|
if self.llm_key and "ANTHROPIC" in self.llm_key:
|
||||||
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
|
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
|
||||||
|
|
||||||
@@ -1193,6 +1245,8 @@ class LLMCaller:
|
|||||||
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse
|
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse
|
||||||
) -> LLMCallStats:
|
) -> LLMCallStats:
|
||||||
empty_call_stats = LLMCallStats()
|
empty_call_stats = LLMCallStats()
|
||||||
|
if self.original_llm_key.startswith("openrouter/"):
|
||||||
|
return empty_call_stats
|
||||||
|
|
||||||
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
|
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
|
||||||
if isinstance(response, UITarsResponse):
|
if isinstance(response, UITarsResponse):
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from skyvern.forge.sdk.models import Step
|
|||||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
|
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
|
||||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
|
from skyvern.utils.image_resizer import Resolution
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMParams(TypedDict, total=False):
|
class LiteLLMParams(TypedDict, total=False):
|
||||||
@@ -96,6 +97,10 @@ class LLMAPIHandler(Protocol):
|
|||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
organization_id: str | None = None,
|
organization_id: str | None = None,
|
||||||
|
tools: list | None = None,
|
||||||
|
use_message_history: bool = False,
|
||||||
|
raw_response: bool = False,
|
||||||
|
window_dimension: Resolution | None = None,
|
||||||
) -> Awaitable[dict[str, Any]]: ...
|
) -> Awaitable[dict[str, Any]]: ...
|
||||||
|
|
||||||
|
|
||||||
@@ -109,5 +114,9 @@ async def dummy_llm_api_handler(
|
|||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
organization_id: str | None = None,
|
organization_id: str | None = None,
|
||||||
|
tools: list | None = None,
|
||||||
|
use_message_history: bool = False,
|
||||||
|
raw_response: bool = False,
|
||||||
|
window_dimension: Resolution | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")
|
raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import types
|
import types
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from skyvern import config
|
||||||
from skyvern.config import Settings
|
from skyvern.config import Settings
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.sdk.api.llm import api_handler_factory, config_registry
|
from skyvern.forge.sdk.api.llm import api_handler_factory, config_registry
|
||||||
@@ -19,6 +20,9 @@ class DummyResponse(dict):
|
|||||||
def model_dump_json(self, indent: int = 2):
|
def model_dump_json(self, indent: int = 2):
|
||||||
return json.dumps(self, indent=indent)
|
return json.dumps(self, indent=indent)
|
||||||
|
|
||||||
|
def model_dump(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class DummyArtifactManager:
|
class DummyArtifactManager:
|
||||||
async def create_llm_artifact(self, *args, **kwargs):
|
async def create_llm_artifact(self, *args, **kwargs):
|
||||||
@@ -49,27 +53,36 @@ async def test_openrouter_basic_completion(monkeypatch):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openrouter_dynamic_model(monkeypatch):
|
async def test_openrouter_dynamic_model(monkeypatch):
|
||||||
settings = Settings(
|
# Update settings via monkeypatch to ensure config_registry sees them
|
||||||
ENABLE_OPENROUTER=True,
|
|
||||||
OPENROUTER_API_KEY="key",
|
monkeypatch.setattr(config.settings, "ENABLE_OPENROUTER", True)
|
||||||
OPENROUTER_MODEL="base-model",
|
monkeypatch.setattr(config.settings, "OPENROUTER_API_KEY", "key")
|
||||||
LLM_KEY="OPENROUTER",
|
monkeypatch.setattr(config.settings, "OPENROUTER_MODEL", "base-model")
|
||||||
)
|
monkeypatch.setattr(config.settings, "OPENROUTER_API_BASE", "https://openrouter.ai/api/v1")
|
||||||
SettingsManager.set_settings(settings)
|
|
||||||
|
# Clear existing configs before reload
|
||||||
|
config_registry.LLMConfigRegistry._configs.clear()
|
||||||
importlib.reload(config_registry)
|
importlib.reload(config_registry)
|
||||||
|
|
||||||
monkeypatch.setattr(app, "ARTIFACT_MANAGER", DummyArtifactManager())
|
monkeypatch.setattr(app, "ARTIFACT_MANAGER", DummyArtifactManager())
|
||||||
|
|
||||||
|
# Mock the AsyncOpenAI client
|
||||||
async_mock = AsyncMock(return_value=DummyResponse('{"status": "ok"}'))
|
async_mock = AsyncMock(return_value=DummyResponse('{"status": "ok"}'))
|
||||||
monkeypatch.setattr(api_handler_factory.litellm, "acompletion", async_mock)
|
mock_client = MagicMock()
|
||||||
|
mock_client.chat.completions.create = async_mock
|
||||||
|
|
||||||
|
# Patch AsyncOpenAI to return our mock client
|
||||||
|
monkeypatch.setattr(api_handler_factory, "AsyncOpenAI", lambda **kwargs: mock_client)
|
||||||
|
|
||||||
base_handler = api_handler_factory.LLMAPIHandlerFactory.get_llm_api_handler("OPENROUTER")
|
base_handler = api_handler_factory.LLMAPIHandlerFactory.get_llm_api_handler("OPENROUTER")
|
||||||
override_handler = api_handler_factory.LLMAPIHandlerFactory.get_override_llm_api_handler(
|
override_handler = api_handler_factory.LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||||
"openrouter/other-model", default=base_handler
|
"openrouter/other-model", default=base_handler
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await override_handler("hi", "test")
|
result = await override_handler("hi", "test")
|
||||||
assert result == {"status": "ok"}
|
assert result == {"status": "ok"}
|
||||||
called_model = async_mock.call_args.kwargs.get("model")
|
called_model = async_mock.call_args.kwargs.get("model")
|
||||||
assert called_model == "openrouter/other-model"
|
assert called_model == "other-model"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user