Pedro/vertex cache minimal fix (#3981)

This commit is contained in:
pedrohsdb
2025-11-12 10:40:52 -08:00
committed by GitHub
parent e8e8481f78
commit d88ca1ca27
3 changed files with 114 additions and 40 deletions

View File

@@ -37,6 +37,8 @@ from skyvern.utils.image_resizer import Resolution, get_resize_target_dimension,
LOG = structlog.get_logger()
EXTRACT_ACTION_PROMPT_NAME = "extract-actions"
class LLMCallStats(BaseModel):
input_tokens: int | None = None
@@ -313,12 +315,28 @@ class LLMAPIHandlerFactory:
except Exception as e:
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
vertex_cache_attached = False
cache_resource_name = getattr(context, "vertex_cache_name", None)
if (
cache_resource_name
and prompt_name == EXTRACT_ACTION_PROMPT_NAME
and getattr(context, "use_prompt_caching", False)
):
parameters = {**parameters, "cached_content": cache_resource_name}
vertex_cache_attached = True
LOG.info(
"Adding Vertex AI cache reference to router request",
prompt_name=prompt_name,
cache_attached=True,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
"vertex_cache_attached": vertex_cache_attached,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
@@ -473,6 +491,7 @@ class LLMAPIHandlerFactory:
return parsed_response
llm_api_handler_with_router_and_fallback.llm_key = llm_key # type: ignore[attr-defined]
return llm_api_handler_with_router_and_fallback
@staticmethod
@@ -592,10 +611,15 @@ class LLMAPIHandlerFactory:
# Add Vertex AI cache reference only for the intended cached prompt
vertex_cache_attached = False
cache_resource_name = getattr(context, "vertex_cache_name", None)
LOG.info(
"Vertex cache attachment check",
cache_resource_name=cache_resource_name,
prompt_name=prompt_name,
use_prompt_caching=getattr(context, "use_prompt_caching", None) if context else None,
)
if (
cache_resource_name
and "vertex_ai/" in model_name
and prompt_name == "extract-actions"
and prompt_name == EXTRACT_ACTION_PROMPT_NAME
and getattr(context, "use_prompt_caching", False)
):
active_parameters["cached_content"] = cache_resource_name
@@ -779,6 +803,7 @@ class LLMAPIHandlerFactory:
return parsed_response
llm_api_handler.llm_key = llm_key # type: ignore[attr-defined]
return llm_api_handler
@staticmethod