diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 0972bd37..46655c2d 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -2,7 +2,7 @@ import dataclasses import json import time from asyncio import CancelledError -from typing import Any, AsyncIterator +from typing import Any, AsyncIterator, Protocol, runtime_checkable import litellm import structlog @@ -40,6 +40,19 @@ LOG = structlog.get_logger() EXTRACT_ACTION_PROMPT_NAME = "extract-actions" +@runtime_checkable +class RouterWithModelList(Protocol): + model_list: list[dict[str, Any]] + + +def _get_primary_model_dict(router: Any, main_model_group: str) -> dict[str, Any] | None: + if isinstance(router, RouterWithModelList): + for model_dict in router.model_list: + if model_dict.get("model_name") == main_model_group: + return model_dict + return None + + class LLMCallStats(BaseModel): input_tokens: int | None = None output_tokens: int | None = None @@ -366,6 +379,8 @@ class LLMAPIHandlerFactory: vertex_cache_attached = False cache_resource_name = getattr(context, "vertex_cache_name", None) + primary_model_dict = _get_primary_model_dict(router, main_model_group) + # Add cached_content to primary model's litellm_params (not global parameters) # This ensures it's only passed to the Gemini primary, not to fallback models. # By setting it in the model-specific litellm_params, LiteLLM will only include it @@ -376,23 +391,24 @@ class LLMAPIHandlerFactory: and prompt_name == EXTRACT_ACTION_PROMPT_NAME and getattr(context, "use_prompt_caching", False) and "gemini" in main_model_group.lower() + and primary_model_dict is not None ): - # Modify the router's model_list to add cached_content only to the primary model - # The router is created per-handler-instance, so this modification is safe - # and idempotent (setting the same value multiple times is fine) - for model_dict in router.model_list: - if model_dict.get("model_name") == main_model_group: - if "litellm_params" not in model_dict: - model_dict["litellm_params"] = {} - model_dict["litellm_params"]["cached_content"] = cache_resource_name - vertex_cache_attached = True - LOG.info( - "Adding Vertex AI cache reference to primary model in router", - prompt_name=prompt_name, - primary_model=main_model_group, - fallback_model=llm_config.fallback_model_group, - ) - break + litellm_params = primary_model_dict.setdefault("litellm_params", {}) + litellm_params["cached_content"] = cache_resource_name + vertex_cache_attached = True + LOG.info( + "Adding Vertex AI cache reference to primary model in router", + prompt_name=prompt_name, + primary_model=main_model_group, + fallback_model=llm_config.fallback_model_group, + ) + elif primary_model_dict and "litellm_params" in primary_model_dict: + if primary_model_dict["litellm_params"].pop("cached_content", None): + LOG.info( + "Removed Vertex AI cache reference from primary model in router", + prompt_name=prompt_name, + primary_model=main_model_group, + ) llm_request_payload = { "model": llm_key, @@ -728,6 +744,14 @@ class LLMAPIHandlerFactory: prompt_name=prompt_name, cache_attached=True, ) + elif "cached_content" in active_parameters: + removed_cache = active_parameters.pop("cached_content", None) + if removed_cache: + LOG.info( + "Removed Vertex AI cache reference from request", + prompt_name=prompt_name, + cache_was_attached=True, + ) llm_request_payload = { "model": model_name,