anthropic CUA (#2231)
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,9 @@ from typing import Any
|
||||
|
||||
import litellm
|
||||
import structlog
|
||||
from anthropic.types.message import Message as AnthropicMessage
|
||||
from jinja2 import Template
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import SkyvernContextWindowExceededError
|
||||
@@ -456,11 +458,18 @@ class LLMCaller:
|
||||
self.llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
self.base_parameters = base_parameters
|
||||
self.message_history: list[dict[str, Any]] = []
|
||||
self.current_tool_results: list[dict[str, Any]] = []
|
||||
|
||||
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
|
||||
self.current_tool_results.append(tool_result)
|
||||
|
||||
def clear_tool_results(self) -> None:
|
||||
self.current_tool_results = []
|
||||
|
||||
async def call(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_name: str,
|
||||
prompt: str | None = None,
|
||||
prompt_name: str | None = None,
|
||||
step: Step | None = None,
|
||||
task_v2: TaskV2 | None = None,
|
||||
thought: Thought | None = None,
|
||||
@@ -469,6 +478,8 @@ class LLMCaller:
|
||||
parameters: dict[str, Any] | None = None,
|
||||
tools: list | None = None,
|
||||
use_message_history: bool = False,
|
||||
raw_response: bool = False,
|
||||
**extra_parameters: Any,
|
||||
) -> dict[str, Any]:
|
||||
start_time = time.perf_counter()
|
||||
active_parameters = self.base_parameters or {}
|
||||
@@ -476,6 +487,8 @@ class LLMCaller:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config)
|
||||
|
||||
active_parameters.update(parameters)
|
||||
if extra_parameters:
|
||||
active_parameters.update(extra_parameters)
|
||||
if self.llm_config.litellm_params: # type: ignore
|
||||
active_parameters.update(self.llm_config.litellm_params) # type: ignore
|
||||
|
||||
@@ -491,7 +504,7 @@ class LLMCaller:
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=prompt.encode("utf-8"),
|
||||
data=prompt.encode("utf-8") if prompt else b"",
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
screenshots=screenshots,
|
||||
step=step,
|
||||
@@ -525,8 +538,7 @@ class LLMCaller:
|
||||
)
|
||||
t_llm_request = time.perf_counter()
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=self.llm_config.model_name,
|
||||
response = await self._dispatch_llm_call(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
timeout=settings.LLM_CONFIG_TIMEOUT,
|
||||
@@ -603,6 +615,21 @@ class LLMCaller:
|
||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||
thought_cost=llm_cost,
|
||||
)
|
||||
# Track LLM API handler duration
|
||||
duration_seconds = time.perf_counter() - start_time
|
||||
LOG.info(
|
||||
"LLM API handler duration metrics",
|
||||
llm_key=self.llm_key,
|
||||
prompt_name=prompt_name,
|
||||
model=self.llm_config.model_name,
|
||||
duration_seconds=duration_seconds,
|
||||
step_id=step.step_id if step else None,
|
||||
thought_id=thought.observer_thought_id if thought else None,
|
||||
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
|
||||
)
|
||||
if raw_response:
|
||||
return response.model_dump()
|
||||
|
||||
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
@@ -626,17 +653,53 @@ class LLMCaller:
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
|
||||
# Track LLM API handler duration
|
||||
duration_seconds = time.perf_counter() - start_time
|
||||
LOG.info(
|
||||
"LLM API handler duration metrics",
|
||||
llm_key=self.llm_key,
|
||||
prompt_name=prompt_name,
|
||||
model=self.llm_config.model_name,
|
||||
duration_seconds=duration_seconds,
|
||||
step_id=step.step_id if step else None,
|
||||
thought_id=thought.observer_thought_id if thought else None,
|
||||
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
|
||||
return parsed_response
|
||||
|
||||
async def _dispatch_llm_call(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
|
||||
if self.llm_key and self.llm_key.startswith("ANTHROPIC"):
|
||||
return await self._call_anthropic(messages, tools, timeout)
|
||||
|
||||
return await litellm.acompletion(
|
||||
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
async def _call_anthropic(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list | None = None,
|
||||
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||
**active_parameters: dict[str, Any],
|
||||
) -> AnthropicMessage:
|
||||
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 4096
|
||||
model_name = self.llm_config.model_name.replace("bedrock/", "").replace("anthropic/", "")
|
||||
return await app.ANTHROPIC_CLIENT.messages.create(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
timeout=timeout,
|
||||
betas=active_parameters.get("betas", None),
|
||||
)
|
||||
|
||||
|
||||
class LLMCallerManager:
|
||||
_llm_callers: dict[str, LLMCaller] = {}
|
||||
|
||||
@classmethod
|
||||
def get_llm_caller(cls, uid: str) -> LLMCaller | None:
|
||||
return cls._llm_callers.get(uid)
|
||||
|
||||
@classmethod
|
||||
def set_llm_caller(cls, uid: str, llm_caller: LLMCaller) -> None:
|
||||
cls._llm_callers[uid] = llm_caller
|
||||
|
||||
@classmethod
|
||||
def clear_llm_caller(cls, uid: str) -> None:
|
||||
if uid in cls._llm_callers:
|
||||
del cls._llm_callers[uid]
|
||||
|
||||
@@ -47,19 +47,22 @@ async def llm_messages_builder(
|
||||
|
||||
|
||||
async def llm_messages_builder_with_history(
|
||||
prompt: str,
|
||||
prompt: str | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
message_history: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
if message_history:
|
||||
messages = copy.deepcopy(message_history)
|
||||
current_user_messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
]
|
||||
|
||||
current_user_messages: list[dict[str, Any]] = []
|
||||
if prompt:
|
||||
current_user_messages.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
)
|
||||
|
||||
if screenshots:
|
||||
for screenshot in screenshots:
|
||||
|
||||
Reference in New Issue
Block a user