Pedro/fix vertex cache leak (#4135)

This commit is contained in:
pedrohsdb
2025-11-29 05:39:05 -08:00
committed by GitHub
parent 2eeca1c699
commit 3f11d44762
4 changed files with 282 additions and 88 deletions

View File

@@ -1,3 +1,4 @@
import copy
import dataclasses
import json
import time
@@ -29,6 +30,7 @@ from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import SpeculativeLLMMetadata, Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
@@ -38,6 +40,7 @@ from skyvern.utils.image_resizer import Resolution, get_resize_target_dimension,
LOG = structlog.get_logger()
EXTRACT_ACTION_PROMPT_NAME = "extract-actions"
CHECK_USER_GOAL_PROMPT_NAMES = {"check-user-goal", "check-user-goal-with-termination"}
@runtime_checkable
@@ -61,6 +64,44 @@ class LLMCallStats(BaseModel):
llm_cost: float | None = None
async def _log_hashed_href_map_artifacts_if_needed(
context: SkyvernContext | None,
step: Step | None,
task_v2: TaskV2 | None,
thought: Thought | None,
ai_suggestion: AISuggestion | None,
*,
is_speculative_step: bool,
) -> None:
if context and context.hashed_href_map and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
def _log_vertex_cache_hit_if_needed(
context: SkyvernContext | None,
prompt_name: str,
llm_identifier: str,
cached_tokens: int,
) -> None:
if cached_tokens > 0 and prompt_name == EXTRACT_ACTION_PROMPT_NAME and context and context.vertex_cache_name:
LOG.info(
"Vertex cache hit",
prompt_name=prompt_name,
llm_key=llm_identifier,
cached_tokens=cached_tokens,
cache_name=context.vertex_cache_name,
cache_key=context.vertex_cache_key,
cache_variant=context.vertex_cache_variant,
)
class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {}
_thinking_budget_settings: dict[str, int] | None = None
@@ -237,6 +278,7 @@ class LLMAPIHandlerFactory:
if not override_llm_key:
return default
try:
# Explicit overrides should honor the exact model choice and skip experimentation reroutes.
return LLMAPIHandlerFactory.get_llm_api_handler(override_llm_key)
except Exception:
LOG.warning(
@@ -320,17 +362,17 @@ class LLMAPIHandlerFactory:
context = skyvern_context.current()
is_speculative_step = step.is_speculative if step else False
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
await _log_hashed_href_map_artifacts_if_needed(
context,
step,
task_v2,
thought,
ai_suggestion,
is_speculative_step=is_speculative_step,
)
llm_prompt_value = prompt
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_prompt_value.encode("utf-8"),
@@ -343,6 +385,25 @@ class LLMAPIHandlerFactory:
# Build messages and apply caching in one step
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
async def _log_llm_request_artifact(model_label: str, vertex_cache_attached_flag: bool) -> str:
llm_request_payload = {
"model": model_label,
"messages": messages,
**parameters,
"vertex_cache_attached": vertex_cache_attached_flag,
}
llm_request_json = json.dumps(llm_request_payload)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
return llm_request_json
# Inject context caching system message when available
try:
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
@@ -377,70 +438,96 @@ class LLMAPIHandlerFactory:
except Exception as e:
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
vertex_cache_attached = False
cache_resource_name = getattr(context, "vertex_cache_name", None)
cache_variant = getattr(context, "vertex_cache_variant", None)
primary_model_dict = _get_primary_model_dict(router, main_model_group)
# Add cached_content to primary model's litellm_params (not global parameters)
# This ensures it's only passed to the Gemini primary, not to fallback models.
# By setting it in the model-specific litellm_params, LiteLLM will only include it
# when calling the primary model. When falling back to GPT-5, the fallback model's
# litellm_params won't have cached_content, so it won't be sent.
if (
cache_resource_name
should_attach_vertex_cache = bool(
cache_resource_name is not None
and prompt_name == EXTRACT_ACTION_PROMPT_NAME
and getattr(context, "use_prompt_caching", False)
and main_model_group
and "gemini" in main_model_group.lower()
and primary_model_dict is not None
):
litellm_params = primary_model_dict.setdefault("litellm_params", {})
litellm_params["cached_content"] = cache_resource_name
vertex_cache_attached = True
)
model_used = main_model_group
llm_request_json = ""
async def _call_primary_with_vertex_cache(
cache_name: str,
cache_variant_name: str | None,
) -> tuple[ModelResponse, str, str]:
if primary_model_dict is None:
raise ValueError("Primary router model missing configuration")
litellm_params = copy.deepcopy(primary_model_dict.get("litellm_params") or {})
if not litellm_params:
raise ValueError("Primary router model missing litellm_params")
active_params = copy.deepcopy(litellm_params)
active_params.update(parameters)
active_params["cached_content"] = cache_name
request_model = active_params.pop("model", primary_model_dict.get("model_name", main_model_group))
LOG.info(
"Adding Vertex AI cache reference to primary model in router",
"Adding Vertex AI cache reference to primary Gemini request",
prompt_name=prompt_name,
primary_model=main_model_group,
fallback_model=llm_config.fallback_model_group,
cache_name=cache_name,
cache_key=getattr(context, "vertex_cache_key", None),
cache_variant=cache_variant_name,
)
elif primary_model_dict and "litellm_params" in primary_model_dict:
if primary_model_dict["litellm_params"].pop("cached_content", None):
LOG.info(
"Removed Vertex AI cache reference from primary model in router",
prompt_name=prompt_name,
primary_model=main_model_group,
)
request_payload_json = await _log_llm_request_artifact(request_model, True)
response = await litellm.acompletion(
model=request_model,
messages=messages,
timeout=settings.LLM_CONFIG_TIMEOUT,
drop_params=True,
**active_params,
)
return response, request_model, request_payload_json
llm_request_payload = {
"model": llm_key,
"messages": messages,
**parameters,
"vertex_cache_attached": vertex_cache_attached,
}
llm_request_json = json.dumps(llm_request_payload)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
model_used = main_model_group
try:
async def _call_router_without_cache() -> tuple[ModelResponse, str]:
request_payload_json = await _log_llm_request_artifact(llm_key, False)
response = await router.acompletion(
model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters
model=main_model_group,
messages=messages,
timeout=settings.LLM_CONFIG_TIMEOUT,
**parameters,
)
response_model = response.model or main_model_group
model_used = response_model
if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group):
LOG.info(
"LLM router fallback succeeded",
llm_key=llm_key,
prompt_name=prompt_name,
primary_model=main_model_group,
fallback_model=response_model,
)
return response, request_payload_json
try:
response: ModelResponse | None = None
if should_attach_vertex_cache and cache_resource_name:
try:
response, direct_model_used, llm_request_json = await _call_primary_with_vertex_cache(
cache_resource_name,
cache_variant,
)
model_used = response.model or direct_model_used
except CancelledError:
raise
except Exception as cache_error:
LOG.warning(
"Vertex cache primary call failed, retrying via router",
prompt_name=prompt_name,
error=str(cache_error),
cache_name=cache_resource_name,
cache_variant=cache_variant,
)
response = None
if response is None:
response, llm_request_json = await _call_router_without_cache()
response_model = response.model or main_model_group
model_used = response_model
if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group):
LOG.info(
"LLM router fallback succeeded",
llm_key=llm_key,
prompt_name=prompt_name,
primary_model=main_model_group,
fallback_model=response_model,
)
except litellm.exceptions.APIError as e:
raise LLMProviderErrorRetryableTask(llm_key) from e
except litellm.exceptions.ContextWindowExceededError as e:
@@ -611,7 +698,10 @@ class LLMAPIHandlerFactory:
return llm_api_handler_with_router_and_fallback
@staticmethod
def get_llm_api_handler(llm_key: str, base_parameters: dict[str, Any] | None = None) -> LLMAPIHandler:
def get_llm_api_handler(
llm_key: str,
base_parameters: dict[str, Any] | None = None,
) -> LLMAPIHandler:
try:
llm_config = LLMConfigRegistry.get_config(llm_key)
except InvalidLLMConfigError:
@@ -668,15 +758,14 @@ class LLMAPIHandlerFactory:
context = skyvern_context.current()
is_speculative_step = step.is_speculative if step else False
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
await _log_hashed_href_map_artifacts_if_needed(
context,
step,
task_v2,
thought,
ai_suggestion,
is_speculative_step=is_speculative_step,
)
llm_prompt_value = prompt
if step and not is_speculative_step:
@@ -746,6 +835,9 @@ class LLMAPIHandlerFactory:
"Adding Vertex AI cache reference to request",
prompt_name=prompt_name,
cache_attached=True,
cache_name=cache_resource_name,
cache_key=getattr(context, "vertex_cache_key", None),
cache_variant=getattr(context, "vertex_cache_variant", None),
)
elif "cached_content" in active_parameters:
removed_cache = active_parameters.pop("cached_content", None)
@@ -754,6 +846,9 @@ class LLMAPIHandlerFactory:
"Removed Vertex AI cache reference from request",
prompt_name=prompt_name,
cache_was_attached=True,
cache_name=cache_resource_name,
cache_key=getattr(context, "vertex_cache_key", None),
cache_variant=getattr(context, "vertex_cache_variant", None),
)
llm_request_payload = {
@@ -863,6 +958,8 @@ class LLMAPIHandlerFactory:
if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
_log_vertex_cache_hit_if_needed(context, prompt_name, model_name, cached_tokens)
if step:
await app.DATABASE.update_step(
task_id=step.task_id,
@@ -1041,15 +1138,14 @@ class LLMCaller:
context = skyvern_context.current()
is_speculative_step = step.is_speculative if step else False
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
await _log_hashed_href_map_artifacts_if_needed(
context,
step,
task_v2,
thought,
ai_suggestion,
is_speculative_step=is_speculative_step,
)
if screenshots and self.screenshot_scaling_enabled:
target_dimension = self.get_screenshot_resize_target_dimension(window_dimension)