overhual llm key override (#2677)
This commit is contained in:
@@ -45,6 +45,20 @@ class LLMCallStats(BaseModel):
|
||||
class LLMAPIHandlerFactory:
|
||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_override_llm_api_handler(override_llm_key: str | None, *, default: LLMAPIHandler) -> LLMAPIHandler:
|
||||
if not override_llm_key:
|
||||
return default
|
||||
try:
|
||||
return LLMAPIHandlerFactory.get_llm_api_handler(override_llm_key)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to get override LLM API handler, going to use the default.",
|
||||
override_llm_key=override_llm_key,
|
||||
exc_info=True,
|
||||
)
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:
|
||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
@@ -82,7 +96,6 @@ class LLMAPIHandlerFactory:
|
||||
ai_suggestion: AISuggestion | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
llm_key_override: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
|
||||
@@ -96,18 +109,10 @@ class LLMAPIHandlerFactory:
|
||||
Returns:
|
||||
The response from the LLM router.
|
||||
"""
|
||||
nonlocal llm_config
|
||||
nonlocal llm_key
|
||||
|
||||
local_llm_config: LLMConfig | LLMRouterConfig = llm_config
|
||||
if llm_key_override:
|
||||
local_llm_config = LLMConfigRegistry.get_config(llm_key_override)
|
||||
|
||||
local_llm_key = llm_key_override or llm_key
|
||||
start_time = time.time()
|
||||
|
||||
if parameters is None:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config)
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
|
||||
|
||||
context = skyvern_context.current()
|
||||
if context and len(context.hashed_href_map) > 0:
|
||||
@@ -128,12 +133,12 @@ class LLMAPIHandlerFactory:
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
)
|
||||
messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix)
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": local_llm_key,
|
||||
"model": llm_key,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
@@ -149,12 +154,12 @@ class LLMAPIHandlerFactory:
|
||||
model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters
|
||||
)
|
||||
except litellm.exceptions.APIError as e:
|
||||
raise LLMProviderErrorRetryableTask(local_llm_key) from e
|
||||
raise LLMProviderErrorRetryableTask(llm_key) from e
|
||||
except litellm.exceptions.ContextWindowExceededError as e:
|
||||
duration_seconds = time.time() - start_time
|
||||
LOG.exception(
|
||||
"Context window exceeded",
|
||||
llm_key=local_llm_key,
|
||||
llm_key=llm_key,
|
||||
model=main_model_group,
|
||||
prompt_name=prompt_name,
|
||||
duration_seconds=duration_seconds,
|
||||
@@ -164,22 +169,22 @@ class LLMAPIHandlerFactory:
|
||||
duration_seconds = time.time() - start_time
|
||||
LOG.exception(
|
||||
"LLM token limit exceeded",
|
||||
llm_key=local_llm_key,
|
||||
llm_key=llm_key,
|
||||
model=main_model_group,
|
||||
prompt_name=prompt_name,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
raise LLMProviderErrorRetryableTask(local_llm_key) from e
|
||||
raise LLMProviderErrorRetryableTask(llm_key) from e
|
||||
except Exception as e:
|
||||
duration_seconds = time.time() - start_time
|
||||
LOG.exception(
|
||||
"LLM request failed unexpectedly",
|
||||
llm_key=local_llm_key,
|
||||
llm_key=llm_key,
|
||||
model=main_model_group,
|
||||
prompt_name=prompt_name,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
raise LLMProviderError(local_llm_key) from e
|
||||
raise LLMProviderError(llm_key) from e
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
@@ -226,7 +231,7 @@ class LLMAPIHandlerFactory:
|
||||
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
|
||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||
)
|
||||
parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix)
|
||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
@@ -253,7 +258,7 @@ class LLMAPIHandlerFactory:
|
||||
duration_seconds = time.time() - start_time
|
||||
LOG.info(
|
||||
"LLM API handler duration metrics",
|
||||
llm_key=local_llm_key,
|
||||
llm_key=llm_key,
|
||||
model=main_model_group,
|
||||
prompt_name=prompt_name,
|
||||
duration_seconds=duration_seconds,
|
||||
@@ -287,25 +292,15 @@ class LLMAPIHandlerFactory:
|
||||
ai_suggestion: AISuggestion | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
llm_key_override: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
nonlocal llm_config
|
||||
nonlocal llm_key
|
||||
|
||||
local_llm_config: LLMConfig | LLMRouterConfig = llm_config
|
||||
if llm_key_override:
|
||||
local_llm_config = LLMConfigRegistry.get_config(llm_key_override)
|
||||
|
||||
local_llm_key = llm_key_override or llm_key
|
||||
|
||||
start_time = time.time()
|
||||
active_parameters = base_parameters or {}
|
||||
if parameters is None:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config)
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
|
||||
|
||||
active_parameters.update(parameters)
|
||||
if local_llm_config.litellm_params: # type: ignore
|
||||
active_parameters.update(local_llm_config.litellm_params) # type: ignore
|
||||
if llm_config.litellm_params: # type: ignore
|
||||
active_parameters.update(llm_config.litellm_params) # type: ignore
|
||||
|
||||
context = skyvern_context.current()
|
||||
if context and len(context.hashed_href_map) > 0:
|
||||
@@ -328,12 +323,12 @@ class LLMAPIHandlerFactory:
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
|
||||
if not local_llm_config.supports_vision:
|
||||
if not llm_config.supports_vision:
|
||||
screenshots = None
|
||||
|
||||
model_name = local_llm_config.model_name
|
||||
model_name = llm_config.model_name
|
||||
|
||||
messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix)
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(
|
||||
{
|
||||
@@ -361,12 +356,12 @@ class LLMAPIHandlerFactory:
|
||||
**active_parameters,
|
||||
)
|
||||
except litellm.exceptions.APIError as e:
|
||||
raise LLMProviderErrorRetryableTask(local_llm_key) from e
|
||||
raise LLMProviderErrorRetryableTask(llm_key) from e
|
||||
except litellm.exceptions.ContextWindowExceededError as e:
|
||||
duration_seconds = time.time() - start_time
|
||||
LOG.exception(
|
||||
"Context window exceeded",
|
||||
llm_key=local_llm_key,
|
||||
llm_key=llm_key,
|
||||
model=model_name,
|
||||
prompt_name=prompt_name,
|
||||
duration_seconds=duration_seconds,
|
||||
@@ -376,22 +371,22 @@ class LLMAPIHandlerFactory:
|
||||
t_llm_cancelled = time.perf_counter()
|
||||
LOG.error(
|
||||
"LLM request got cancelled",
|
||||
llm_key=local_llm_key,
|
||||
llm_key=llm_key,
|
||||
model=model_name,
|
||||
prompt_name=prompt_name,
|
||||
duration=t_llm_cancelled - t_llm_request,
|
||||
)
|
||||
raise LLMProviderError(local_llm_key)
|
||||
raise LLMProviderError(llm_key)
|
||||
except Exception as e:
|
||||
duration_seconds = time.time() - start_time
|
||||
LOG.exception(
|
||||
"LLM request failed unexpectedly",
|
||||
llm_key=local_llm_key,
|
||||
llm_key=llm_key,
|
||||
model=model_name,
|
||||
prompt_name=prompt_name,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
raise LLMProviderError(local_llm_key) from e
|
||||
raise LLMProviderError(llm_key) from e
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
@@ -439,7 +434,7 @@ class LLMAPIHandlerFactory:
|
||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||
thought_cost=llm_cost,
|
||||
)
|
||||
parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix)
|
||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
@@ -466,9 +461,9 @@ class LLMAPIHandlerFactory:
|
||||
duration_seconds = time.time() - start_time
|
||||
LOG.info(
|
||||
"LLM API handler duration metrics",
|
||||
llm_key=local_llm_key,
|
||||
llm_key=llm_key,
|
||||
prompt_name=prompt_name,
|
||||
model=local_llm_config.model_name,
|
||||
model=llm_config.model_name,
|
||||
duration_seconds=duration_seconds,
|
||||
step_id=step.step_id if step else None,
|
||||
thought_id=thought.observer_thought_id if thought else None,
|
||||
|
||||
@@ -94,7 +94,6 @@ class LLMAPIHandler(Protocol):
|
||||
ai_suggestion: AISuggestion | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
llm_key_override: str | None = None,
|
||||
) -> Awaitable[dict[str, Any]]: ...
|
||||
|
||||
|
||||
@@ -107,6 +106,5 @@ async def dummy_llm_api_handler(
|
||||
ai_suggestion: AISuggestion | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
llm_key_override: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")
|
||||
|
||||
Reference in New Issue
Block a user