Pedro/fix vertex cache leak (#4135)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@@ -2467,8 +2468,35 @@ class ForgeAgent:
|
||||
|
||||
return scraped_page, extract_action_prompt, use_caching
|
||||
|
||||
@staticmethod
|
||||
def _build_extract_action_cache_variant(
|
||||
verification_code_check: bool,
|
||||
has_magic_link_page: bool,
|
||||
complete_criterion: str | None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a short-but-unique cache variant identifier so extract-action prompts that
|
||||
differ meaningfully (OTP, magic link flows, complete criteria) do not reuse the
|
||||
same Vertex cache object.
|
||||
"""
|
||||
variant_parts: list[str] = []
|
||||
if verification_code_check:
|
||||
variant_parts.append("vc")
|
||||
if has_magic_link_page:
|
||||
variant_parts.append("ml")
|
||||
if complete_criterion:
|
||||
normalized = " ".join(complete_criterion.split())
|
||||
digest = hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:6]
|
||||
variant_parts.append(f"cc{digest}")
|
||||
return "-".join(variant_parts) if variant_parts else "std"
|
||||
|
||||
async def _create_vertex_cache_for_task(
|
||||
self, task: Task, static_prompt: str, context: SkyvernContext, llm_key_override: str | None
|
||||
self,
|
||||
task: Task,
|
||||
static_prompt: str,
|
||||
context: SkyvernContext,
|
||||
llm_key_override: str | None,
|
||||
prompt_variant: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Create a Vertex AI cache for the task's static prompt.
|
||||
@@ -2479,9 +2507,9 @@ class ForgeAgent:
|
||||
task: The task to create cache for
|
||||
static_prompt: The static prompt content to cache
|
||||
context: The Skyvern context to store the cache name in
|
||||
llm_key_override: Optional override when we explicitly pick an LLM key
|
||||
prompt_variant: Cache variant identifier (std/vc/ml/etc.)
|
||||
"""
|
||||
# Early return if task doesn't have an llm_key
|
||||
# This should not happen given the guard at the call site, but being defensive
|
||||
resolved_llm_key = llm_key_override or task.llm_key
|
||||
|
||||
if not resolved_llm_key:
|
||||
@@ -2491,17 +2519,20 @@ class ForgeAgent:
|
||||
)
|
||||
return
|
||||
|
||||
cache_variant = prompt_variant or "std"
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Attempting Vertex AI cache creation",
|
||||
task_id=task.task_id,
|
||||
llm_key=resolved_llm_key,
|
||||
cache_variant=cache_variant,
|
||||
)
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# Use llm_key as cache_key so all tasks with the same model share the same cache
|
||||
# This maximizes cache reuse and reduces cache storage costs
|
||||
cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}-{resolved_llm_key}"
|
||||
variant_suffix = f"-{cache_variant}" if cache_variant else ""
|
||||
|
||||
cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}{variant_suffix}-{resolved_llm_key}"
|
||||
|
||||
# Get the actual model name from LLM config to ensure correct format
|
||||
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
|
||||
@@ -2565,8 +2596,10 @@ class ForgeAgent:
|
||||
ttl_seconds=3600, # 1 hour
|
||||
)
|
||||
|
||||
# Store cache resource name in context
|
||||
# Store cache metadata in context
|
||||
context.vertex_cache_name = cache_data["name"]
|
||||
context.vertex_cache_key = cache_key
|
||||
context.vertex_cache_variant = cache_variant
|
||||
|
||||
LOG.info(
|
||||
"Created Vertex AI cache for task",
|
||||
@@ -2574,6 +2607,7 @@ class ForgeAgent:
|
||||
cache_key=cache_key,
|
||||
cache_name=cache_data["name"],
|
||||
model_name=model_name,
|
||||
cache_variant=cache_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(
|
||||
@@ -2653,7 +2687,7 @@ class ForgeAgent:
|
||||
|
||||
# Check if prompt caching is enabled for extract-action
|
||||
use_caching = False
|
||||
prompt_caching_settings = LLMAPIHandlerFactory._prompt_caching_settings or {}
|
||||
prompt_caching_settings = await self._get_prompt_caching_settings(context)
|
||||
effective_llm_key = task.llm_key
|
||||
if not effective_llm_key:
|
||||
handler_for_key = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
@@ -2701,6 +2735,11 @@ class ForgeAgent:
|
||||
"parse_select_feature_enabled": context.enable_parse_select_in_extract,
|
||||
"has_magic_link_page": context.has_magic_link_page(task.task_id),
|
||||
}
|
||||
cache_variant = self._build_extract_action_cache_variant(
|
||||
verification_code_check=verification_code_check,
|
||||
has_magic_link_page=context.has_magic_link_page(task.task_id),
|
||||
complete_criterion=task.complete_criterion.strip() if task.complete_criterion else None,
|
||||
)
|
||||
static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs)
|
||||
dynamic_prompt = prompt_engine.load_prompt(
|
||||
f"{template}-dynamic",
|
||||
@@ -2718,7 +2757,13 @@ class ForgeAgent:
|
||||
|
||||
# Create Vertex AI cache for Gemini models
|
||||
if effective_llm_key and "GEMINI" in effective_llm_key:
|
||||
await self._create_vertex_cache_for_task(task, static_prompt, context, effective_llm_key)
|
||||
await self._create_vertex_cache_for_task(
|
||||
task,
|
||||
static_prompt,
|
||||
context,
|
||||
effective_llm_key,
|
||||
prompt_variant=cache_variant,
|
||||
)
|
||||
|
||||
combined_prompt = f"{static_prompt.rstrip()}\n\n{dynamic_prompt.lstrip()}"
|
||||
|
||||
@@ -2726,6 +2771,7 @@ class ForgeAgent:
|
||||
"Using cached prompt",
|
||||
task_id=task.task_id,
|
||||
prompt_name=EXTRACT_ACTION_PROMPT_NAME,
|
||||
cache_variant=cache_variant,
|
||||
)
|
||||
return combined_prompt, use_caching
|
||||
|
||||
@@ -2755,6 +2801,55 @@ class ForgeAgent:
|
||||
|
||||
return full_prompt, use_caching
|
||||
|
||||
async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]:
|
||||
"""
|
||||
Resolve prompt caching settings for the current run.
|
||||
|
||||
We prefer explicit overrides via LLMAPIHandlerFactory.set_prompt_caching_settings(), which
|
||||
are mostly used by scripts/tests. When no override exists, evaluate the PostHog experiment
|
||||
once per context and cache the result on the context to avoid repeated lookups.
|
||||
"""
|
||||
if LLMAPIHandlerFactory._prompt_caching_settings is not None:
|
||||
return LLMAPIHandlerFactory._prompt_caching_settings
|
||||
|
||||
if context.prompt_caching_settings is not None:
|
||||
return context.prompt_caching_settings
|
||||
|
||||
distinct_id = context.run_id or context.workflow_run_id or context.task_id
|
||||
organization_id = context.organization_id
|
||||
context.prompt_caching_settings = {}
|
||||
|
||||
if not distinct_id or not organization_id:
|
||||
return context.prompt_caching_settings
|
||||
|
||||
try:
|
||||
enabled = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
|
||||
"PROMPT_CACHING_OPTIMIZATION",
|
||||
distinct_id,
|
||||
properties={"organization_id": organization_id},
|
||||
)
|
||||
except Exception as exc:
|
||||
LOG.warning(
|
||||
"Failed to evaluate prompt caching experiment; defaulting to disabled",
|
||||
distinct_id=distinct_id,
|
||||
organization_id=organization_id,
|
||||
error=str(exc),
|
||||
)
|
||||
return context.prompt_caching_settings
|
||||
|
||||
if enabled:
|
||||
context.prompt_caching_settings = {
|
||||
EXTRACT_ACTION_PROMPT_NAME: True,
|
||||
EXTRACT_ACTION_TEMPLATE: True,
|
||||
}
|
||||
LOG.info(
|
||||
"Prompt caching optimization enabled",
|
||||
distinct_id=distinct_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return context.prompt_caching_settings
|
||||
|
||||
def _should_process_totp(self, scraped_page: ScrapedPage | None) -> bool:
|
||||
"""Detect TOTP pages by checking for multiple input fields or verification keywords."""
|
||||
if not scraped_page:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
@@ -29,6 +30,7 @@ from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
|
||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.models import SpeculativeLLMMetadata, Step
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
|
||||
@@ -38,6 +40,7 @@ from skyvern.utils.image_resizer import Resolution, get_resize_target_dimension,
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
EXTRACT_ACTION_PROMPT_NAME = "extract-actions"
|
||||
CHECK_USER_GOAL_PROMPT_NAMES = {"check-user-goal", "check-user-goal-with-termination"}
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@@ -61,6 +64,44 @@ class LLMCallStats(BaseModel):
|
||||
llm_cost: float | None = None
|
||||
|
||||
|
||||
async def _log_hashed_href_map_artifacts_if_needed(
|
||||
context: SkyvernContext | None,
|
||||
step: Step | None,
|
||||
task_v2: TaskV2 | None,
|
||||
thought: Thought | None,
|
||||
ai_suggestion: AISuggestion | None,
|
||||
*,
|
||||
is_speculative_step: bool,
|
||||
) -> None:
|
||||
if context and context.hashed_href_map and step and not is_speculative_step:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.HASHED_HREF_MAP,
|
||||
step=step,
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
|
||||
|
||||
def _log_vertex_cache_hit_if_needed(
|
||||
context: SkyvernContext | None,
|
||||
prompt_name: str,
|
||||
llm_identifier: str,
|
||||
cached_tokens: int,
|
||||
) -> None:
|
||||
if cached_tokens > 0 and prompt_name == EXTRACT_ACTION_PROMPT_NAME and context and context.vertex_cache_name:
|
||||
LOG.info(
|
||||
"Vertex cache hit",
|
||||
prompt_name=prompt_name,
|
||||
llm_key=llm_identifier,
|
||||
cached_tokens=cached_tokens,
|
||||
cache_name=context.vertex_cache_name,
|
||||
cache_key=context.vertex_cache_key,
|
||||
cache_variant=context.vertex_cache_variant,
|
||||
)
|
||||
|
||||
|
||||
class LLMAPIHandlerFactory:
|
||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||
_thinking_budget_settings: dict[str, int] | None = None
|
||||
@@ -237,6 +278,7 @@ class LLMAPIHandlerFactory:
|
||||
if not override_llm_key:
|
||||
return default
|
||||
try:
|
||||
# Explicit overrides should honor the exact model choice and skip experimentation reroutes.
|
||||
return LLMAPIHandlerFactory.get_llm_api_handler(override_llm_key)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
@@ -320,17 +362,17 @@ class LLMAPIHandlerFactory:
|
||||
|
||||
context = skyvern_context.current()
|
||||
is_speculative_step = step.is_speculative if step else False
|
||||
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.HASHED_HREF_MAP,
|
||||
step=step,
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
await _log_hashed_href_map_artifacts_if_needed(
|
||||
context,
|
||||
step,
|
||||
task_v2,
|
||||
thought,
|
||||
ai_suggestion,
|
||||
is_speculative_step=is_speculative_step,
|
||||
)
|
||||
|
||||
llm_prompt_value = prompt
|
||||
|
||||
if step and not is_speculative_step:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=llm_prompt_value.encode("utf-8"),
|
||||
@@ -343,6 +385,25 @@ class LLMAPIHandlerFactory:
|
||||
# Build messages and apply caching in one step
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
|
||||
async def _log_llm_request_artifact(model_label: str, vertex_cache_attached_flag: bool) -> str:
|
||||
llm_request_payload = {
|
||||
"model": model_label,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
"vertex_cache_attached": vertex_cache_attached_flag,
|
||||
}
|
||||
llm_request_json = json.dumps(llm_request_payload)
|
||||
if step and not is_speculative_step:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=llm_request_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
step=step,
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
return llm_request_json
|
||||
|
||||
# Inject context caching system message when available
|
||||
try:
|
||||
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
|
||||
@@ -377,70 +438,96 @@ class LLMAPIHandlerFactory:
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
|
||||
|
||||
vertex_cache_attached = False
|
||||
cache_resource_name = getattr(context, "vertex_cache_name", None)
|
||||
cache_variant = getattr(context, "vertex_cache_variant", None)
|
||||
primary_model_dict = _get_primary_model_dict(router, main_model_group)
|
||||
|
||||
# Add cached_content to primary model's litellm_params (not global parameters)
|
||||
# This ensures it's only passed to the Gemini primary, not to fallback models.
|
||||
# By setting it in the model-specific litellm_params, LiteLLM will only include it
|
||||
# when calling the primary model. When falling back to GPT-5, the fallback model's
|
||||
# litellm_params won't have cached_content, so it won't be sent.
|
||||
if (
|
||||
cache_resource_name
|
||||
should_attach_vertex_cache = bool(
|
||||
cache_resource_name is not None
|
||||
and prompt_name == EXTRACT_ACTION_PROMPT_NAME
|
||||
and getattr(context, "use_prompt_caching", False)
|
||||
and main_model_group
|
||||
and "gemini" in main_model_group.lower()
|
||||
and primary_model_dict is not None
|
||||
):
|
||||
litellm_params = primary_model_dict.setdefault("litellm_params", {})
|
||||
litellm_params["cached_content"] = cache_resource_name
|
||||
vertex_cache_attached = True
|
||||
)
|
||||
|
||||
model_used = main_model_group
|
||||
llm_request_json = ""
|
||||
|
||||
async def _call_primary_with_vertex_cache(
|
||||
cache_name: str,
|
||||
cache_variant_name: str | None,
|
||||
) -> tuple[ModelResponse, str, str]:
|
||||
if primary_model_dict is None:
|
||||
raise ValueError("Primary router model missing configuration")
|
||||
litellm_params = copy.deepcopy(primary_model_dict.get("litellm_params") or {})
|
||||
if not litellm_params:
|
||||
raise ValueError("Primary router model missing litellm_params")
|
||||
active_params = copy.deepcopy(litellm_params)
|
||||
active_params.update(parameters)
|
||||
active_params["cached_content"] = cache_name
|
||||
request_model = active_params.pop("model", primary_model_dict.get("model_name", main_model_group))
|
||||
LOG.info(
|
||||
"Adding Vertex AI cache reference to primary model in router",
|
||||
"Adding Vertex AI cache reference to primary Gemini request",
|
||||
prompt_name=prompt_name,
|
||||
primary_model=main_model_group,
|
||||
fallback_model=llm_config.fallback_model_group,
|
||||
cache_name=cache_name,
|
||||
cache_key=getattr(context, "vertex_cache_key", None),
|
||||
cache_variant=cache_variant_name,
|
||||
)
|
||||
elif primary_model_dict and "litellm_params" in primary_model_dict:
|
||||
if primary_model_dict["litellm_params"].pop("cached_content", None):
|
||||
LOG.info(
|
||||
"Removed Vertex AI cache reference from primary model in router",
|
||||
prompt_name=prompt_name,
|
||||
primary_model=main_model_group,
|
||||
)
|
||||
request_payload_json = await _log_llm_request_artifact(request_model, True)
|
||||
response = await litellm.acompletion(
|
||||
model=request_model,
|
||||
messages=messages,
|
||||
timeout=settings.LLM_CONFIG_TIMEOUT,
|
||||
drop_params=True,
|
||||
**active_params,
|
||||
)
|
||||
return response, request_model, request_payload_json
|
||||
|
||||
llm_request_payload = {
|
||||
"model": llm_key,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
"vertex_cache_attached": vertex_cache_attached,
|
||||
}
|
||||
llm_request_json = json.dumps(llm_request_payload)
|
||||
if step and not is_speculative_step:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=llm_request_json.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
step=step,
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
model_used = main_model_group
|
||||
try:
|
||||
async def _call_router_without_cache() -> tuple[ModelResponse, str]:
|
||||
request_payload_json = await _log_llm_request_artifact(llm_key, False)
|
||||
response = await router.acompletion(
|
||||
model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters
|
||||
model=main_model_group,
|
||||
messages=messages,
|
||||
timeout=settings.LLM_CONFIG_TIMEOUT,
|
||||
**parameters,
|
||||
)
|
||||
response_model = response.model or main_model_group
|
||||
model_used = response_model
|
||||
if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group):
|
||||
LOG.info(
|
||||
"LLM router fallback succeeded",
|
||||
llm_key=llm_key,
|
||||
prompt_name=prompt_name,
|
||||
primary_model=main_model_group,
|
||||
fallback_model=response_model,
|
||||
)
|
||||
return response, request_payload_json
|
||||
|
||||
try:
|
||||
response: ModelResponse | None = None
|
||||
if should_attach_vertex_cache and cache_resource_name:
|
||||
try:
|
||||
response, direct_model_used, llm_request_json = await _call_primary_with_vertex_cache(
|
||||
cache_resource_name,
|
||||
cache_variant,
|
||||
)
|
||||
model_used = response.model or direct_model_used
|
||||
except CancelledError:
|
||||
raise
|
||||
except Exception as cache_error:
|
||||
LOG.warning(
|
||||
"Vertex cache primary call failed, retrying via router",
|
||||
prompt_name=prompt_name,
|
||||
error=str(cache_error),
|
||||
cache_name=cache_resource_name,
|
||||
cache_variant=cache_variant,
|
||||
)
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
response, llm_request_json = await _call_router_without_cache()
|
||||
response_model = response.model or main_model_group
|
||||
model_used = response_model
|
||||
if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group):
|
||||
LOG.info(
|
||||
"LLM router fallback succeeded",
|
||||
llm_key=llm_key,
|
||||
prompt_name=prompt_name,
|
||||
primary_model=main_model_group,
|
||||
fallback_model=response_model,
|
||||
)
|
||||
except litellm.exceptions.APIError as e:
|
||||
raise LLMProviderErrorRetryableTask(llm_key) from e
|
||||
except litellm.exceptions.ContextWindowExceededError as e:
|
||||
@@ -611,7 +698,10 @@ class LLMAPIHandlerFactory:
|
||||
return llm_api_handler_with_router_and_fallback
|
||||
|
||||
@staticmethod
|
||||
def get_llm_api_handler(llm_key: str, base_parameters: dict[str, Any] | None = None) -> LLMAPIHandler:
|
||||
def get_llm_api_handler(
|
||||
llm_key: str,
|
||||
base_parameters: dict[str, Any] | None = None,
|
||||
) -> LLMAPIHandler:
|
||||
try:
|
||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
except InvalidLLMConfigError:
|
||||
@@ -668,15 +758,14 @@ class LLMAPIHandlerFactory:
|
||||
|
||||
context = skyvern_context.current()
|
||||
is_speculative_step = step.is_speculative if step else False
|
||||
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.HASHED_HREF_MAP,
|
||||
step=step,
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
await _log_hashed_href_map_artifacts_if_needed(
|
||||
context,
|
||||
step,
|
||||
task_v2,
|
||||
thought,
|
||||
ai_suggestion,
|
||||
is_speculative_step=is_speculative_step,
|
||||
)
|
||||
|
||||
llm_prompt_value = prompt
|
||||
if step and not is_speculative_step:
|
||||
@@ -746,6 +835,9 @@ class LLMAPIHandlerFactory:
|
||||
"Adding Vertex AI cache reference to request",
|
||||
prompt_name=prompt_name,
|
||||
cache_attached=True,
|
||||
cache_name=cache_resource_name,
|
||||
cache_key=getattr(context, "vertex_cache_key", None),
|
||||
cache_variant=getattr(context, "vertex_cache_variant", None),
|
||||
)
|
||||
elif "cached_content" in active_parameters:
|
||||
removed_cache = active_parameters.pop("cached_content", None)
|
||||
@@ -754,6 +846,9 @@ class LLMAPIHandlerFactory:
|
||||
"Removed Vertex AI cache reference from request",
|
||||
prompt_name=prompt_name,
|
||||
cache_was_attached=True,
|
||||
cache_name=cache_resource_name,
|
||||
cache_key=getattr(context, "vertex_cache_key", None),
|
||||
cache_variant=getattr(context, "vertex_cache_variant", None),
|
||||
)
|
||||
|
||||
llm_request_payload = {
|
||||
@@ -863,6 +958,8 @@ class LLMAPIHandlerFactory:
|
||||
if cached_tokens == 0:
|
||||
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
|
||||
_log_vertex_cache_hit_if_needed(context, prompt_name, model_name, cached_tokens)
|
||||
|
||||
if step:
|
||||
await app.DATABASE.update_step(
|
||||
task_id=step.task_id,
|
||||
@@ -1041,15 +1138,14 @@ class LLMCaller:
|
||||
|
||||
context = skyvern_context.current()
|
||||
is_speculative_step = step.is_speculative if step else False
|
||||
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.HASHED_HREF_MAP,
|
||||
step=step,
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
await _log_hashed_href_map_artifacts_if_needed(
|
||||
context,
|
||||
step,
|
||||
task_v2,
|
||||
thought,
|
||||
ai_suggestion,
|
||||
is_speculative_step=is_speculative_step,
|
||||
)
|
||||
|
||||
if screenshots and self.screenshot_scaling_enabled:
|
||||
target_dimension = self.get_screenshot_resize_target_dimension(window_dimension)
|
||||
|
||||
@@ -37,6 +37,9 @@ class SkyvernContext:
|
||||
use_prompt_caching: bool = False
|
||||
cached_static_prompt: str | None = None
|
||||
vertex_cache_name: str | None = None # Vertex AI cache resource name for explicit caching
|
||||
vertex_cache_key: str | None = None # Logical cache key (includes variant + llm key)
|
||||
vertex_cache_variant: str | None = None # Variant identifier used when creating the cache
|
||||
prompt_caching_settings: dict[str, bool] | None = None
|
||||
enable_speed_optimizations: bool = False
|
||||
|
||||
# script run context
|
||||
|
||||
@@ -66,7 +66,7 @@ async def collect_experiment_metadata(
|
||||
"LLM_NAME",
|
||||
"LLM_SECONDARY_NAME",
|
||||
# Add more experiment flags as needed
|
||||
"PROMPT_CACHING_ENABLED",
|
||||
"PROMPT_CACHING_OPTIMIZATION",
|
||||
"THINKING_BUDGET_OPTIMIZATION",
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user