Force Claude 3 models to output JSON object and parse it more reliably (#293)
Co-authored-by: otmane <otmanebenazzou.pro@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import commentjson
|
||||
@@ -10,6 +11,7 @@ from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidL
|
||||
async def llm_messages_builder(
|
||||
prompt: str,
|
||||
screenshots: list[bytes] | None = None,
|
||||
add_assistant_prefix: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
@@ -29,17 +31,37 @@ async def llm_messages_builder(
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Anthropic models seems to struggle to always output a valid json object so we need to prefill the response to force it:
|
||||
if add_assistant_prefix:
|
||||
return [{"role": "user", "content": messages}, {"role": "assistant", "content": "{"}]
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
|
||||
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
|
||||
def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, str]:
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
content = content.replace("```json", "")
|
||||
content = content.replace("```", "")
|
||||
# Since we prefilled Anthropic response with "{" we need to add it back to the response to have a valid json object:
|
||||
if add_assistant_prefix:
|
||||
content = "{" + content
|
||||
content = try_to_extract_json_from_markdown_format(content)
|
||||
content = replace_useless_text_around_json(content)
|
||||
if not content:
|
||||
raise EmptyLLMResponseError(str(response))
|
||||
return commentjson.loads(content)
|
||||
except Exception as e:
|
||||
raise InvalidLLMResponseFormat(str(response)) from e
|
||||
|
||||
|
||||
def replace_useless_text_around_json(input_string: str) -> str:
|
||||
first_occurrence_of_brace = input_string.find("{")
|
||||
last_occurrence_of_brace = input_string.rfind("}")
|
||||
return input_string[first_occurrence_of_brace : last_occurrence_of_brace + 1]
|
||||
|
||||
|
||||
def try_to_extract_json_from_markdown_format(text: str) -> str:
|
||||
pattern = r"```json\s*(.*?)\s*```"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
return text
|
||||
|
||||
Reference in New Issue
Block a user