Anthropic CUA fix (#2238)

This commit is contained in:
Shuchang Zheng
2025-04-28 18:15:23 +08:00
committed by GitHub
parent d798b00409
commit 1530338cad
5 changed files with 10 additions and 5 deletions

View File

@@ -131,6 +131,7 @@ class Settings(BaseSettings):
OPENAI_API_KEY: str | None = None OPENAI_API_KEY: str | None = None
# ANTHROPIC # ANTHROPIC
ANTHROPIC_API_KEY: str | None = None ANTHROPIC_API_KEY: str | None = None
ANTHROPIC_CUA_LLM_KEY: str = "ANTHROPIC_CLAUDE3.7_SONNET"
# OPENAI COMPATIBLE # OPENAI COMPATIBLE
OPENAI_COMPATIBLE_MODEL_NAME: str | None = None OPENAI_COMPATIBLE_MODEL_NAME: str | None = None

View File

@@ -385,7 +385,7 @@ class ForgeAgent:
# llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE") # llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE")
llm_caller = LLMCallerManager.get_llm_caller(task.task_id) llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if not llm_caller: if not llm_caller:
llm_caller = LLMCaller(llm_key="ANTHROPIC_CLAUDE3.5_SONNET") llm_caller = LLMCaller(llm_key=settings.ANTHROPIC_CUA_LLM_KEY)
LLMCallerManager.set_llm_caller(task.task_id, llm_caller) LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
step, detailed_output = await self.agent_step( step, detailed_output = await self.agent_step(
task, task,

View File

@@ -6,6 +6,7 @@ from typing import Any
import litellm import litellm
import structlog import structlog
from anthropic import NOT_GIVEN
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
from jinja2 import Template from jinja2 import Template
from litellm.utils import CustomStreamWrapper, ModelResponse from litellm.utils import CustomStreamWrapper, ModelResponse
@@ -663,7 +664,7 @@ class LLMCaller:
**active_parameters: dict[str, Any], **active_parameters: dict[str, Any],
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage: ) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
if self.llm_key and self.llm_key.startswith("ANTHROPIC"): if self.llm_key and self.llm_key.startswith("ANTHROPIC"):
return await self._call_anthropic(messages, tools, timeout) return await self._call_anthropic(messages, tools, timeout, **active_parameters)
return await litellm.acompletion( return await litellm.acompletion(
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
@@ -678,15 +679,16 @@ class LLMCaller:
) -> AnthropicMessage: ) -> AnthropicMessage:
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 4096 max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 4096
model_name = self.llm_config.model_name.replace("bedrock/", "").replace("anthropic/", "") model_name = self.llm_config.model_name.replace("bedrock/", "").replace("anthropic/", "")
betas = active_parameters.get("betas", NOT_GIVEN)
response = await app.ANTHROPIC_CLIENT.beta.messages.create( response = await app.ANTHROPIC_CLIENT.beta.messages.create(
max_tokens=max_tokens, max_tokens=max_tokens,
messages=messages, messages=messages,
model=model_name, model=model_name,
tools=tools, tools=tools,
timeout=timeout, timeout=timeout,
betas=active_parameters.get("betas", None), betas=betas,
) )
LOG.info("Anthropic response", response=response) LOG.info("Anthropic response", response=response, betas=betas, tools=tools, timeout=timeout)
return response return response

View File

@@ -374,6 +374,7 @@ class ActionHandler:
"content": {"result": "Tool execution failed"}, "content": {"result": "Tool execution failed"},
} }
llm_caller.add_tool_result(tool_call_result) llm_caller.add_tool_result(tool_call_result)
LOG.info("Tool call result", tool_call_result=tool_call_result, action=action)
return actions_result return actions_result
if llm_caller and action.tool_call_id: if llm_caller and action.tool_call_id:
@@ -382,6 +383,7 @@ class ActionHandler:
"tool_use_id": action.tool_call_id, "tool_use_id": action.tool_call_id,
"content": {"result": "Tool executed successfully"}, "content": {"result": "Tool executed successfully"},
} }
LOG.info("Tool call result", tool_call_result=tool_call_result, action=action)
llm_caller.add_tool_result(tool_call_result) llm_caller.add_tool_result(tool_call_result)
# do the teardown # do the teardown

View File

@@ -512,7 +512,7 @@ async def parse_anthropic_actions(
task_id=task.task_id, task_id=task.task_id,
step_id=step.step_id, step_id=step.step_id,
step_order=step.order, step_order=step.order,
action_order=idx - 1, action_order=idx,
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
) )
) )