From 1eca20b78a2464e4e4fd21bbe48654925e54b7e0 Mon Sep 17 00:00:00 2001 From: Stanislav Novosad Date: Wed, 17 Dec 2025 20:15:26 -0700 Subject: [PATCH] Batch LLM artifacts creation (#4322) --- skyvern/forge/agent.py | 56 +- .../forge/sdk/api/llm/api_handler_factory.py | 1632 +++++++++-------- skyvern/forge/sdk/artifact/manager.py | 48 +- 3 files changed, 908 insertions(+), 828 deletions(-) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 7aef2181..1ee8040e 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -1880,42 +1880,56 @@ class ForgeAgent: LOG.debug("Persisting speculative LLM metadata") + artifacts = [] if metadata.prompt: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=metadata.prompt.encode("utf-8"), - artifact_type=ArtifactType.LLM_PROMPT, - screenshots=screenshots, - step=step, + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=metadata.prompt.encode("utf-8"), + artifact_type=ArtifactType.LLM_PROMPT, + screenshots=screenshots, + step=step, + ) ) if metadata.llm_request_json: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=metadata.llm_request_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_REQUEST, - step=step, + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=metadata.llm_request_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + step=step, + ) ) if metadata.llm_response_json: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=metadata.llm_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE, - step=step, + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=metadata.llm_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE, + step=step, + ) ) if metadata.parsed_response_json: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=metadata.parsed_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_PARSED, - step=step, + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=metadata.parsed_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + step=step, + ) ) if metadata.rendered_response_json: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=metadata.rendered_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, - step=step, + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=metadata.rendered_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, + step=step, + ) ) + if artifacts: + await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts) + incremental_cost = metadata.llm_cost if metadata.llm_cost and metadata.llm_cost > 0 else None incremental_input_tokens = ( metadata.input_tokens if metadata.input_tokens and metadata.input_tokens > 0 else None diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index d658ee0f..d8dc23ee 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -35,6 +35,7 @@ from skyvern.forge.sdk.api.llm.models import ( ) from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response +from skyvern.forge.sdk.artifact.manager import BulkArtifactCreationRequest from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import SkyvernContext @@ -93,6 +94,7 @@ def _get_artifact_targets_and_persist_flag( async def _log_hashed_href_map_artifacts_if_needed( + artifacts: list[BulkArtifactCreationRequest | None], context: SkyvernContext | None, step: Step | None, task_v2: TaskV2 | None, @@ -105,10 +107,12 @@ async def _log_hashed_href_map_artifacts_if_needed( step, is_speculative_step, task_v2, thought, ai_suggestion ) if context and context.hashed_href_map and should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), - artifact_type=ArtifactType.HASHED_HREF_MAP, - **artifact_targets, + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), + artifact_type=ArtifactType.HASHED_HREF_MAP, + **artifact_targets, + ) ) @@ -446,334 +450,331 @@ class LLMAPIHandlerFactory: should_persist_llm_artifacts, artifact_targets = _get_artifact_targets_and_persist_flag( step, is_speculative_step, task_v2, thought, ai_suggestion ) - await _log_hashed_href_map_artifacts_if_needed( - context, - step, - task_v2, - thought, - ai_suggestion, - is_speculative_step=is_speculative_step, - ) - llm_prompt_value = prompt - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=llm_prompt_value.encode("utf-8"), - artifact_type=ArtifactType.LLM_PROMPT, - screenshots=screenshots, - **artifact_targets, - ) - # Build messages and apply caching in one step - messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) - - async def _log_llm_request_artifact(model_label: str, vertex_cache_attached_flag: bool) -> str: - llm_request_payload = { - "model": model_label, - "messages": messages, - **parameters, - "vertex_cache_attached": vertex_cache_attached_flag, - } - llm_request_json = json.dumps(llm_request_payload) - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=llm_request_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_REQUEST, - **artifact_targets, - ) - return llm_request_json - - # Inject context caching system message when available - # IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts - # (e.g., check-user-goal) with the extract-action schema + artifacts: list[BulkArtifactCreationRequest | None] = [] try: - if ( - context - and context.cached_static_prompt - and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions - and isinstance(llm_config, LLMConfig) - and isinstance(llm_config.model_name, str) - ): - # Check if this is an OpenAI model + await _log_hashed_href_map_artifacts_if_needed( + artifacts, + context, + step, + task_v2, + thought, + ai_suggestion, + is_speculative_step=is_speculative_step, + ) + + llm_prompt_value = prompt + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=llm_prompt_value.encode("utf-8"), + artifact_type=ArtifactType.LLM_PROMPT, + screenshots=screenshots, + **artifact_targets, + ) + ) + # Build messages and apply caching in one step + messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) + + async def _log_llm_request_artifact(model_label: str, vertex_cache_attached_flag: bool) -> str: + llm_request_payload = { + "model": model_label, + "messages": messages, + **parameters, + "vertex_cache_attached": vertex_cache_attached_flag, + } + llm_request_json = json.dumps(llm_request_payload) + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=llm_request_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + **artifact_targets, + ) + ) + return llm_request_json + + # Inject context caching system message when available + # IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts + # (e.g., check-user-goal) with the extract-action schema + try: if ( - llm_config.model_name.startswith("gpt-") - or llm_config.model_name.startswith("o1-") - or llm_config.model_name.startswith("o3-") + context + and context.cached_static_prompt + and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions + and isinstance(llm_config, LLMConfig) + and isinstance(llm_config.model_name, str) ): - # For OpenAI models, we need to add the cached content as a system message - # and mark it for caching using the cache_control parameter - caching_system_message = { - "role": "system", - "content": [ - { - "type": "text", - "text": context.cached_static_prompt, - } - ], - } - messages = [caching_system_message] + messages - LOG.info( - "Applied OpenAI context caching", - prompt_name=prompt_name, - model=llm_config.model_name, + # Check if this is an OpenAI model + if ( + llm_config.model_name.startswith("gpt-") + or llm_config.model_name.startswith("o1-") + or llm_config.model_name.startswith("o3-") + ): + # For OpenAI models, we need to add the cached content as a system message + # and mark it for caching using the cache_control parameter + caching_system_message = { + "role": "system", + "content": [ + { + "type": "text", + "text": context.cached_static_prompt, + } + ], + } + messages = [caching_system_message] + messages + LOG.info( + "Applied OpenAI context caching", + prompt_name=prompt_name, + model=llm_config.model_name, + ) + except Exception as e: + LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True) + + cache_resource_name = getattr(context, "vertex_cache_name", None) + cache_variant = getattr(context, "vertex_cache_variant", None) + primary_model_dict = _get_primary_model_dict(router, main_model_group) + should_attach_vertex_cache = bool( + cache_resource_name is not None + and prompt_name == EXTRACT_ACTION_PROMPT_NAME + and getattr(context, "use_prompt_caching", False) + and main_model_group + and "gemini" in main_model_group.lower() + and primary_model_dict is not None + ) + + model_used = main_model_group + llm_request_json = "" + + async def _call_primary_with_vertex_cache( + cache_name: str, + cache_variant_name: str | None, + ) -> tuple[ModelResponse, str, str]: + if primary_model_dict is None: + raise ValueError("Primary router model missing configuration") + litellm_params = copy.deepcopy(primary_model_dict.get("litellm_params") or {}) + if not litellm_params: + raise ValueError("Primary router model missing litellm_params") + active_params = copy.deepcopy(litellm_params) + active_params.update(parameters) + active_params["cached_content"] = cache_name + request_model = active_params.pop("model", primary_model_dict.get("model_name", main_model_group)) + + # Clone messages to avoid modifying original list which is needed for fallback + active_messages = copy.deepcopy(messages) + + # Strip static prompt from the request messages because it's already in the cache + # Sending it again causes double-billing (once cached, once uncached) + if context and context.cached_static_prompt: + prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages( + active_messages, context.cached_static_prompt ) - except Exception as e: - LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True) - cache_resource_name = getattr(context, "vertex_cache_name", None) - cache_variant = getattr(context, "vertex_cache_variant", None) - primary_model_dict = _get_primary_model_dict(router, main_model_group) - should_attach_vertex_cache = bool( - cache_resource_name is not None - and prompt_name == EXTRACT_ACTION_PROMPT_NAME - and getattr(context, "use_prompt_caching", False) - and main_model_group - and "gemini" in main_model_group.lower() - and primary_model_dict is not None - ) + if prompt_stripped: + LOG.info("Stripped static prompt from cached request to avoid double-billing") + else: + LOG.warning("Could not find static prompt to strip from cached request") - model_used = main_model_group - llm_request_json = "" - - async def _call_primary_with_vertex_cache( - cache_name: str, - cache_variant_name: str | None, - ) -> tuple[ModelResponse, str, str]: - if primary_model_dict is None: - raise ValueError("Primary router model missing configuration") - litellm_params = copy.deepcopy(primary_model_dict.get("litellm_params") or {}) - if not litellm_params: - raise ValueError("Primary router model missing litellm_params") - active_params = copy.deepcopy(litellm_params) - active_params.update(parameters) - active_params["cached_content"] = cache_name - request_model = active_params.pop("model", primary_model_dict.get("model_name", main_model_group)) - - # Clone messages to avoid modifying original list which is needed for fallback - active_messages = copy.deepcopy(messages) - - # Strip static prompt from the request messages because it's already in the cache - # Sending it again causes double-billing (once cached, once uncached) - if context and context.cached_static_prompt: - prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages( - active_messages, context.cached_static_prompt + LOG.info( + "Adding Vertex AI cache reference to primary Gemini request", + prompt_name=prompt_name, + primary_model=main_model_group, + fallback_model=llm_config.fallback_model_group, + cache_name=cache_name, + cache_key=getattr(context, "vertex_cache_key", None), + cache_variant=cache_variant_name, ) + request_payload_json = await _log_llm_request_artifact(request_model, True) + response = await litellm.acompletion( + model=request_model, + messages=active_messages, + timeout=settings.LLM_CONFIG_TIMEOUT, + drop_params=True, + **active_params, + ) + return response, request_model, request_payload_json - if prompt_stripped: - LOG.info("Stripped static prompt from cached request to avoid double-billing") - else: - LOG.warning("Could not find static prompt to strip from cached request") + async def _call_router_without_cache() -> tuple[ModelResponse, str]: + request_payload_json = await _log_llm_request_artifact(llm_key, False) + response = await router.acompletion( + model=main_model_group, + messages=messages, + timeout=settings.LLM_CONFIG_TIMEOUT, + drop_params=True, + **parameters, + ) + return response, request_payload_json - LOG.info( - "Adding Vertex AI cache reference to primary Gemini request", - prompt_name=prompt_name, - primary_model=main_model_group, - fallback_model=llm_config.fallback_model_group, - cache_name=cache_name, - cache_key=getattr(context, "vertex_cache_key", None), - cache_variant=cache_variant_name, - ) - request_payload_json = await _log_llm_request_artifact(request_model, True) - response = await litellm.acompletion( - model=request_model, - messages=active_messages, - timeout=settings.LLM_CONFIG_TIMEOUT, - drop_params=True, - **active_params, - ) - return response, request_model, request_payload_json + try: + response: ModelResponse | None = None + if should_attach_vertex_cache and cache_resource_name: + try: + response, direct_model_used, llm_request_json = await _call_primary_with_vertex_cache( + cache_resource_name, + cache_variant, + ) + model_used = response.model or direct_model_used + except CancelledError: + raise + except Exception as cache_error: + LOG.warning( + "Vertex cache primary call failed, retrying via router", + prompt_name=prompt_name, + error=str(cache_error), + cache_name=cache_resource_name, + cache_variant=cache_variant, + ) + response = None - async def _call_router_without_cache() -> tuple[ModelResponse, str]: - request_payload_json = await _log_llm_request_artifact(llm_key, False) - response = await router.acompletion( - model=main_model_group, - messages=messages, - timeout=settings.LLM_CONFIG_TIMEOUT, - drop_params=True, - **parameters, - ) - return response, request_payload_json + if response is None: + response, llm_request_json = await _call_router_without_cache() + response_model = response.model or main_model_group + model_used = response_model + if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group): + LOG.info( + "LLM router fallback succeeded", + llm_key=llm_key, + prompt_name=prompt_name, + primary_model=main_model_group, + fallback_model=response_model, + ) + except litellm.exceptions.APIError as e: + raise LLMProviderErrorRetryableTask(llm_key) from e + except litellm.exceptions.ContextWindowExceededError as e: + duration_seconds = time.time() - start_time + LOG.exception( + "Context window exceeded", + llm_key=llm_key, + model=main_model_group, + prompt_name=prompt_name, + duration_seconds=duration_seconds, + ) + raise SkyvernContextWindowExceededError() from e + except ValueError as e: + duration_seconds = time.time() - start_time + LOG.exception( + "LLM token limit exceeded", + llm_key=llm_key, + model=main_model_group, + prompt_name=prompt_name, + duration_seconds=duration_seconds, + ) + raise LLMProviderErrorRetryableTask(llm_key) from e + except Exception as e: + duration_seconds = time.time() - start_time + LOG.exception( + "LLM request failed unexpectedly", + llm_key=llm_key, + model=main_model_group, + prompt_name=prompt_name, + duration_seconds=duration_seconds, + ) + raise LLMProviderError(llm_key) from e - try: - response: ModelResponse | None = None - if should_attach_vertex_cache and cache_resource_name: - try: - response, direct_model_used, llm_request_json = await _call_primary_with_vertex_cache( - cache_resource_name, - cache_variant, - ) - model_used = response.model or direct_model_used - except CancelledError: - raise - except Exception as cache_error: - LOG.warning( - "Vertex cache primary call failed, retrying via router", - prompt_name=prompt_name, - error=str(cache_error), - cache_name=cache_resource_name, - cache_variant=cache_variant, - ) - response = None - - if response is None: - response, llm_request_json = await _call_router_without_cache() - response_model = response.model or main_model_group - model_used = response_model - if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group): - LOG.info( - "LLM router fallback succeeded", - llm_key=llm_key, - prompt_name=prompt_name, - primary_model=main_model_group, - fallback_model=response_model, - ) - except litellm.exceptions.APIError as e: - raise LLMProviderErrorRetryableTask(llm_key) from e - except litellm.exceptions.ContextWindowExceededError as e: - duration_seconds = time.time() - start_time - LOG.exception( - "Context window exceeded", - llm_key=llm_key, - model=main_model_group, - prompt_name=prompt_name, - duration_seconds=duration_seconds, - ) - raise SkyvernContextWindowExceededError() from e - except ValueError as e: - duration_seconds = time.time() - start_time - LOG.exception( - "LLM token limit exceeded", - llm_key=llm_key, - model=main_model_group, - prompt_name=prompt_name, - duration_seconds=duration_seconds, - ) - raise LLMProviderErrorRetryableTask(llm_key) from e - except Exception as e: - duration_seconds = time.time() - start_time - LOG.exception( - "LLM request failed unexpectedly", - llm_key=llm_key, - model=main_model_group, - prompt_name=prompt_name, - duration_seconds=duration_seconds, - ) - raise LLMProviderError(llm_key) from e - - llm_response_json = response.model_dump_json(indent=2) - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=llm_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE, - **artifact_targets, - ) - prompt_tokens = 0 - completion_tokens = 0 - reasoning_tokens = 0 - cached_tokens = 0 - completion_token_detail = None - cached_token_detail = None - try: - # FIXME: volcengine doesn't support litellm cost calculation. - llm_cost = litellm.completion_cost(completion_response=response) - except Exception as e: - LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) - llm_cost = 0 - prompt_tokens = 0 - completion_tokens = 0 - reasoning_tokens = 0 - cached_tokens = 0 - - if hasattr(response, "usage") and response.usage: - prompt_tokens = getattr(response.usage, "prompt_tokens", 0) - completion_tokens = getattr(response.usage, "completion_tokens", 0) - - # Extract reasoning tokens from completion_tokens_details - completion_token_detail = getattr(response.usage, "completion_tokens_details", None) - if completion_token_detail: - reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0 - - # Extract cached tokens from prompt_tokens_details - cached_token_detail = getattr(response.usage, "prompt_tokens_details", None) - if cached_token_detail: - cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0 - - # Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage - if cached_tokens == 0: - cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0 - if step and not is_speculative_step: - await app.DATABASE.update_step( - task_id=step.task_id, - step_id=step.step_id, - organization_id=step.organization_id, - incremental_cost=llm_cost, - incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, - incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, - incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, - incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, - ) - if thought: - await app.DATABASE.update_thought( - thought_id=thought.observer_thought_id, - organization_id=thought.organization_id, - input_token_count=prompt_tokens if prompt_tokens > 0 else None, - output_token_count=completion_tokens if completion_tokens > 0 else None, - thought_cost=llm_cost, - reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, - cached_token_count=cached_tokens if cached_tokens > 0 else None, - ) - parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict) - parsed_response_json = json.dumps(parsed_response, indent=2) - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=parsed_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_PARSED, - **artifact_targets, - ) - - rendered_response_json = None - if context and len(context.hashed_href_map) > 0: - llm_content = json.dumps(parsed_response) - rendered_content = Template(llm_content).render(context.hashed_href_map) - parsed_response = json.loads(rendered_content) - rendered_response_json = json.dumps(parsed_response, indent=2) + llm_response_json = response.model_dump_json(indent=2) if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=rendered_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, - **artifact_targets, + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=llm_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE, + **artifact_targets, + ) + ) + prompt_tokens = 0 + completion_tokens = 0 + reasoning_tokens = 0 + cached_tokens = 0 + completion_token_detail = None + cached_token_detail = None + try: + # FIXME: volcengine doesn't support litellm cost calculation. + llm_cost = litellm.completion_cost(completion_response=response) + except Exception as e: + LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) + llm_cost = 0 + prompt_tokens = 0 + completion_tokens = 0 + reasoning_tokens = 0 + cached_tokens = 0 + + if hasattr(response, "usage") and response.usage: + prompt_tokens = getattr(response.usage, "prompt_tokens", 0) + completion_tokens = getattr(response.usage, "completion_tokens", 0) + + # Extract reasoning tokens from completion_tokens_details + completion_token_detail = getattr(response.usage, "completion_tokens_details", None) + if completion_token_detail: + reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0 + + # Extract cached tokens from prompt_tokens_details + cached_token_detail = getattr(response.usage, "prompt_tokens_details", None) + if cached_token_detail: + cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0 + + # Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage + if cached_tokens == 0: + cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0 + if step and not is_speculative_step: + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=llm_cost, + incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, + incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, + incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, + incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, + ) + if thought: + await app.DATABASE.update_thought( + thought_id=thought.observer_thought_id, + organization_id=thought.organization_id, + input_token_count=prompt_tokens if prompt_tokens > 0 else None, + output_token_count=completion_tokens if completion_tokens > 0 else None, + thought_cost=llm_cost, + reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, + cached_token_count=cached_tokens if cached_tokens > 0 else None, + ) + parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict) + parsed_response_json = json.dumps(parsed_response, indent=2) + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=parsed_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + **artifact_targets, + ) ) - # Track LLM API handler duration, token counts, and cost - organization_id = organization_id or ( - step.organization_id if step else (thought.organization_id if thought else None) - ) - duration_seconds = time.time() - start_time - LOG.info( - "LLM API handler duration metrics", - llm_key=llm_key, - model=model_used, - prompt_name=prompt_name, - duration_seconds=duration_seconds, - step_id=step.step_id if step else None, - thought_id=thought.observer_thought_id if thought else None, - organization_id=organization_id, - input_tokens=prompt_tokens if prompt_tokens > 0 else None, - output_tokens=completion_tokens if completion_tokens > 0 else None, - reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, - cached_tokens=cached_tokens if cached_tokens > 0 else None, - llm_cost=llm_cost if llm_cost > 0 else None, - ) + rendered_response_json = None + if context and len(context.hashed_href_map) > 0: + llm_content = json.dumps(parsed_response) + rendered_content = Template(llm_content).render(context.hashed_href_map) + parsed_response = json.loads(rendered_content) + rendered_response_json = json.dumps(parsed_response, indent=2) + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=rendered_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, + **artifact_targets, + ) + ) - if step and is_speculative_step: - step.speculative_llm_metadata = SpeculativeLLMMetadata( - prompt=llm_prompt_value, - llm_request_json=llm_request_json, - llm_response_json=llm_response_json, - parsed_response_json=parsed_response_json, - rendered_response_json=rendered_response_json, + # Track LLM API handler duration, token counts, and cost + organization_id = organization_id or ( + step.organization_id if step else (thought.organization_id if thought else None) + ) + duration_seconds = time.time() - start_time + LOG.info( + "LLM API handler duration metrics", llm_key=llm_key, model=model_used, + prompt_name=prompt_name, duration_seconds=duration_seconds, + step_id=step.step_id if step else None, + thought_id=thought.observer_thought_id if thought else None, + organization_id=organization_id, input_tokens=prompt_tokens if prompt_tokens > 0 else None, output_tokens=completion_tokens if completion_tokens > 0 else None, reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, @@ -781,7 +782,29 @@ class LLMAPIHandlerFactory: llm_cost=llm_cost if llm_cost > 0 else None, ) - return parsed_response + if step and is_speculative_step: + step.speculative_llm_metadata = SpeculativeLLMMetadata( + prompt=llm_prompt_value, + llm_request_json=llm_request_json, + llm_response_json=llm_response_json, + parsed_response_json=parsed_response_json, + rendered_response_json=rendered_response_json, + llm_key=llm_key, + model=model_used, + duration_seconds=duration_seconds, + input_tokens=prompt_tokens if prompt_tokens > 0 else None, + output_tokens=completion_tokens if completion_tokens > 0 else None, + reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, + cached_tokens=cached_tokens if cached_tokens > 0 else None, + llm_cost=llm_cost if llm_cost > 0 else None, + ) + + return parsed_response + finally: + try: + await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts) + except Exception: + LOG.error("Failed to persist artifacts", exc_info=True) llm_api_handler_with_router_and_fallback.llm_key = llm_key # type: ignore[attr-defined] return llm_api_handler_with_router_and_fallback @@ -855,302 +878,299 @@ class LLMAPIHandlerFactory: should_persist_llm_artifacts, artifact_targets = _get_artifact_targets_and_persist_flag( step, is_speculative_step, task_v2, thought, ai_suggestion ) - await _log_hashed_href_map_artifacts_if_needed( - context, - step, - task_v2, - thought, - ai_suggestion, - is_speculative_step=is_speculative_step, - ) - llm_prompt_value = prompt - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=llm_prompt_value.encode("utf-8"), - artifact_type=ArtifactType.LLM_PROMPT, - screenshots=screenshots, - **artifact_targets, - ) - - if not llm_config.supports_vision: - screenshots = None - - model_name = llm_config.model_name - - messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) - - # Inject context caching system message when available - # IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts - # (e.g., check-user-goal) with the extract-action schema + artifacts: list[BulkArtifactCreationRequest | None] = [] try: - if ( - context - and context.cached_static_prompt - and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions - and isinstance(llm_config, LLMConfig) - and isinstance(llm_config.model_name, str) - ): - # Check if this is an OpenAI model - if ( - llm_config.model_name.startswith("gpt-") - or llm_config.model_name.startswith("o1-") - or llm_config.model_name.startswith("o3-") - ): - # For OpenAI models, we need to add the cached content as a system message - # and mark it for caching using the cache_control parameter - caching_system_message = { - "role": "system", - "content": [ - { - "type": "text", - "text": context.cached_static_prompt, - } - ], - } - messages = [caching_system_message] + messages - LOG.info( - "Applied OpenAI context caching", - prompt_name=prompt_name, - model=llm_config.model_name, - ) - except Exception as e: - LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True) - - # Add Vertex AI cache reference only for the intended cached prompt - vertex_cache_attached = False - cache_resource_name = getattr(context, "vertex_cache_name", None) - if ( - cache_resource_name - and prompt_name == EXTRACT_ACTION_PROMPT_NAME - and getattr(context, "use_prompt_caching", False) - and "gemini" in model_name.lower() - ): - active_parameters["cached_content"] = cache_resource_name - vertex_cache_attached = True - LOG.info( - "Adding Vertex AI cache reference to request", - prompt_name=prompt_name, - cache_attached=True, - cache_name=cache_resource_name, - cache_key=getattr(context, "vertex_cache_key", None), - cache_variant=getattr(context, "vertex_cache_variant", None), + await _log_hashed_href_map_artifacts_if_needed( + artifacts, + context, + step, + task_v2, + thought, + ai_suggestion, + is_speculative_step=is_speculative_step, ) - elif "cached_content" in active_parameters: - removed_cache = active_parameters.pop("cached_content", None) - if removed_cache: + + llm_prompt_value = prompt + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=llm_prompt_value.encode("utf-8"), + artifact_type=ArtifactType.LLM_PROMPT, + screenshots=screenshots, + **artifact_targets, + ) + ) + + if not llm_config.supports_vision: + screenshots = None + + model_name = llm_config.model_name + + messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) + + # Inject context caching system message when available + # IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts + # (e.g., check-user-goal) with the extract-action schema + try: + if ( + context + and context.cached_static_prompt + and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions + and isinstance(llm_config, LLMConfig) + and isinstance(llm_config.model_name, str) + ): + # Check if this is an OpenAI model + if ( + llm_config.model_name.startswith("gpt-") + or llm_config.model_name.startswith("o1-") + or llm_config.model_name.startswith("o3-") + ): + # For OpenAI models, we need to add the cached content as a system message + # and mark it for caching using the cache_control parameter + caching_system_message = { + "role": "system", + "content": [ + { + "type": "text", + "text": context.cached_static_prompt, + } + ], + } + messages = [caching_system_message] + messages + LOG.info( + "Applied OpenAI context caching", + prompt_name=prompt_name, + model=llm_config.model_name, + ) + except Exception as e: + LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True) + + # Add Vertex AI cache reference only for the intended cached prompt + vertex_cache_attached = False + cache_resource_name = getattr(context, "vertex_cache_name", None) + if ( + cache_resource_name + and prompt_name == EXTRACT_ACTION_PROMPT_NAME + and getattr(context, "use_prompt_caching", False) + and "gemini" in model_name.lower() + ): + active_parameters["cached_content"] = cache_resource_name + vertex_cache_attached = True LOG.info( - "Removed Vertex AI cache reference from request", + "Adding Vertex AI cache reference to request", prompt_name=prompt_name, - cache_was_attached=True, + cache_attached=True, cache_name=cache_resource_name, cache_key=getattr(context, "vertex_cache_key", None), cache_variant=getattr(context, "vertex_cache_variant", None), ) + elif "cached_content" in active_parameters: + removed_cache = active_parameters.pop("cached_content", None) + if removed_cache: + LOG.info( + "Removed Vertex AI cache reference from request", + prompt_name=prompt_name, + cache_was_attached=True, + cache_name=cache_resource_name, + cache_key=getattr(context, "vertex_cache_key", None), + cache_variant=getattr(context, "vertex_cache_variant", None), + ) - llm_request_payload = { - "model": model_name, - "messages": messages, - # we're not using active_parameters here because it may contain sensitive information - **parameters, - "vertex_cache_attached": vertex_cache_attached, - } - llm_request_json = json.dumps(llm_request_payload) - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=llm_request_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_REQUEST, - **artifact_targets, - ) - - # Strip static prompt from the request messages because it's already in the cache - # Sending it again causes double-billing (once cached, once uncached) - active_messages = messages - if vertex_cache_attached and context and context.cached_static_prompt: - active_messages = copy.deepcopy(messages) - prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages( - active_messages, context.cached_static_prompt - ) - - if prompt_stripped: - LOG.info("Stripped static prompt from cached request to avoid double-billing") - else: - LOG.warning("Could not find static prompt to strip from cached request") - - t_llm_request = time.perf_counter() - try: - # TODO (kerem): add a retry mechanism to this call (acompletion_with_retries) - # TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work - response = await litellm.acompletion( - model=model_name, - messages=active_messages, - drop_params=True, # Drop unsupported parameters gracefully - **active_parameters, - ) - except litellm.exceptions.APIError as e: - raise LLMProviderErrorRetryableTask(llm_key) from e - except litellm.exceptions.ContextWindowExceededError as e: - duration_seconds = time.time() - start_time - LOG.exception( - "Context window exceeded", - llm_key=llm_key, - model=model_name, - prompt_name=prompt_name, - duration_seconds=duration_seconds, - ) - raise SkyvernContextWindowExceededError() from e - except CancelledError: - # Speculative steps are intentionally cancelled when goal verification completes first, - # so we log at debug level. Non-speculative cancellations are unexpected errors. - t_llm_cancelled = time.perf_counter() - if is_speculative_step: - LOG.debug( - "LLM request cancelled (speculative step)", - llm_key=llm_key, - model=model_name, - prompt_name=prompt_name, - duration=t_llm_cancelled - t_llm_request, - ) - raise - else: - LOG.error( - "LLM request got cancelled", - llm_key=llm_key, - model=model_name, - prompt_name=prompt_name, - duration=t_llm_cancelled - t_llm_request, - ) - raise LLMProviderError(llm_key) from None - except Exception as e: - duration_seconds = time.time() - start_time - LOG.exception( - "LLM request failed unexpectedly", - llm_key=llm_key, - model=model_name, - prompt_name=prompt_name, - duration_seconds=duration_seconds, - ) - raise LLMProviderError(llm_key) from e - - llm_response_json = response.model_dump_json(indent=2) - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=llm_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE, - **artifact_targets, - ) - - prompt_tokens = 0 - completion_tokens = 0 - reasoning_tokens = 0 - cached_tokens = 0 - completion_token_detail = None - cached_token_detail = None - try: - # FIXME: volcengine doesn't support litellm cost calculation. - llm_cost = litellm.completion_cost(completion_response=response) - except Exception as e: - LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) - llm_cost = 0 - prompt_tokens = 0 - completion_tokens = 0 - reasoning_tokens = 0 - cached_tokens = 0 - - if hasattr(response, "usage") and response.usage: - prompt_tokens = getattr(response.usage, "prompt_tokens", 0) - completion_tokens = getattr(response.usage, "completion_tokens", 0) - - # Extract reasoning tokens from completion_tokens_details - completion_token_detail = getattr(response.usage, "completion_tokens_details", None) - if completion_token_detail: - reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0 - - # Extract cached tokens from prompt_tokens_details - cached_token_detail = getattr(response.usage, "prompt_tokens_details", None) - if cached_token_detail: - cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0 - - # Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage - if cached_tokens == 0: - cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0 - - _log_vertex_cache_hit_if_needed(context, prompt_name, model_name, cached_tokens) - - if step and not is_speculative_step: - await app.DATABASE.update_step( - task_id=step.task_id, - step_id=step.step_id, - organization_id=step.organization_id, - incremental_cost=llm_cost, - incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, - incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, - incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, - incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, - ) - if thought: - await app.DATABASE.update_thought( - thought_id=thought.observer_thought_id, - organization_id=thought.organization_id, - input_token_count=prompt_tokens if prompt_tokens > 0 else None, - output_token_count=completion_tokens if completion_tokens > 0 else None, - reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, - cached_token_count=cached_tokens if cached_tokens > 0 else None, - thought_cost=llm_cost, - ) - parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict) - parsed_response_json = json.dumps(parsed_response, indent=2) - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=parsed_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_PARSED, - **artifact_targets, - ) - - rendered_response_json = None - if context and len(context.hashed_href_map) > 0: - llm_content = json.dumps(parsed_response) - rendered_content = Template(llm_content).render(context.hashed_href_map) - parsed_response = json.loads(rendered_content) - rendered_response_json = json.dumps(parsed_response, indent=2) + llm_request_payload = { + "model": model_name, + "messages": messages, + # we're not using active_parameters here because it may contain sensitive information + **parameters, + "vertex_cache_attached": vertex_cache_attached, + } + llm_request_json = json.dumps(llm_request_payload) if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=rendered_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, - **artifact_targets, + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=llm_request_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + **artifact_targets, + ) ) - # Track LLM API handler duration, token counts, and cost - organization_id = organization_id or ( - step.organization_id if step else (thought.organization_id if thought else None) - ) - duration_seconds = time.time() - start_time - LOG.info( - "LLM API handler duration metrics", - llm_key=llm_key, - prompt_name=prompt_name, - model=llm_config.model_name, - duration_seconds=duration_seconds, - step_id=step.step_id if step else None, - thought_id=thought.observer_thought_id if thought else None, - organization_id=organization_id, - input_tokens=prompt_tokens if prompt_tokens > 0 else None, - output_tokens=completion_tokens if completion_tokens > 0 else None, - reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, - cached_tokens=cached_tokens if cached_tokens > 0 else None, - llm_cost=llm_cost if llm_cost > 0 else None, - ) + # Strip static prompt from the request messages because it's already in the cache + # Sending it again causes double-billing (once cached, once uncached) + active_messages = messages + if vertex_cache_attached and context and context.cached_static_prompt: + active_messages = copy.deepcopy(messages) + prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages( + active_messages, context.cached_static_prompt + ) - if step and is_speculative_step: - step.speculative_llm_metadata = SpeculativeLLMMetadata( - prompt=llm_prompt_value, - llm_request_json=llm_request_json, - llm_response_json=llm_response_json, - parsed_response_json=parsed_response_json, - rendered_response_json=rendered_response_json, + if prompt_stripped: + LOG.info("Stripped static prompt from cached request to avoid double-billing") + else: + LOG.warning("Could not find static prompt to strip from cached request") + + t_llm_request = time.perf_counter() + try: + # TODO (kerem): add a retry mechanism to this call (acompletion_with_retries) + # TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work + response = await litellm.acompletion( + model=model_name, + messages=active_messages, + drop_params=True, # Drop unsupported parameters gracefully + **active_parameters, + ) + except litellm.exceptions.APIError as e: + raise LLMProviderErrorRetryableTask(llm_key) from e + except litellm.exceptions.ContextWindowExceededError as e: + duration_seconds = time.time() - start_time + LOG.exception( + "Context window exceeded", + llm_key=llm_key, + model=model_name, + prompt_name=prompt_name, + duration_seconds=duration_seconds, + ) + raise SkyvernContextWindowExceededError() from e + except CancelledError: + # Speculative steps are intentionally cancelled when goal verification completes first, + # so we log at debug level. Non-speculative cancellations are unexpected errors. + t_llm_cancelled = time.perf_counter() + if is_speculative_step: + LOG.debug( + "LLM request cancelled (speculative step)", + llm_key=llm_key, + model=model_name, + prompt_name=prompt_name, + duration=t_llm_cancelled - t_llm_request, + ) + raise + else: + LOG.error( + "LLM request got cancelled", + llm_key=llm_key, + model=model_name, + prompt_name=prompt_name, + duration=t_llm_cancelled - t_llm_request, + ) + raise LLMProviderError(llm_key) from None + except Exception as e: + duration_seconds = time.time() - start_time + LOG.exception( + "LLM request failed unexpectedly", + llm_key=llm_key, + model=model_name, + prompt_name=prompt_name, + duration_seconds=duration_seconds, + ) + raise LLMProviderError(llm_key) from e + + llm_response_json = response.model_dump_json(indent=2) + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=llm_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE, + **artifact_targets, + ) + ) + + prompt_tokens = 0 + completion_tokens = 0 + reasoning_tokens = 0 + cached_tokens = 0 + completion_token_detail = None + cached_token_detail = None + try: + # FIXME: volcengine doesn't support litellm cost calculation. + llm_cost = litellm.completion_cost(completion_response=response) + except Exception as e: + LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) + llm_cost = 0 + prompt_tokens = 0 + completion_tokens = 0 + reasoning_tokens = 0 + cached_tokens = 0 + + if hasattr(response, "usage") and response.usage: + prompt_tokens = getattr(response.usage, "prompt_tokens", 0) + completion_tokens = getattr(response.usage, "completion_tokens", 0) + + # Extract reasoning tokens from completion_tokens_details + completion_token_detail = getattr(response.usage, "completion_tokens_details", None) + if completion_token_detail: + reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0 + + # Extract cached tokens from prompt_tokens_details + cached_token_detail = getattr(response.usage, "prompt_tokens_details", None) + if cached_token_detail: + cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0 + + # Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage + if cached_tokens == 0: + cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0 + + _log_vertex_cache_hit_if_needed(context, prompt_name, model_name, cached_tokens) + + if step and not is_speculative_step: + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=llm_cost, + incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, + incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, + incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, + incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, + ) + if thought: + await app.DATABASE.update_thought( + thought_id=thought.observer_thought_id, + organization_id=thought.organization_id, + input_token_count=prompt_tokens if prompt_tokens > 0 else None, + output_token_count=completion_tokens if completion_tokens > 0 else None, + reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, + cached_token_count=cached_tokens if cached_tokens > 0 else None, + thought_cost=llm_cost, + ) + parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict) + parsed_response_json = json.dumps(parsed_response, indent=2) + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=parsed_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + **artifact_targets, + ) + ) + + rendered_response_json = None + if context and len(context.hashed_href_map) > 0: + llm_content = json.dumps(parsed_response) + rendered_content = Template(llm_content).render(context.hashed_href_map) + parsed_response = json.loads(rendered_content) + rendered_response_json = json.dumps(parsed_response, indent=2) + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=rendered_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, + **artifact_targets, + ) + ) + + # Track LLM API handler duration, token counts, and cost + organization_id = organization_id or ( + step.organization_id if step else (thought.organization_id if thought else None) + ) + duration_seconds = time.time() - start_time + LOG.info( + "LLM API handler duration metrics", llm_key=llm_key, + prompt_name=prompt_name, model=llm_config.model_name, duration_seconds=duration_seconds, + step_id=step.step_id if step else None, + thought_id=thought.observer_thought_id if thought else None, + organization_id=organization_id, input_tokens=prompt_tokens if prompt_tokens > 0 else None, output_tokens=completion_tokens if completion_tokens > 0 else None, reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, @@ -1158,7 +1178,29 @@ class LLMAPIHandlerFactory: llm_cost=llm_cost if llm_cost > 0 else None, ) - return parsed_response + if step and is_speculative_step: + step.speculative_llm_metadata = SpeculativeLLMMetadata( + prompt=llm_prompt_value, + llm_request_json=llm_request_json, + llm_response_json=llm_response_json, + parsed_response_json=parsed_response_json, + rendered_response_json=rendered_response_json, + llm_key=llm_key, + model=llm_config.model_name, + duration_seconds=duration_seconds, + input_tokens=prompt_tokens if prompt_tokens > 0 else None, + output_tokens=completion_tokens if completion_tokens > 0 else None, + reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, + cached_tokens=cached_tokens if cached_tokens > 0 else None, + llm_cost=llm_cost if llm_cost > 0 else None, + ) + + return parsed_response + finally: + try: + await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts) + except Exception: + LOG.error("Failed to persist artifacts", exc_info=True) llm_api_handler.llm_key = llm_key # type: ignore[attr-defined] return llm_api_handler @@ -1278,215 +1320,235 @@ class LLMCaller: should_persist_llm_artifacts, artifact_targets = _get_artifact_targets_and_persist_flag( step, is_speculative_step, task_v2, thought, ai_suggestion ) - await _log_hashed_href_map_artifacts_if_needed( - context, - step, - task_v2, - thought, - ai_suggestion, - is_speculative_step=is_speculative_step, - ) - if screenshots and self.screenshot_scaling_enabled: - target_dimension = self.get_screenshot_resize_target_dimension(window_dimension) - if window_dimension and window_dimension != self.browser_window_dimension and tools: - # THIS situation only applies to Anthropic CUA - LOG.info( - "Window dimension is different from the default browser window dimension when making LLM call", - window_dimension=window_dimension, - browser_window_dimension=self.browser_window_dimension, - ) - # update the tools to use the new target dimension - for tool in tools: - if "display_height_px" in tool: - tool["display_height_px"] = target_dimension["height"] - if "display_width_px" in tool: - tool["display_width_px"] = target_dimension["width"] - screenshots = resize_screenshots(screenshots, target_dimension) - - llm_prompt_value = prompt or "" - if prompt and should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=prompt.encode("utf-8"), - artifact_type=ArtifactType.LLM_PROMPT, - screenshots=screenshots, - **artifact_targets, - ) - - if not self.llm_config.supports_vision: - screenshots = None - - message_pattern = "openai" - if "ANTHROPIC" in self.llm_key: - message_pattern = "anthropic" - - if use_message_history: - # self.message_history will be updated in place - messages = await llm_messages_builder_with_history( - prompt, - screenshots, - self.message_history, - message_pattern=message_pattern, - ) - else: - messages = await llm_messages_builder_with_history( - prompt, - screenshots, - message_pattern=message_pattern, - ) - llm_request_payload = { - "model": self.llm_config.model_name, - "messages": messages, - # we're not using active_parameters here because it may contain sensitive information - **parameters, - } - llm_request_json = json.dumps(llm_request_payload) - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=llm_request_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_REQUEST, - **artifact_targets, - ) - t_llm_request = time.perf_counter() + artifacts: list[BulkArtifactCreationRequest | None] = [] try: - response = await self._dispatch_llm_call( - messages=messages, - tools=tools, - timeout=settings.LLM_CONFIG_TIMEOUT, - **active_parameters, + await _log_hashed_href_map_artifacts_if_needed( + artifacts, + context, + step, + task_v2, + thought, + ai_suggestion, + is_speculative_step=is_speculative_step, ) + + if screenshots and self.screenshot_scaling_enabled: + target_dimension = self.get_screenshot_resize_target_dimension(window_dimension) + if window_dimension and window_dimension != self.browser_window_dimension and tools: + # THIS situation only applies to Anthropic CUA + LOG.info( + "Window dimension is different from the default browser window dimension when making LLM call", + window_dimension=window_dimension, + browser_window_dimension=self.browser_window_dimension, + ) + # update the tools to use the new target dimension + for tool in tools: + if "display_height_px" in tool: + tool["display_height_px"] = target_dimension["height"] + if "display_width_px" in tool: + tool["display_width_px"] = target_dimension["width"] + screenshots = resize_screenshots(screenshots, target_dimension) + + llm_prompt_value = prompt or "" + if prompt and should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=prompt.encode("utf-8"), + artifact_type=ArtifactType.LLM_PROMPT, + screenshots=screenshots, + **artifact_targets, + ) + ) + + if not self.llm_config.supports_vision: + screenshots = None + + message_pattern = "openai" + if "ANTHROPIC" in self.llm_key: + message_pattern = "anthropic" + if use_message_history: - # only update message_history when the request is successful - self.message_history = messages - except litellm.exceptions.APIError as e: - raise LLMProviderErrorRetryableTask(self.llm_key) from e - except litellm.exceptions.ContextWindowExceededError as e: - LOG.exception( - "Context window exceeded", - llm_key=self.llm_key, - model=self.llm_config.model_name, - ) - raise SkyvernContextWindowExceededError() from e - except CancelledError: - # Speculative steps are intentionally cancelled when goal verification returns completed, - # so we log at debug level. Non-speculative cancellations are unexpected errors. - t_llm_cancelled = time.perf_counter() - if is_speculative_step: - LOG.debug( - "LLM request cancelled (speculative step)", - llm_key=self.llm_key, - model=self.llm_config.model_name, - duration=t_llm_cancelled - t_llm_request, + # self.message_history will be updated in place + messages = await llm_messages_builder_with_history( + prompt, + screenshots, + self.message_history, + message_pattern=message_pattern, ) - raise else: - LOG.error( - "LLM request got cancelled", + messages = await llm_messages_builder_with_history( + prompt, + screenshots, + message_pattern=message_pattern, + ) + llm_request_payload = { + "model": self.llm_config.model_name, + "messages": messages, + # we're not using active_parameters here because it may contain sensitive information + **parameters, + } + llm_request_json = json.dumps(llm_request_payload) + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=llm_request_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + **artifact_targets, + ) + ) + + t_llm_request = time.perf_counter() + try: + response = await self._dispatch_llm_call( + messages=messages, + tools=tools, + timeout=settings.LLM_CONFIG_TIMEOUT, + **active_parameters, + ) + if use_message_history: + # only update message_history when the request is successful + self.message_history = messages + except litellm.exceptions.APIError as e: + raise LLMProviderErrorRetryableTask(self.llm_key) from e + except litellm.exceptions.ContextWindowExceededError as e: + LOG.exception( + "Context window exceeded", llm_key=self.llm_key, model=self.llm_config.model_name, - duration=t_llm_cancelled - t_llm_request, ) - raise LLMProviderError(self.llm_key) from None - except Exception as e: - LOG.exception("LLM request failed unexpectedly", llm_key=self.llm_key) - raise LLMProviderError(self.llm_key) from e + raise SkyvernContextWindowExceededError() from e + except CancelledError: + # Speculative steps are intentionally cancelled when goal verification returns completed, + # so we log at debug level. Non-speculative cancellations are unexpected errors. + t_llm_cancelled = time.perf_counter() + if is_speculative_step: + LOG.debug( + "LLM request cancelled (speculative step)", + llm_key=self.llm_key, + model=self.llm_config.model_name, + duration=t_llm_cancelled - t_llm_request, + ) + raise + else: + LOG.error( + "LLM request got cancelled", + llm_key=self.llm_key, + model=self.llm_config.model_name, + duration=t_llm_cancelled - t_llm_request, + ) + raise LLMProviderError(self.llm_key) from None + except Exception as e: + LOG.exception("LLM request failed unexpectedly", llm_key=self.llm_key) + raise LLMProviderError(self.llm_key) from e - llm_response_json = response.model_dump_json(indent=2) - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=llm_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE, - **artifact_targets, - ) - - call_stats = await self.get_call_stats(response) - if step and not is_speculative_step: - await app.DATABASE.update_step( - task_id=step.task_id, - step_id=step.step_id, - organization_id=step.organization_id, - incremental_cost=call_stats.llm_cost, - incremental_input_tokens=call_stats.input_tokens, - incremental_output_tokens=call_stats.output_tokens, - incremental_reasoning_tokens=call_stats.reasoning_tokens, - incremental_cached_tokens=call_stats.cached_tokens, - ) - if thought: - await app.DATABASE.update_thought( - thought_id=thought.observer_thought_id, - organization_id=thought.organization_id, - input_token_count=call_stats.input_tokens, - output_token_count=call_stats.output_tokens, - reasoning_token_count=call_stats.reasoning_tokens, - cached_token_count=call_stats.cached_tokens, - thought_cost=call_stats.llm_cost, - ) - - organization_id = organization_id or ( - step.organization_id if step else (thought.organization_id if thought else None) - ) - # Track LLM API handler duration, token counts, and cost - duration_seconds = time.perf_counter() - start_time - LOG.info( - "LLM API handler duration metrics", - llm_key=self.llm_key, - prompt_name=prompt_name, - model=self.llm_config.model_name, - duration_seconds=duration_seconds, - step_id=step.step_id if step else None, - thought_id=thought.observer_thought_id if thought else None, - organization_id=organization_id, - input_tokens=call_stats.input_tokens if call_stats and call_stats.input_tokens else None, - output_tokens=call_stats.output_tokens if call_stats and call_stats.output_tokens else None, - reasoning_tokens=call_stats.reasoning_tokens if call_stats and call_stats.reasoning_tokens else None, - cached_tokens=call_stats.cached_tokens if call_stats and call_stats.cached_tokens else None, - llm_cost=call_stats.llm_cost if call_stats and call_stats.llm_cost else None, - ) - - # Raw response is used for CUA engine LLM calls. - if raw_response: - return response.model_dump(exclude_none=True) - - parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix, force_dict) - parsed_response_json = json.dumps(parsed_response, indent=2) - if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=parsed_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_PARSED, - **artifact_targets, - ) - - rendered_response_json = None - if context and len(context.hashed_href_map) > 0: - llm_content = json.dumps(parsed_response) - rendered_content = Template(llm_content).render(context.hashed_href_map) - parsed_response = json.loads(rendered_content) - rendered_response_json = json.dumps(parsed_response, indent=2) + llm_response_json = response.model_dump_json(indent=2) if should_persist_llm_artifacts: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=rendered_response_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, - **artifact_targets, + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=llm_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE, + **artifact_targets, + ) ) - if step and is_speculative_step: - step.speculative_llm_metadata = SpeculativeLLMMetadata( - prompt=llm_prompt_value, - llm_request_json=llm_request_json, - llm_response_json=llm_response_json, - parsed_response_json=parsed_response_json, - rendered_response_json=rendered_response_json, + call_stats = await self.get_call_stats(response) + if step and not is_speculative_step: + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=call_stats.llm_cost, + incremental_input_tokens=call_stats.input_tokens, + incremental_output_tokens=call_stats.output_tokens, + incremental_reasoning_tokens=call_stats.reasoning_tokens, + incremental_cached_tokens=call_stats.cached_tokens, + ) + if thought: + await app.DATABASE.update_thought( + thought_id=thought.observer_thought_id, + organization_id=thought.organization_id, + input_token_count=call_stats.input_tokens, + output_token_count=call_stats.output_tokens, + reasoning_token_count=call_stats.reasoning_tokens, + cached_token_count=call_stats.cached_tokens, + thought_cost=call_stats.llm_cost, + ) + + organization_id = organization_id or ( + step.organization_id if step else (thought.organization_id if thought else None) + ) + # Track LLM API handler duration, token counts, and cost + duration_seconds = time.perf_counter() - start_time + LOG.info( + "LLM API handler duration metrics", llm_key=self.llm_key, + prompt_name=prompt_name, model=self.llm_config.model_name, duration_seconds=duration_seconds, - input_tokens=call_stats.input_tokens, - output_tokens=call_stats.output_tokens, - reasoning_tokens=call_stats.reasoning_tokens, - cached_tokens=call_stats.cached_tokens, - llm_cost=call_stats.llm_cost, + step_id=step.step_id if step else None, + thought_id=thought.observer_thought_id if thought else None, + organization_id=organization_id, + input_tokens=call_stats.input_tokens if call_stats and call_stats.input_tokens else None, + output_tokens=call_stats.output_tokens if call_stats and call_stats.output_tokens else None, + reasoning_tokens=call_stats.reasoning_tokens if call_stats and call_stats.reasoning_tokens else None, + cached_tokens=call_stats.cached_tokens if call_stats and call_stats.cached_tokens else None, + llm_cost=call_stats.llm_cost if call_stats and call_stats.llm_cost else None, ) - return parsed_response + # Raw response is used for CUA engine LLM calls. + if raw_response: + return response.model_dump(exclude_none=True) + + parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix, force_dict) + parsed_response_json = json.dumps(parsed_response, indent=2) + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=parsed_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + **artifact_targets, + ) + ) + + rendered_response_json = None + if context and len(context.hashed_href_map) > 0: + llm_content = json.dumps(parsed_response) + rendered_content = Template(llm_content).render(context.hashed_href_map) + parsed_response = json.loads(rendered_content) + rendered_response_json = json.dumps(parsed_response, indent=2) + if should_persist_llm_artifacts: + artifacts.append( + await app.ARTIFACT_MANAGER.prepare_llm_artifact( + data=rendered_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, + **artifact_targets, + ) + ) + + if step and is_speculative_step: + step.speculative_llm_metadata = SpeculativeLLMMetadata( + prompt=llm_prompt_value, + llm_request_json=llm_request_json, + llm_response_json=llm_response_json, + parsed_response_json=parsed_response_json, + rendered_response_json=rendered_response_json, + llm_key=self.llm_key, + model=self.llm_config.model_name, + duration_seconds=duration_seconds, + input_tokens=call_stats.input_tokens, + output_tokens=call_stats.output_tokens, + reasoning_tokens=call_stats.reasoning_tokens, + cached_tokens=call_stats.cached_tokens, + llm_cost=call_stats.llm_cost, + ) + + return parsed_response + finally: + try: + await app.ARTIFACT_MANAGER.bulk_create_artifacts(artifacts) + except Exception: + LOG.error("Failed to persist artifacts", exc_info=True) def get_screenshot_resize_target_dimension(self, window_dimension: Resolution | None) -> Resolution: if window_dimension and window_dimension != self.browser_window_dimension: diff --git a/skyvern/forge/sdk/artifact/manager.py b/skyvern/forge/sdk/artifact/manager.py index 36548ba9..366a6359 100644 --- a/skyvern/forge/sdk/artifact/manager.py +++ b/skyvern/forge/sdk/artifact/manager.py @@ -367,6 +367,24 @@ class ArtifactManager: data=data, ) + async def bulk_create_artifacts( + self, + requests: list[BulkArtifactCreationRequest | None], + ) -> list[str]: + artifacts: list[ArtifactBatchData] = [] + primary_key: str | None = None + for request in requests: + if request: + artifacts.extend(request.artifacts) + primary_key = request.primary_key + + if primary_key is None or not artifacts: + return [] + + return await self._bulk_create_artifacts( + BulkArtifactCreationRequest(artifacts=artifacts, primary_key=primary_key) + ) + async def _bulk_create_artifacts( self, request: BulkArtifactCreationRequest, @@ -636,7 +654,7 @@ class ArtifactManager: return BulkArtifactCreationRequest(artifacts=artifacts, primary_key=ai_suggestion.ai_suggestion_id) - async def create_llm_artifact( + async def prepare_llm_artifact( self, data: bytes, artifact_type: ArtifactType, @@ -645,54 +663,40 @@ class ArtifactManager: thought: Thought | None = None, task_v2: TaskV2 | None = None, ai_suggestion: AISuggestion | None = None, - ) -> None: - """ - Create LLM artifact with optional screenshots using bulk insert. - - Args: - data: Main artifact data - artifact_type: Type of the main artifact - screenshots: Optional list of screenshot data - step: Optional Step entity - thought: Optional Thought entity - task_v2: Optional TaskV2 entity - ai_suggestion: Optional AISuggestion entity - """ + ) -> BulkArtifactCreationRequest | None: if step: - request = self._prepare_step_artifacts( + return self._prepare_step_artifacts( step=step, artifact_type=artifact_type, data=data, screenshots=screenshots, ) - await self._bulk_create_artifacts(request) elif task_v2: - request = self._prepare_task_v2_artifacts( + return self._prepare_task_v2_artifacts( task_v2=task_v2, artifact_type=artifact_type, data=data, screenshots=screenshots, ) - await self._bulk_create_artifacts(request) elif thought: - request = self._prepare_thought_artifacts( + return self._prepare_thought_artifacts( thought=thought, artifact_type=artifact_type, data=data, screenshots=screenshots, ) - await self._bulk_create_artifacts(request) elif ai_suggestion: - request = self._prepare_ai_suggestion_artifacts( + return self._prepare_ai_suggestion_artifacts( ai_suggestion=ai_suggestion, artifact_type=artifact_type, data=data, screenshots=screenshots, ) - await self._bulk_create_artifacts(request) + else: + return None async def update_artifact_data( self,