add LLMCaller that supports message history (#2204)
This commit is contained in:
@@ -19,7 +19,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
|
|||||||
LLMProviderErrorRetryableTask,
|
LLMProviderErrorRetryableTask,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig, dummy_llm_api_handler
|
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig, dummy_llm_api_handler
|
||||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.models import Step
|
from skyvern.forge.sdk.models import Step
|
||||||
@@ -444,3 +444,199 @@ class LLMAPIHandlerFactory:
|
|||||||
if llm_key in cls._custom_handlers:
|
if llm_key in cls._custom_handlers:
|
||||||
raise DuplicateCustomLLMProviderError(llm_key)
|
raise DuplicateCustomLLMProviderError(llm_key)
|
||||||
cls._custom_handlers[llm_key] = handler
|
cls._custom_handlers[llm_key] = handler
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCaller:
|
||||||
|
"""
|
||||||
|
An LLMCaller instance defines the LLM configs and keeps the chat history if needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_key: str, base_parameters: dict[str, Any] | None = None):
|
||||||
|
self.llm_key = llm_key
|
||||||
|
self.llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||||
|
self.base_parameters = base_parameters
|
||||||
|
self.message_history: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
prompt_name: str,
|
||||||
|
step: Step | None = None,
|
||||||
|
task_v2: TaskV2 | None = None,
|
||||||
|
thought: Thought | None = None,
|
||||||
|
ai_suggestion: AISuggestion | None = None,
|
||||||
|
screenshots: list[bytes] | None = None,
|
||||||
|
parameters: dict[str, Any] | None = None,
|
||||||
|
tools: list | None = None,
|
||||||
|
use_message_history: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
active_parameters = self.base_parameters or {}
|
||||||
|
if parameters is None:
|
||||||
|
parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config)
|
||||||
|
|
||||||
|
active_parameters.update(parameters)
|
||||||
|
if self.llm_config.litellm_params: # type: ignore
|
||||||
|
active_parameters.update(self.llm_config.litellm_params) # type: ignore
|
||||||
|
|
||||||
|
context = skyvern_context.current()
|
||||||
|
if context and len(context.hashed_href_map) > 0:
|
||||||
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
|
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
|
||||||
|
artifact_type=ArtifactType.HASHED_HREF_MAP,
|
||||||
|
step=step,
|
||||||
|
task_v2=task_v2,
|
||||||
|
thought=thought,
|
||||||
|
ai_suggestion=ai_suggestion,
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
|
data=prompt.encode("utf-8"),
|
||||||
|
artifact_type=ArtifactType.LLM_PROMPT,
|
||||||
|
screenshots=screenshots,
|
||||||
|
step=step,
|
||||||
|
task_v2=task_v2,
|
||||||
|
thought=thought,
|
||||||
|
ai_suggestion=ai_suggestion,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.llm_config.supports_vision:
|
||||||
|
screenshots = None
|
||||||
|
|
||||||
|
if use_message_history:
|
||||||
|
# self.message_history will be updated in place
|
||||||
|
messages = await llm_messages_builder_with_history(prompt, screenshots, self.message_history)
|
||||||
|
else:
|
||||||
|
messages = await llm_messages_builder_with_history(prompt, screenshots)
|
||||||
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": self.llm_config.model_name,
|
||||||
|
"messages": messages,
|
||||||
|
# we're not using active_parameters here because it may contain sensitive information
|
||||||
|
**parameters,
|
||||||
|
}
|
||||||
|
).encode("utf-8"),
|
||||||
|
artifact_type=ArtifactType.LLM_REQUEST,
|
||||||
|
step=step,
|
||||||
|
task_v2=task_v2,
|
||||||
|
thought=thought,
|
||||||
|
ai_suggestion=ai_suggestion,
|
||||||
|
)
|
||||||
|
t_llm_request = time.perf_counter()
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=self.llm_config.model_name,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
timeout=settings.LLM_CONFIG_TIMEOUT,
|
||||||
|
**active_parameters,
|
||||||
|
)
|
||||||
|
if use_message_history:
|
||||||
|
# only update message_history when the request is successful
|
||||||
|
self.message_history = messages
|
||||||
|
except litellm.exceptions.APIError as e:
|
||||||
|
raise LLMProviderErrorRetryableTask(self.llm_key) from e
|
||||||
|
except litellm.exceptions.ContextWindowExceededError as e:
|
||||||
|
LOG.exception(
|
||||||
|
"Context window exceeded",
|
||||||
|
llm_key=self.llm_key,
|
||||||
|
model=self.llm_config.model_name,
|
||||||
|
)
|
||||||
|
raise SkyvernContextWindowExceededError() from e
|
||||||
|
except CancelledError:
|
||||||
|
t_llm_cancelled = time.perf_counter()
|
||||||
|
LOG.error(
|
||||||
|
"LLM request got cancelled",
|
||||||
|
llm_key=self.llm_key,
|
||||||
|
model=self.llm_config.model_name,
|
||||||
|
duration=t_llm_cancelled - t_llm_request,
|
||||||
|
)
|
||||||
|
raise LLMProviderError(self.llm_key)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.exception("LLM request failed unexpectedly", llm_key=self.llm_key)
|
||||||
|
raise LLMProviderError(self.llm_key) from e
|
||||||
|
|
||||||
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
|
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||||
|
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||||
|
step=step,
|
||||||
|
task_v2=task_v2,
|
||||||
|
thought=thought,
|
||||||
|
ai_suggestion=ai_suggestion,
|
||||||
|
)
|
||||||
|
|
||||||
|
if step or thought:
|
||||||
|
try:
|
||||||
|
llm_cost = litellm.completion_cost(completion_response=response)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||||
|
llm_cost = 0
|
||||||
|
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||||
|
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||||
|
reasoning_tokens = 0
|
||||||
|
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
|
||||||
|
if completion_token_detail:
|
||||||
|
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
|
||||||
|
cached_tokens = 0
|
||||||
|
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
|
||||||
|
if cached_token_detail:
|
||||||
|
cached_tokens = cached_token_detail.cached_tokens or 0
|
||||||
|
if step:
|
||||||
|
await app.DATABASE.update_step(
|
||||||
|
task_id=step.task_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
organization_id=step.organization_id,
|
||||||
|
incremental_cost=llm_cost,
|
||||||
|
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
|
||||||
|
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
|
||||||
|
incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None,
|
||||||
|
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
|
||||||
|
)
|
||||||
|
if thought:
|
||||||
|
await app.DATABASE.update_thought(
|
||||||
|
thought_id=thought.observer_thought_id,
|
||||||
|
organization_id=thought.organization_id,
|
||||||
|
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
|
||||||
|
output_token_count=completion_tokens if completion_tokens > 0 else None,
|
||||||
|
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
|
||||||
|
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||||
|
thought_cost=llm_cost,
|
||||||
|
)
|
||||||
|
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
|
||||||
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
|
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||||
|
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||||
|
step=step,
|
||||||
|
task_v2=task_v2,
|
||||||
|
thought=thought,
|
||||||
|
ai_suggestion=ai_suggestion,
|
||||||
|
)
|
||||||
|
|
||||||
|
if context and len(context.hashed_href_map) > 0:
|
||||||
|
llm_content = json.dumps(parsed_response)
|
||||||
|
rendered_content = Template(llm_content).render(context.hashed_href_map)
|
||||||
|
parsed_response = json.loads(rendered_content)
|
||||||
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
|
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||||
|
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
|
||||||
|
step=step,
|
||||||
|
task_v2=task_v2,
|
||||||
|
thought=thought,
|
||||||
|
ai_suggestion=ai_suggestion,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track LLM API handler duration
|
||||||
|
duration_seconds = time.perf_counter() - start_time
|
||||||
|
LOG.info(
|
||||||
|
"LLM API handler duration metrics",
|
||||||
|
llm_key=self.llm_key,
|
||||||
|
prompt_name=prompt_name,
|
||||||
|
model=self.llm_config.model_name,
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
step_id=step.step_id if step else None,
|
||||||
|
thought_id=thought.observer_thought_id if thought else None,
|
||||||
|
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed_response
|
||||||
|
|||||||
@@ -318,7 +318,7 @@ if settings.ENABLE_BEDROCK:
|
|||||||
["AWS_REGION"],
|
["AWS_REGION"],
|
||||||
supports_vision=True,
|
supports_vision=True,
|
||||||
add_assistant_prefix=True,
|
add_assistant_prefix=True,
|
||||||
max_completion_tokens=200000,
|
max_completion_tokens=64000,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -45,6 +46,36 @@ async def llm_messages_builder(
|
|||||||
return [{"role": "user", "content": messages}]
|
return [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_messages_builder_with_history(
|
||||||
|
prompt: str,
|
||||||
|
screenshots: list[bytes] | None = None,
|
||||||
|
message_history: list[dict[str, Any]] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
messages: list[dict[str, Any]] = []
|
||||||
|
if message_history:
|
||||||
|
messages = copy.deepcopy(message_history)
|
||||||
|
current_user_messages: list[dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if screenshots:
|
||||||
|
for screenshot in screenshots:
|
||||||
|
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
||||||
|
current_user_messages.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{encoded_image}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
messages.append({"role": "user", "content": current_user_messages})
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, Any]:
|
def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, Any]:
|
||||||
content = None
|
content = None
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user