diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index a3576137..e30aeb2a 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -19,7 +19,7 @@ from skyvern.forge.sdk.api.llm.exceptions import ( LLMProviderErrorRetryableTask, ) from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig, dummy_llm_api_handler -from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response +from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.models import Step @@ -444,3 +444,199 @@ class LLMAPIHandlerFactory: if llm_key in cls._custom_handlers: raise DuplicateCustomLLMProviderError(llm_key) cls._custom_handlers[llm_key] = handler + + +class LLMCaller: + """ + An LLMCaller instance defines the LLM configs and keeps the chat history if needed. + """ + + def __init__(self, llm_key: str, base_parameters: dict[str, Any] | None = None): + self.llm_key = llm_key + self.llm_config = LLMConfigRegistry.get_config(llm_key) + self.base_parameters = base_parameters + self.message_history: list[dict[str, Any]] = [] + + async def call( + self, + prompt: str, + prompt_name: str, + step: Step | None = None, + task_v2: TaskV2 | None = None, + thought: Thought | None = None, + ai_suggestion: AISuggestion | None = None, + screenshots: list[bytes] | None = None, + parameters: dict[str, Any] | None = None, + tools: list | None = None, + use_message_history: bool = False, + ) -> dict[str, Any]: + start_time = time.perf_counter() + active_parameters = self.base_parameters or {} + if parameters is None: + parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config) + + active_parameters.update(parameters) + if self.llm_config.litellm_params: # type: ignore + active_parameters.update(self.llm_config.litellm_params) # type: ignore + + context = skyvern_context.current() + if context and len(context.hashed_href_map) > 0: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), + artifact_type=ArtifactType.HASHED_HREF_MAP, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=prompt.encode("utf-8"), + artifact_type=ArtifactType.LLM_PROMPT, + screenshots=screenshots, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + + if not self.llm_config.supports_vision: + screenshots = None + + if use_message_history: + # self.message_history will be updated in place + messages = await llm_messages_builder_with_history(prompt, screenshots, self.message_history) + else: + messages = await llm_messages_builder_with_history(prompt, screenshots) + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=json.dumps( + { + "model": self.llm_config.model_name, + "messages": messages, + # we're not using active_parameters here because it may contain sensitive information + **parameters, + } + ).encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + t_llm_request = time.perf_counter() + try: + response = await litellm.acompletion( + model=self.llm_config.model_name, + messages=messages, + tools=tools, + timeout=settings.LLM_CONFIG_TIMEOUT, + **active_parameters, + ) + if use_message_history: + # only update message_history when the request is successful + self.message_history = messages + except litellm.exceptions.APIError as e: + raise LLMProviderErrorRetryableTask(self.llm_key) from e + except litellm.exceptions.ContextWindowExceededError as e: + LOG.exception( + "Context window exceeded", + llm_key=self.llm_key, + model=self.llm_config.model_name, + ) + raise SkyvernContextWindowExceededError() from e + except CancelledError: + t_llm_cancelled = time.perf_counter() + LOG.error( + "LLM request got cancelled", + llm_key=self.llm_key, + model=self.llm_config.model_name, + duration=t_llm_cancelled - t_llm_request, + ) + raise LLMProviderError(self.llm_key) + except Exception as e: + LOG.exception("LLM request failed unexpectedly", llm_key=self.llm_key) + raise LLMProviderError(self.llm_key) from e + + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=response.model_dump_json(indent=2).encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + + if step or thought: + try: + llm_cost = litellm.completion_cost(completion_response=response) + except Exception as e: + LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) + llm_cost = 0 + prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) + completion_tokens = response.get("usage", {}).get("completion_tokens", 0) + reasoning_tokens = 0 + completion_token_detail = response.get("usage", {}).get("completion_tokens_details") + if completion_token_detail: + reasoning_tokens = completion_token_detail.reasoning_tokens or 0 + cached_tokens = 0 + cached_token_detail = response.get("usage", {}).get("prompt_tokens_details") + if cached_token_detail: + cached_tokens = cached_token_detail.cached_tokens or 0 + if step: + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=llm_cost, + incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, + incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, + incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, + incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, + ) + if thought: + await app.DATABASE.update_thought( + thought_id=thought.observer_thought_id, + organization_id=thought.organization_id, + input_token_count=prompt_tokens if prompt_tokens > 0 else None, + output_token_count=completion_tokens if completion_tokens > 0 else None, + reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, + cached_token_count=cached_tokens if cached_tokens > 0 else None, + thought_cost=llm_cost, + ) + parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix) + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=json.dumps(parsed_response, indent=2).encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + + if context and len(context.hashed_href_map) > 0: + llm_content = json.dumps(parsed_response) + rendered_content = Template(llm_content).render(context.hashed_href_map) + parsed_response = json.loads(rendered_content) + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=json.dumps(parsed_response, indent=2).encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + + # Track LLM API handler duration + duration_seconds = time.perf_counter() - start_time + LOG.info( + "LLM API handler duration metrics", + llm_key=self.llm_key, + prompt_name=prompt_name, + model=self.llm_config.model_name, + duration_seconds=duration_seconds, + step_id=step.step_id if step else None, + thought_id=thought.observer_thought_id if thought else None, + organization_id=step.organization_id if step else (thought.organization_id if thought else None), + ) + + return parsed_response diff --git a/skyvern/forge/sdk/api/llm/config_registry.py b/skyvern/forge/sdk/api/llm/config_registry.py index 6b3b36ec..b9a51c85 100644 --- a/skyvern/forge/sdk/api/llm/config_registry.py +++ b/skyvern/forge/sdk/api/llm/config_registry.py @@ -318,7 +318,7 @@ if settings.ENABLE_BEDROCK: ["AWS_REGION"], supports_vision=True, add_assistant_prefix=True, - max_completion_tokens=200000, + max_completion_tokens=64000, ), ) diff --git a/skyvern/forge/sdk/api/llm/utils.py b/skyvern/forge/sdk/api/llm/utils.py index cbcb69c7..0bbb00a8 100644 --- a/skyvern/forge/sdk/api/llm/utils.py +++ b/skyvern/forge/sdk/api/llm/utils.py @@ -1,4 +1,5 @@ import base64 +import copy import json import re from typing import Any @@ -45,6 +46,36 @@ async def llm_messages_builder( return [{"role": "user", "content": messages}] +async def llm_messages_builder_with_history( + prompt: str, + screenshots: list[bytes] | None = None, + message_history: list[dict[str, Any]] | None = None, +) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + if message_history: + messages = copy.deepcopy(message_history) + current_user_messages: list[dict[str, Any]] = [ + { + "type": "text", + "text": prompt, + } + ] + + if screenshots: + for screenshot in screenshots: + encoded_image = base64.b64encode(screenshot).decode("utf-8") + current_user_messages.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{encoded_image}", + }, + } + ) + messages.append({"role": "user", "content": current_user_messages}) + return messages + + def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, Any]: content = None try: