support openrouter qwen model (#3630)
This commit is contained in:
@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
|
||||
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
|
||||
from jinja2 import Template
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -224,6 +225,10 @@ class LLMAPIHandlerFactory:
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
|
||||
@@ -499,6 +504,11 @@ class LLMAPIHandlerFactory:
|
||||
if LLMConfigRegistry.is_router_config(llm_key):
|
||||
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
|
||||
|
||||
# For OpenRouter models, use LLMCaller which has native OpenRouter support
|
||||
if llm_key.startswith("openrouter/"):
|
||||
llm_caller = LLMCaller(llm_key=llm_key, base_parameters=base_parameters)
|
||||
return llm_caller.call
|
||||
|
||||
assert isinstance(llm_config, LLMConfig)
|
||||
|
||||
@TraceManager.traced_async(tags=[llm_key], ignore_inputs=["prompt", "screenshots", "parameters"])
|
||||
@@ -512,6 +522,10 @@ class LLMAPIHandlerFactory:
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
) -> dict[str, Any]:
|
||||
start_time = time.time()
|
||||
active_parameters = base_parameters or {}
|
||||
@@ -827,6 +841,10 @@ class LLMAPIHandlerFactory:
|
||||
class LLMCaller:
|
||||
"""
|
||||
An LLMCaller instance defines the LLM configs and keeps the chat history if needed.
|
||||
|
||||
A couple of things to keep in mind:
|
||||
- LLMCaller should be compatible with litellm interface
|
||||
- LLMCaller should also support models that are not supported by litellm
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -835,6 +853,7 @@ class LLMCaller:
|
||||
screenshot_scaling_enabled: bool = False,
|
||||
base_parameters: dict[str, Any] | None = None,
|
||||
):
|
||||
self.original_llm_key = llm_key
|
||||
self.llm_key = llm_key
|
||||
self.llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
self.base_parameters = base_parameters
|
||||
@@ -846,6 +865,11 @@ class LLMCaller:
|
||||
if screenshot_scaling_enabled:
|
||||
self.screenshot_resize_target_dimension = get_resize_target_dimension(self.browser_window_dimension)
|
||||
|
||||
self.openai_client = None
|
||||
if self.llm_key.startswith("openrouter/"):
|
||||
self.llm_key = self.llm_key.replace("openrouter/", "")
|
||||
self.openai_client = AsyncOpenAI(api_key=settings.OPENROUTER_API_KEY, base_url=settings.OPENROUTER_API_BASE)
|
||||
|
||||
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
|
||||
self.current_tool_results.append(tool_result)
|
||||
|
||||
@@ -862,11 +886,11 @@ class LLMCaller:
|
||||
ai_suggestion: AISuggestion | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
organization_id: str | None = None,
|
||||
**extra_parameters: Any,
|
||||
) -> dict[str, Any]:
|
||||
start_time = time.perf_counter()
|
||||
@@ -1081,6 +1105,34 @@ class LLMCaller:
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse:
|
||||
if self.openai_client:
|
||||
# Extract OpenRouter-specific parameters
|
||||
extra_headers = {}
|
||||
if settings.SKYVERN_APP_URL:
|
||||
extra_headers["HTTP-Referer"] = settings.SKYVERN_APP_URL
|
||||
extra_headers["X-Title"] = "Skyvern"
|
||||
|
||||
# Filter out parameters that OpenAI client doesn't support
|
||||
openai_params = {}
|
||||
if "max_completion_tokens" in active_parameters:
|
||||
openai_params["max_completion_tokens"] = active_parameters["max_completion_tokens"]
|
||||
elif "max_tokens" in active_parameters:
|
||||
openai_params["max_tokens"] = active_parameters["max_tokens"]
|
||||
if "temperature" in active_parameters:
|
||||
openai_params["temperature"] = active_parameters["temperature"]
|
||||
|
||||
completion = await self.openai_client.chat.completions.create(
|
||||
model=self.llm_key,
|
||||
messages=messages,
|
||||
extra_headers=extra_headers if extra_headers else None,
|
||||
timeout=timeout,
|
||||
**openai_params,
|
||||
)
|
||||
# Convert OpenAI ChatCompletion to litellm ModelResponse format
|
||||
# litellm.utils.convert_to_model_response_object expects a dict
|
||||
response_dict = completion.model_dump()
|
||||
return litellm.ModelResponse(**response_dict)
|
||||
|
||||
if self.llm_key and "ANTHROPIC" in self.llm_key:
|
||||
return await self._call_anthropic(messages, tools, timeout, **active_parameters)
|
||||
|
||||
@@ -1193,6 +1245,8 @@ class LLMCaller:
|
||||
self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse
|
||||
) -> LLMCallStats:
|
||||
empty_call_stats = LLMCallStats()
|
||||
if self.original_llm_key.startswith("openrouter/"):
|
||||
return empty_call_stats
|
||||
|
||||
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
|
||||
if isinstance(response, UITarsResponse):
|
||||
|
||||
@@ -7,6 +7,7 @@ from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.utils.image_resizer import Resolution
|
||||
|
||||
|
||||
class LiteLLMParams(TypedDict, total=False):
|
||||
@@ -96,6 +97,10 @@ class LLMAPIHandler(Protocol):
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
) -> Awaitable[dict[str, Any]]: ...
|
||||
|
||||
|
||||
@@ -109,5 +114,9 @@ async def dummy_llm_api_handler(
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
organization_id: str | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
window_dimension: Resolution | None = None,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")
|
||||
|
||||
Reference in New Issue
Block a user