177 lines
7.5 KiB
Python
177 lines
7.5 KiB
Python
"""
|
|
Tests for Vertex AI cache model name extraction from LLMRouterConfig.
|
|
|
|
This tests the fix for the issue where GEMINI_3_0_FLASH_WITH_FALLBACK was
|
|
incorrectly using 'gemini-3.0-flash' instead of 'gemini-3-flash-preview'.
|
|
"""
|
|
|
|
import re
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class MockLLMRouterModelConfig:
|
|
model_name: str
|
|
litellm_params: dict
|
|
|
|
|
|
@dataclass
|
|
class MockLLMRouterConfig:
|
|
model_name: str
|
|
model_list: list
|
|
main_model_group: str
|
|
required_env_vars: list = None
|
|
|
|
def __post_init__(self):
|
|
if self.required_env_vars is None:
|
|
self.required_env_vars = []
|
|
|
|
|
|
@dataclass
|
|
class MockLLMConfig:
|
|
model_name: str
|
|
required_env_vars: list = None
|
|
litellm_params: dict = None
|
|
|
|
def __post_init__(self):
|
|
if self.required_env_vars is None:
|
|
self.required_env_vars = []
|
|
|
|
|
|
class TestVertexCacheModelExtraction:
|
|
"""Test that model names are correctly extracted for Vertex AI caching."""
|
|
|
|
def _extract_model_name(self, llm_config, resolved_llm_key: str) -> str:
|
|
"""
|
|
Mimics the model name extraction logic from _create_vertex_cache_for_task.
|
|
"""
|
|
model_name = "gemini-2.5-flash" # Default
|
|
extracted_name = None
|
|
|
|
# For router configs (LLMRouterConfig), extract from model_list primary model FIRST
|
|
# This must be checked before model_name since router model_name is just an identifier
|
|
# (e.g., "gemini-3.0-flash-gpt-5-mini-fallback-router"), not an actual Vertex model
|
|
if hasattr(llm_config, "model_list") and hasattr(llm_config, "main_model_group"):
|
|
# Find the primary model in model_list by matching main_model_group
|
|
for model_entry in llm_config.model_list:
|
|
if model_entry.model_name == llm_config.main_model_group:
|
|
# Extract actual model name from litellm_params
|
|
model_param = model_entry.litellm_params.get("model", "")
|
|
if "vertex_ai/" in model_param:
|
|
extracted_name = model_param.split("/")[-1]
|
|
elif model_param.startswith("gemini-"):
|
|
extracted_name = model_param
|
|
break
|
|
|
|
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
|
|
if not extracted_name and hasattr(llm_config, "model_name") and isinstance(llm_config.model_name, str):
|
|
if "vertex_ai/" in llm_config.model_name:
|
|
# Direct Vertex config: "vertex_ai/gemini-2.5-flash" -> "gemini-2.5-flash"
|
|
extracted_name = llm_config.model_name.split("/")[-1]
|
|
elif llm_config.model_name.startswith("gemini-"):
|
|
# Already in correct format
|
|
extracted_name = llm_config.model_name
|
|
|
|
# For router/fallback configs, extract from api_base or infer from key name
|
|
if not extracted_name and hasattr(llm_config, "litellm_params") and llm_config.litellm_params:
|
|
params = llm_config.litellm_params
|
|
api_base = params.get("api_base") if isinstance(params, dict) else getattr(params, "api_base", None)
|
|
if api_base and isinstance(api_base, str) and "/models/" in api_base:
|
|
# Extract from URL: .../models/gemini-2.5-flash -> "gemini-2.5-flash"
|
|
extracted_name = api_base.split("/models/")[-1]
|
|
|
|
# For router configs without api_base, infer from the llm_key itself
|
|
if not extracted_name:
|
|
# Extract version from llm_key
|
|
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", resolved_llm_key, re.IGNORECASE)
|
|
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
|
|
|
|
# Determine flavor
|
|
if "_PRO_" in resolved_llm_key or resolved_llm_key.endswith("_PRO"):
|
|
extracted_name = f"gemini-{version}-pro"
|
|
elif "_FLASH_LITE_" in resolved_llm_key or resolved_llm_key.endswith("_FLASH_LITE"):
|
|
extracted_name = f"gemini-{version}-flash-lite"
|
|
else:
|
|
# Default to flash flavor
|
|
extracted_name = f"gemini-{version}-flash"
|
|
|
|
if extracted_name:
|
|
model_name = extracted_name
|
|
|
|
# Normalize model name to the canonical Vertex identifier
|
|
# Preserve preview suffixes so we don't strip required identifiers (e.g., gemini-3-flash-preview).
|
|
match = re.search(r"(gemini-\d+(?:\.\d+)?-(?:flash-lite|flash|pro)(?:-preview)?)", model_name, re.IGNORECASE)
|
|
if match:
|
|
model_name = match.group(1).lower()
|
|
|
|
return model_name
|
|
|
|
def test_router_config_extracts_gemini_3_flash_preview(self):
|
|
"""
|
|
GEMINI_3_0_FLASH_WITH_FALLBACK should extract 'gemini-3-flash-preview',
|
|
NOT 'gemini-3.0-flash'.
|
|
"""
|
|
# Create a mock router config that matches the real GEMINI_3_0_FLASH_WITH_FALLBACK
|
|
router_config = MockLLMRouterConfig(
|
|
model_name="gemini-3.0-flash-gpt-5-mini-fallback-router",
|
|
model_list=[
|
|
MockLLMRouterModelConfig(
|
|
model_name="vertex-gemini-3.0-flash",
|
|
litellm_params={"model": "vertex_ai/gemini-3-flash-preview"},
|
|
),
|
|
MockLLMRouterModelConfig(
|
|
model_name="gpt-5-mini-fallback",
|
|
litellm_params={"model": "gpt-5-mini-2025-08-07"},
|
|
),
|
|
],
|
|
main_model_group="vertex-gemini-3.0-flash",
|
|
)
|
|
|
|
model_name = self._extract_model_name(router_config, "GEMINI_3_0_FLASH_WITH_FALLBACK")
|
|
|
|
# Should extract the correct model name with -preview suffix
|
|
assert model_name == "gemini-3-flash-preview", (
|
|
f"Expected 'gemini-3-flash-preview' but got '{model_name}'. "
|
|
"The router config should extract from model_list, not infer from llm_key."
|
|
)
|
|
|
|
def test_direct_vertex_config_extracts_correctly(self):
|
|
"""Direct VERTEX_GEMINI_3.0_FLASH should extract correctly."""
|
|
direct_config = MockLLMConfig(
|
|
model_name="vertex_ai/gemini-3-flash-preview",
|
|
)
|
|
|
|
model_name = self._extract_model_name(direct_config, "VERTEX_GEMINI_3.0_FLASH")
|
|
assert model_name == "gemini-3-flash-preview"
|
|
|
|
def test_router_config_extracts_gemini_2_5_flash(self):
|
|
"""GEMINI_2_5_FLASH_WITH_FALLBACK should extract 'gemini-2.5-flash'."""
|
|
router_config = MockLLMRouterConfig(
|
|
model_name="gemini-2.5-flash-gpt-5-mini-fallback-router",
|
|
model_list=[
|
|
MockLLMRouterModelConfig(
|
|
model_name="vertex-gemini-2.5-flash",
|
|
litellm_params={"model": "vertex_ai/gemini-2.5-flash"},
|
|
),
|
|
MockLLMRouterModelConfig(
|
|
model_name="gpt-5-mini-fallback",
|
|
litellm_params={"model": "gpt-5-mini-2025-08-07"},
|
|
),
|
|
],
|
|
main_model_group="vertex-gemini-2.5-flash",
|
|
)
|
|
|
|
model_name = self._extract_model_name(router_config, "GEMINI_2_5_FLASH_WITH_FALLBACK")
|
|
assert model_name == "gemini-2.5-flash"
|
|
|
|
def test_fallback_to_llm_key_inference_when_no_model_list(self):
|
|
"""When there's no model_list, should fall back to llm_key inference."""
|
|
# A config that doesn't have model_list (not a router config)
|
|
simple_config = MockLLMConfig(
|
|
model_name="some-unrelated-name",
|
|
)
|
|
|
|
model_name = self._extract_model_name(simple_config, "GEMINI_2_5_FLASH")
|
|
# Should fall back to inference from llm_key
|
|
assert model_name == "gemini-2.5-flash"
|