[SKV-3992] Add OPENAI_COMPATIBLE for githubcopilot.com (#3993)
This commit is contained in:
@@ -35,7 +35,12 @@ from skyvern.forge.sdk.api.llm.models import (
|
||||
LLMRouterConfig,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
|
||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
|
||||
from skyvern.forge.sdk.api.llm.utils import (
|
||||
is_image_message,
|
||||
llm_messages_builder,
|
||||
llm_messages_builder_with_history,
|
||||
parse_api_response,
|
||||
)
|
||||
from skyvern.forge.sdk.artifact.manager import BulkArtifactCreationRequest
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
@@ -223,21 +228,42 @@ class LLMAPIHandlerFactory:
|
||||
|
||||
return _normalize(left) == _normalize(right)
|
||||
|
||||
@staticmethod
|
||||
def _extract_token_counts(response: ModelResponse | CustomStreamWrapper) -> tuple[int, int, int, int]:
|
||||
"""Extract token counts from response usage information."""
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
reasoning_tokens = 0
|
||||
cached_tokens = 0
|
||||
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||
|
||||
# Extract reasoning tokens from completion_tokens_details
|
||||
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
|
||||
|
||||
# Extract cached tokens from prompt_tokens_details
|
||||
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
|
||||
if cached_token_detail:
|
||||
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
|
||||
|
||||
# Fallback: Some providers expose cache_read_input_tokens directly on usage
|
||||
if cached_tokens == 0:
|
||||
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
|
||||
return input_tokens, output_tokens, reasoning_tokens, cached_tokens
|
||||
|
||||
@staticmethod
|
||||
def _apply_thinking_budget_optimization(
|
||||
parameters: dict[str, Any], new_budget: int, llm_config: LLMConfig | LLMRouterConfig, prompt_name: str
|
||||
) -> None:
|
||||
"""Apply thinking budget optimization based on model type and LiteLLM reasoning support."""
|
||||
# Compute a safe model label and a representative model for capability checks
|
||||
model_label = getattr(llm_config, "model_name", None)
|
||||
if model_label is None and isinstance(llm_config, LLMRouterConfig):
|
||||
model_label = getattr(llm_config, "main_model_group", "router")
|
||||
check_model = model_label
|
||||
if isinstance(llm_config, LLMRouterConfig) and getattr(llm_config, "model_list", None):
|
||||
try:
|
||||
check_model = llm_config.model_list[0].model_name or model_label # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
check_model = model_label
|
||||
model_label = LLMAPIHandlerFactory._get_model_label(llm_config)
|
||||
check_model = LLMAPIHandlerFactory._get_check_model(llm_config, model_label)
|
||||
|
||||
# Check reasoning support (safe call - log but don't fail if litellm errors)
|
||||
supports_reasoning = False
|
||||
@@ -273,7 +299,6 @@ class LLMAPIHandlerFactory:
|
||||
model=model_label,
|
||||
)
|
||||
return
|
||||
|
||||
# Apply optimization based on model type
|
||||
model_label_lower = (model_label or "").lower()
|
||||
if "gemini" in model_label_lower:
|
||||
@@ -296,7 +321,6 @@ class LLMAPIHandlerFactory:
|
||||
reasoning_effort="low",
|
||||
model=model_label,
|
||||
)
|
||||
|
||||
except (AttributeError, KeyError, TypeError) as e:
|
||||
LOG.warning(
|
||||
"Failed to apply thinking budget optimization",
|
||||
@@ -312,6 +336,8 @@ class LLMAPIHandlerFactory:
|
||||
parameters: dict[str, Any], new_budget: int, llm_config: LLMConfig | LLMRouterConfig, prompt_name: str
|
||||
) -> None:
|
||||
"""Apply thinking optimization for Anthropic/Claude models."""
|
||||
model_label = LLMAPIHandlerFactory._get_model_label(llm_config)
|
||||
|
||||
if llm_config.reasoning_effort is not None:
|
||||
# Use reasoning_effort if configured in LLM config - always use "low" per LiteLLM constants
|
||||
parameters["reasoning_effort"] = "low"
|
||||
@@ -350,8 +376,7 @@ class LLMAPIHandlerFactory:
|
||||
parameters: dict[str, Any], new_budget: int, llm_config: LLMConfig | LLMRouterConfig, prompt_name: str
|
||||
) -> None:
|
||||
"""Apply thinking optimization for Gemini models using exact integer budget value."""
|
||||
# Get model label for logging — prefer main_model_group for router configs
|
||||
model_label = llm_config.main_model_group if isinstance(llm_config, LLMRouterConfig) else llm_config.model_name
|
||||
model_label = LLMAPIHandlerFactory._get_model_label(llm_config)
|
||||
|
||||
# Models that use thinking_level (e.g. Gemini 3 Pro/Flash) don't support budget_tokens.
|
||||
# Their reasoning is already bounded by the thinking_level set in their config, so skip.
|
||||
@@ -373,6 +398,34 @@ class LLMAPIHandlerFactory:
|
||||
model=model_label,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_model_label(llm_config: LLMConfig | LLMRouterConfig) -> str | None:
|
||||
"""Extract a safe model label from LLMConfig or LLMRouterConfig for logging and capability checks."""
|
||||
model_label = getattr(llm_config, "model_name", None)
|
||||
# Compute a safe model label and a representative model for capability checks
|
||||
if model_label is None and isinstance(llm_config, LLMRouterConfig):
|
||||
model_label = getattr(llm_config, "main_model_group", "router")
|
||||
return model_label
|
||||
|
||||
@staticmethod
|
||||
def _get_check_model(llm_config: LLMConfig | LLMRouterConfig, model_label: str | None) -> str | None:
|
||||
"""Get a representative model name for capability checks from LLMRouterConfig or use the model label."""
|
||||
check_model = model_label
|
||||
if isinstance(llm_config, LLMRouterConfig) and getattr(llm_config, "model_list", None):
|
||||
try:
|
||||
check_model = llm_config.model_list[0].model_name or model_label # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
check_model = model_label
|
||||
return check_model
|
||||
|
||||
@staticmethod
|
||||
def is_github_copilot_endpoint() -> bool:
|
||||
"""Check if the OPENAI_COMPATIBLE endpoint is GitHub Copilot."""
|
||||
return (
|
||||
settings.OPENAI_COMPATIBLE_API_BASE is not None
|
||||
and settings.OPENAI_COMPATIBLE_GITHUB_COPILOT_DOMAIN in settings.OPENAI_COMPATIBLE_API_BASE
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_override_llm_api_handler(override_llm_key: str | None, *, default: LLMAPIHandler) -> LLMAPIHandler:
|
||||
if not override_llm_key:
|
||||
@@ -856,6 +909,11 @@ class LLMAPIHandlerFactory:
|
||||
llm_caller = LLMCaller(llm_key=llm_key, base_parameters=base_parameters)
|
||||
return llm_caller.call
|
||||
|
||||
# For GitHub Copilot via OPENAI_COMPATIBLE, use LLMCaller for a custom header
|
||||
if llm_key == "OPENAI_COMPATIBLE" and LLMAPIHandlerFactory.is_github_copilot_endpoint():
|
||||
llm_caller = LLMCaller(llm_key=llm_key, base_parameters=base_parameters)
|
||||
return llm_caller.call
|
||||
|
||||
assert isinstance(llm_config, LLMConfig)
|
||||
|
||||
@traced(tags=[llm_key])
|
||||
@@ -1313,6 +1371,14 @@ class LLMCaller:
|
||||
base_url=settings.OPENROUTER_API_BASE,
|
||||
http_client=ForgeAsyncHttpxClientWrapper(),
|
||||
)
|
||||
elif self.llm_key == "OPENAI_COMPATIBLE" and LLMAPIHandlerFactory.is_github_copilot_endpoint():
|
||||
# For GitHub Copilot, use the actual model name from OPENAI_COMPATIBLE_MODEL_NAME
|
||||
self.llm_key = settings.OPENAI_COMPATIBLE_MODEL_NAME or self.llm_key
|
||||
self.openai_client = AsyncOpenAI(
|
||||
api_key=settings.OPENAI_COMPATIBLE_API_KEY,
|
||||
base_url=settings.OPENAI_COMPATIBLE_API_BASE,
|
||||
http_client=ForgeAsyncHttpxClientWrapper(),
|
||||
)
|
||||
|
||||
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
|
||||
self.current_tool_results.append(tool_result)
|
||||
@@ -1598,12 +1664,22 @@ class LLMCaller:
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage | UITarsResponse:
|
||||
if self.openai_client:
|
||||
# Extract OpenRouter-specific parameters
|
||||
# Extract OpenRouter-specific and GitHub Copilot-specific parameters
|
||||
extra_headers = {}
|
||||
if settings.SKYVERN_APP_URL:
|
||||
extra_headers["HTTP-Referer"] = settings.SKYVERN_APP_URL
|
||||
extra_headers["X-Title"] = "Skyvern"
|
||||
|
||||
# Add required headers for GitHub Copilot API
|
||||
if LLMAPIHandlerFactory.is_github_copilot_endpoint():
|
||||
extra_headers["Copilot-Integration-Id"] = "copilot-chat"
|
||||
|
||||
# Add vision header when there are images in the request
|
||||
has_images = any(is_image_message(msg) for msg in messages)
|
||||
# Only set the header when there are actual images in the request
|
||||
if has_images:
|
||||
extra_headers["Copilot-Vision-Request"] = "true"
|
||||
|
||||
# Filter out parameters that OpenAI client doesn't support
|
||||
openai_params = {}
|
||||
if "max_completion_tokens" in active_parameters:
|
||||
@@ -1740,6 +1816,19 @@ class LLMCaller:
|
||||
if self.original_llm_key.startswith("openrouter/"):
|
||||
return empty_call_stats
|
||||
|
||||
# Handle OPENAI_COMPATIBLE provider GitHub Copilot
|
||||
if self.original_llm_key == "OPENAI_COMPATIBLE" and isinstance(response, (ModelResponse, CustomStreamWrapper)):
|
||||
input_tokens, output_tokens, reasoning_tokens, cached_tokens = LLMAPIHandlerFactory._extract_token_counts(
|
||||
response
|
||||
)
|
||||
return LLMCallStats(
|
||||
llm_cost=0, # TODO: calculate the cost according to the price: https://github.com/features/copilot/plans
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
|
||||
# Handle UI-TARS response (UITarsResponse object from _call_ui_tars)
|
||||
if isinstance(response, UITarsResponse):
|
||||
ui_tars_usage = response.usage
|
||||
@@ -1771,28 +1860,9 @@ class LLMCaller:
|
||||
except Exception as e:
|
||||
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||
llm_cost = 0
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
reasoning_tokens = 0
|
||||
cached_tokens = 0
|
||||
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||
|
||||
# Extract reasoning tokens from completion_tokens_details
|
||||
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
|
||||
|
||||
# Extract cached tokens from prompt_tokens_details
|
||||
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
|
||||
if cached_token_detail:
|
||||
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
|
||||
|
||||
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
|
||||
if cached_tokens == 0:
|
||||
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
input_tokens, output_tokens, reasoning_tokens, cached_tokens = LLMAPIHandlerFactory._extract_token_counts(
|
||||
response
|
||||
)
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=input_tokens,
|
||||
|
||||
@@ -15,13 +15,13 @@ UI-TARS LLM Caller that follows the standard LLMCaller pattern.
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict
|
||||
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller
|
||||
from skyvern.forge.sdk.api.llm.utils import is_image_message
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
|
||||
@@ -33,15 +33,6 @@ def _build_system_prompt(instruction: str, language: str = "English") -> str:
|
||||
return prompt_engine.load_prompt("ui-tars-system-prompt", language=language, instruction=instruction)
|
||||
|
||||
|
||||
def _is_image_message(message: Dict[str, Any]) -> bool:
|
||||
"""Check if message contains an image."""
|
||||
return (
|
||||
message.get("role") == "user"
|
||||
and isinstance(message.get("content"), list)
|
||||
and any(item.get("type") == "image_url" for item in message["content"])
|
||||
)
|
||||
|
||||
|
||||
class UITarsLLMCaller(LLMCaller):
|
||||
"""
|
||||
UI-TARS specific LLM caller that manages conversation history.
|
||||
@@ -114,7 +105,7 @@ class UITarsLLMCaller(LLMCaller):
|
||||
i = 1 # Start after system prompt (index 0)
|
||||
while i < len(self.message_history) and removed_count < images_to_remove:
|
||||
message = self.message_history[i]
|
||||
if _is_image_message(message):
|
||||
if is_image_message(message):
|
||||
# Remove only the screenshot message, keep all assistant responses
|
||||
self.message_history.pop(i)
|
||||
removed_count += 1
|
||||
@@ -131,7 +122,7 @@ class UITarsLLMCaller(LLMCaller):
|
||||
"""Count existing image messages in the conversation history."""
|
||||
count = 0
|
||||
for message in self.message_history:
|
||||
if _is_image_message(message):
|
||||
if is_image_message(message):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
@@ -15,6 +15,15 @@ from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidL
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
def is_image_message(message: dict[str, Any]) -> bool:
|
||||
"""Check if message contains an image."""
|
||||
return (
|
||||
message.get("role") == "user"
|
||||
and isinstance(message.get("content"), list)
|
||||
and any(item.get("type") == "image_url" for item in message["content"])
|
||||
)
|
||||
|
||||
|
||||
async def llm_messages_builder(
|
||||
prompt: str,
|
||||
screenshots: list[bytes] | None = None,
|
||||
|
||||
Reference in New Issue
Block a user