diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 130e60fa..390cc4b4 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -81,7 +81,7 @@ class LLMAPIHandlerFactory: data=screenshot, ) - messages = await llm_messages_builder(prompt, screenshots) + messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) if step: await app.ARTIFACT_MANAGER.create_artifact( step=step, @@ -115,7 +115,7 @@ class LLMAPIHandlerFactory: organization_id=step.organization_id, incremental_cost=llm_cost, ) - parsed_response = parse_api_response(response) + parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) if step: await app.ARTIFACT_MANAGER.create_artifact( step=step, @@ -159,7 +159,7 @@ class LLMAPIHandlerFactory: if not llm_config.supports_vision: screenshots = None - messages = await llm_messages_builder(prompt, screenshots) + messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) if step: await app.ARTIFACT_MANAGER.create_artifact( step=step, @@ -199,7 +199,7 @@ class LLMAPIHandlerFactory: organization_id=step.organization_id, incremental_cost=llm_cost, ) - parsed_response = parse_api_response(response) + parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) if step: await app.ARTIFACT_MANAGER.create_artifact( step=step, diff --git a/skyvern/forge/sdk/api/llm/config_registry.py b/skyvern/forge/sdk/api/llm/config_registry.py index 0263fbe1..5513a01f 100644 --- a/skyvern/forge/sdk/api/llm/config_registry.py +++ b/skyvern/forge/sdk/api/llm/config_registry.py @@ -55,21 +55,38 @@ if not any( if SettingsManager.get_settings().ENABLE_OPENAI: - LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], True)) - LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], True)) + LLMConfigRegistry.register_config( + "OPENAI_GPT4_TURBO", + LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=False, add_assistant_prefix=False), + ) + LLMConfigRegistry.register_config( + "OPENAI_GPT4V", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False) + ) if SettingsManager.get_settings().ENABLE_ANTHROPIC: LLMConfigRegistry.register_config( - "ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True) + "ANTHROPIC_CLAUDE3", + LLMConfig( + "anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True + ), ) LLMConfigRegistry.register_config( - "ANTHROPIC_CLAUDE3_OPUS", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True) + "ANTHROPIC_CLAUDE3_OPUS", + LLMConfig( + "anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True + ), ) LLMConfigRegistry.register_config( - "ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True) + "ANTHROPIC_CLAUDE3_SONNET", + LLMConfig( + "anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True + ), ) LLMConfigRegistry.register_config( - "ANTHROPIC_CLAUDE3_HAIKU", LLMConfig("anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], True) + "ANTHROPIC_CLAUDE3_HAIKU", + LLMConfig( + "anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True + ), ) if SettingsManager.get_settings().ENABLE_BEDROCK: @@ -79,7 +96,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK: LLMConfig( "bedrock/anthropic.claude-3-opus-20240229-v1:0", ["AWS_REGION"], - True, + supports_vision=True, + add_assistant_prefix=True, ), ) LLMConfigRegistry.register_config( @@ -87,7 +105,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK: LLMConfig( "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", ["AWS_REGION"], - True, + supports_vision=True, + add_assistant_prefix=True, ), ) LLMConfigRegistry.register_config( @@ -95,7 +114,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK: LLMConfig( "bedrock/anthropic.claude-3-haiku-20240307-v1:0", ["AWS_REGION"], - True, + supports_vision=True, + add_assistant_prefix=True, ), ) @@ -105,6 +125,7 @@ if SettingsManager.get_settings().ENABLE_AZURE: LLMConfig( f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}", ["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"], - True, + supports_vision=True, + add_assistant_prefix=False, ), ) diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index 41c8d152..940b0330 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -10,6 +10,7 @@ class LLMConfig: model_name: str required_env_vars: list[str] supports_vision: bool + add_assistant_prefix: bool def get_missing_env_vars(self) -> list[str]: missing_env_vars = [] diff --git a/skyvern/forge/sdk/api/llm/utils.py b/skyvern/forge/sdk/api/llm/utils.py index 47dbf858..11e5ed14 100644 --- a/skyvern/forge/sdk/api/llm/utils.py +++ b/skyvern/forge/sdk/api/llm/utils.py @@ -1,4 +1,5 @@ import base64 +import re from typing import Any import commentjson @@ -10,6 +11,7 @@ from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidL async def llm_messages_builder( prompt: str, screenshots: list[bytes] | None = None, + add_assistant_prefix: bool = False, ) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [ { @@ -29,17 +31,37 @@ async def llm_messages_builder( }, } ) - + # Anthropic models seems to struggle to always output a valid json object so we need to prefill the response to force it: + if add_assistant_prefix: + return [{"role": "user", "content": messages}, {"role": "assistant", "content": "{"}] return [{"role": "user", "content": messages}] -def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]: +def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, str]: try: content = response.choices[0].message.content - content = content.replace("```json", "") - content = content.replace("```", "") + # Since we prefilled Anthropic response with "{" we need to add it back to the response to have a valid json object: + if add_assistant_prefix: + content = "{" + content + content = try_to_extract_json_from_markdown_format(content) + content = replace_useless_text_around_json(content) if not content: raise EmptyLLMResponseError(str(response)) return commentjson.loads(content) except Exception as e: raise InvalidLLMResponseFormat(str(response)) from e + + +def replace_useless_text_around_json(input_string: str) -> str: + first_occurrence_of_brace = input_string.find("{") + last_occurrence_of_brace = input_string.rfind("}") + return input_string[first_occurrence_of_brace : last_occurrence_of_brace + 1] + + +def try_to_extract_json_from_markdown_format(text: str) -> str: + pattern = r"```json\s*(.*?)\s*```" + match = re.search(pattern, text, re.DOTALL) + if match: + return match.group(1) + else: + return text