From bf06fcfeb7592cd5d41d32a58d52d9ae9503f798 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sat, 2 Nov 2024 21:46:55 -0700 Subject: [PATCH] Update max output tokens to 16K (#1110) --- skyvern/forge/sdk/api/llm/api_handler_factory.py | 8 ++++---- skyvern/forge/sdk/api/llm/config_registry.py | 15 +++++++++++++-- skyvern/forge/sdk/api/llm/models.py | 2 ++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index d73519b5..55b458fa 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -74,7 +74,7 @@ class LLMAPIHandlerFactory: The response from the LLM router. """ if parameters is None: - parameters = LLMAPIHandlerFactory.get_api_parameters() + parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config) if step: await app.ARTIFACT_MANAGER.create_artifact( @@ -168,7 +168,7 @@ class LLMAPIHandlerFactory: ) -> dict[str, Any]: active_parameters = base_parameters or {} if parameters is None: - parameters = LLMAPIHandlerFactory.get_api_parameters() + parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config) active_parameters.update(parameters) if llm_config.litellm_params: # type: ignore @@ -261,9 +261,9 @@ class LLMAPIHandlerFactory: return llm_api_handler @staticmethod - def get_api_parameters() -> dict[str, Any]: + def get_api_parameters(llm_config: LLMConfig | LLMRouterConfig) -> dict[str, Any]: return { - "max_tokens": SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS, + "max_tokens": llm_config.max_output_tokens, "temperature": SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE, } diff --git a/skyvern/forge/sdk/api/llm/config_registry.py b/skyvern/forge/sdk/api/llm/config_registry.py index 1507ebab..1f4dc6af 100644 --- a/skyvern/forge/sdk/api/llm/config_registry.py +++ b/skyvern/forge/sdk/api/llm/config_registry.py @@ -76,7 +76,10 @@ if SettingsManager.get_settings().ENABLE_OPENAI: ), ) LLMConfigRegistry.register_config( - "OPENAI_GPT4O", LLMConfig("gpt-4o", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False) + "OPENAI_GPT4O", + LLMConfig( + "gpt-4o", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False, max_output_tokens=16384 + ), ) LLMConfigRegistry.register_config( "OPENAI_GPT4O_MINI", @@ -85,11 +88,18 @@ if SettingsManager.get_settings().ENABLE_OPENAI: ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False, + max_output_tokens=16384, ), ) LLMConfigRegistry.register_config( "OPENAI_GPT-4O-2024-08-06", - LLMConfig("gpt-4o-2024-08-06", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False), + LLMConfig( + "gpt-4o-2024-08-06", + ["OPENAI_API_KEY"], + supports_vision=True, + add_assistant_prefix=False, + max_output_tokens=16384, + ), ) @@ -137,6 +147,7 @@ if SettingsManager.get_settings().ENABLE_ANTHROPIC: ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True, + max_output_tokens=8192, ), ) diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index 015ad574..20b99c17 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -34,6 +34,7 @@ class LLMConfigBase: @dataclass(frozen=True) class LLMConfig(LLMConfigBase): litellm_params: Optional[LiteLLMParams] = field(default=None) + max_output_tokens: int = SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS @dataclass(frozen=True) @@ -69,6 +70,7 @@ class LLMRouterConfig(LLMConfigBase): allowed_fails: int | None = None allowed_fails_policy: AllowedFailsPolicy | None = None cooldown_time: float | None = None + max_output_tokens: int = SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS class LLMAPIHandler(Protocol):