Pass base parameters to acompletion (#343)
This commit is contained in:
@@ -133,7 +133,7 @@ class LLMAPIHandlerFactory:
|
||||
return llm_api_handler_with_router_and_fallback
|
||||
|
||||
@staticmethod
|
||||
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
|
||||
def get_llm_api_handler(llm_key: str, base_parameters: dict[str, Any] | None = None) -> LLMAPIHandler:
|
||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
|
||||
if LLMConfigRegistry.is_router_config(llm_key):
|
||||
@@ -145,9 +145,12 @@ class LLMAPIHandlerFactory:
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
active_parameters = base_parameters or {}
|
||||
if parameters is None:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters()
|
||||
|
||||
active_parameters.update(parameters)
|
||||
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
@@ -174,6 +177,7 @@ class LLMAPIHandlerFactory:
|
||||
{
|
||||
"model": llm_config.model_name,
|
||||
"messages": messages,
|
||||
# we're not using active_parameters here because it may contain sensitive information
|
||||
**parameters,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
@@ -185,7 +189,7 @@ class LLMAPIHandlerFactory:
|
||||
response = await litellm.acompletion(
|
||||
model=llm_config.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
**active_parameters,
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
raise LLMProviderError(llm_key) from e
|
||||
|
||||
Reference in New Issue
Block a user