anthropic CUA (#2231)

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Shuchang Zheng
2025-04-28 09:49:44 +08:00
committed by GitHub
parent 5582998490
commit 0a0228b341
18 changed files with 378 additions and 45 deletions

View File

@@ -6,7 +6,9 @@ from typing import Any
import litellm
import structlog
from anthropic.types.message import Message as AnthropicMessage
from jinja2 import Template
from litellm.utils import CustomStreamWrapper, ModelResponse
from skyvern.config import settings
from skyvern.exceptions import SkyvernContextWindowExceededError
@@ -456,11 +458,18 @@ class LLMCaller:
self.llm_config = LLMConfigRegistry.get_config(llm_key)
self.base_parameters = base_parameters
self.message_history: list[dict[str, Any]] = []
self.current_tool_results: list[dict[str, Any]] = []
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
self.current_tool_results.append(tool_result)
def clear_tool_results(self) -> None:
self.current_tool_results = []
async def call(
self,
prompt: str,
prompt_name: str,
prompt: str | None = None,
prompt_name: str | None = None,
step: Step | None = None,
task_v2: TaskV2 | None = None,
thought: Thought | None = None,
@@ -469,6 +478,8 @@ class LLMCaller:
parameters: dict[str, Any] | None = None,
tools: list | None = None,
use_message_history: bool = False,
raw_response: bool = False,
**extra_parameters: Any,
) -> dict[str, Any]:
start_time = time.perf_counter()
active_parameters = self.base_parameters or {}
@@ -476,6 +487,8 @@ class LLMCaller:
parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config)
active_parameters.update(parameters)
if extra_parameters:
active_parameters.update(extra_parameters)
if self.llm_config.litellm_params: # type: ignore
active_parameters.update(self.llm_config.litellm_params) # type: ignore
@@ -491,7 +504,7 @@ class LLMCaller:
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
data=prompt.encode("utf-8") if prompt else b"",
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
@@ -525,8 +538,7 @@ class LLMCaller:
)
t_llm_request = time.perf_counter()
try:
response = await litellm.acompletion(
model=self.llm_config.model_name,
response = await self._dispatch_llm_call(
messages=messages,
tools=tools,
timeout=settings.LLM_CONFIG_TIMEOUT,
@@ -603,6 +615,21 @@ class LLMCaller:
cached_token_count=cached_tokens if cached_tokens > 0 else None,
thought_cost=llm_cost,
)
# Track LLM API handler duration
duration_seconds = time.perf_counter() - start_time
LOG.info(
"LLM API handler duration metrics",
llm_key=self.llm_key,
prompt_name=prompt_name,
model=self.llm_config.model_name,
duration_seconds=duration_seconds,
step_id=step.step_id if step else None,
thought_id=thought.observer_thought_id if thought else None,
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
)
if raw_response:
return response.model_dump()
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
@@ -626,17 +653,53 @@ class LLMCaller:
ai_suggestion=ai_suggestion,
)
# Track LLM API handler duration
duration_seconds = time.perf_counter() - start_time
LOG.info(
"LLM API handler duration metrics",
llm_key=self.llm_key,
prompt_name=prompt_name,
model=self.llm_config.model_name,
duration_seconds=duration_seconds,
step_id=step.step_id if step else None,
thought_id=thought.observer_thought_id if thought else None,
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
return parsed_response
async def _dispatch_llm_call(
self,
messages: list[dict[str, Any]],
tools: list | None = None,
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
if self.llm_key and self.llm_key.startswith("ANTHROPIC"):
return await self._call_anthropic(messages, tools, timeout)
return await litellm.acompletion(
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
)
return parsed_response
async def _call_anthropic(
self,
messages: list[dict[str, Any]],
tools: list | None = None,
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> AnthropicMessage:
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 4096
model_name = self.llm_config.model_name.replace("bedrock/", "").replace("anthropic/", "")
return await app.ANTHROPIC_CLIENT.messages.create(
max_tokens=max_tokens,
messages=messages,
model=model_name,
tools=tools,
timeout=timeout,
betas=active_parameters.get("betas", None),
)
class LLMCallerManager:
_llm_callers: dict[str, LLMCaller] = {}
@classmethod
def get_llm_caller(cls, uid: str) -> LLMCaller | None:
return cls._llm_callers.get(uid)
@classmethod
def set_llm_caller(cls, uid: str, llm_caller: LLMCaller) -> None:
cls._llm_callers[uid] = llm_caller
@classmethod
def clear_llm_caller(cls, uid: str) -> None:
if uid in cls._llm_callers:
del cls._llm_callers[uid]

View File

@@ -47,19 +47,22 @@ async def llm_messages_builder(
async def llm_messages_builder_with_history(
prompt: str,
prompt: str | None = None,
screenshots: list[bytes] | None = None,
message_history: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
if message_history:
messages = copy.deepcopy(message_history)
current_user_messages: list[dict[str, Any]] = [
{
"type": "text",
"text": prompt,
}
]
current_user_messages: list[dict[str, Any]] = []
if prompt:
current_user_messages.append(
{
"type": "text",
"text": prompt,
}
)
if screenshots:
for screenshot in screenshots: