weave LLMCaller.llm_key through to api handler/agent (#2524)

This commit is contained in:
Shuchang Zheng
2025-05-29 13:49:59 -07:00
committed by GitHub
parent 7f6b65ba61
commit ea5620acd2
3 changed files with 34 additions and 18 deletions

View File

@@ -864,6 +864,7 @@ class ForgeAgent:
prompt_name="extract-actions", prompt_name="extract-actions",
step=step, step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
llm_key_override=llm_caller.llm_key if llm_caller else None,
) )
try: try:
json_response = await self.handle_potential_verification_code( json_response = await self.handle_potential_verification_code(

View File

@@ -82,6 +82,7 @@ class LLMAPIHandlerFactory:
ai_suggestion: AISuggestion | None = None, ai_suggestion: AISuggestion | None = None,
screenshots: list[bytes] | None = None, screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None, parameters: dict[str, Any] | None = None,
llm_key_override: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision. Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.
@@ -267,15 +268,25 @@ class LLMAPIHandlerFactory:
ai_suggestion: AISuggestion | None = None, ai_suggestion: AISuggestion | None = None,
screenshots: list[bytes] | None = None, screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None, parameters: dict[str, Any] | None = None,
llm_key_override: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
nonlocal llm_config
nonlocal llm_key
local_llm_config: LLMConfig | LLMRouterConfig = llm_config
if llm_key_override:
local_llm_config = LLMConfigRegistry.get_config(llm_key_override)
local_llm_key = llm_key_override or llm_key
start_time = time.time() start_time = time.time()
active_parameters = base_parameters or {} active_parameters = base_parameters or {}
if parameters is None: if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config) parameters = LLMAPIHandlerFactory.get_api_parameters(local_llm_config)
active_parameters.update(parameters) active_parameters.update(parameters)
if llm_config.litellm_params: # type: ignore if local_llm_config.litellm_params: # type: ignore
active_parameters.update(llm_config.litellm_params) # type: ignore active_parameters.update(local_llm_config.litellm_params) # type: ignore
context = skyvern_context.current() context = skyvern_context.current()
if context and len(context.hashed_href_map) > 0: if context and len(context.hashed_href_map) > 0:
@@ -298,14 +309,16 @@ class LLMAPIHandlerFactory:
ai_suggestion=ai_suggestion, ai_suggestion=ai_suggestion,
) )
if not llm_config.supports_vision: if not local_llm_config.supports_vision:
screenshots = None screenshots = None
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) model_name = local_llm_config.model_name
messages = await llm_messages_builder(prompt, screenshots, local_llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps( data=json.dumps(
{ {
"model": llm_config.model_name, "model": model_name,
"messages": messages, "messages": messages,
# we're not using active_parameters here because it may contain sensitive information # we're not using active_parameters here because it may contain sensitive information
**parameters, **parameters,
@@ -323,32 +336,32 @@ class LLMAPIHandlerFactory:
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries) # TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work # TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
response = await litellm.acompletion( response = await litellm.acompletion(
model=llm_config.model_name, model=model_name,
messages=messages, messages=messages,
timeout=settings.LLM_CONFIG_TIMEOUT, timeout=settings.LLM_CONFIG_TIMEOUT,
**active_parameters, **active_parameters,
) )
except litellm.exceptions.APIError as e: except litellm.exceptions.APIError as e:
raise LLMProviderErrorRetryableTask(llm_key) from e raise LLMProviderErrorRetryableTask(local_llm_key) from e
except litellm.exceptions.ContextWindowExceededError as e: except litellm.exceptions.ContextWindowExceededError as e:
LOG.exception( LOG.exception(
"Context window exceeded", "Context window exceeded",
llm_key=llm_key, llm_key=local_llm_key,
model=llm_config.model_name, model=model_name,
) )
raise SkyvernContextWindowExceededError() from e raise SkyvernContextWindowExceededError() from e
except CancelledError: except CancelledError:
t_llm_cancelled = time.perf_counter() t_llm_cancelled = time.perf_counter()
LOG.error( LOG.error(
"LLM request got cancelled", "LLM request got cancelled",
llm_key=llm_key, llm_key=local_llm_key,
model=llm_config.model_name, model=model_name,
duration=t_llm_cancelled - t_llm_request, duration=t_llm_cancelled - t_llm_request,
) )
raise LLMProviderError(llm_key) raise LLMProviderError(local_llm_key)
except Exception as e: except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key) LOG.exception("LLM request failed unexpectedly", llm_key=local_llm_key)
raise LLMProviderError(llm_key) from e raise LLMProviderError(local_llm_key) from e
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"), data=response.model_dump_json(indent=2).encode("utf-8"),
@@ -396,7 +409,7 @@ class LLMAPIHandlerFactory:
cached_token_count=cached_tokens if cached_tokens > 0 else None, cached_token_count=cached_tokens if cached_tokens > 0 else None,
thought_cost=llm_cost, thought_cost=llm_cost,
) )
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) parsed_response = parse_api_response(response, local_llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"), data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED, artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
@@ -423,9 +436,9 @@ class LLMAPIHandlerFactory:
duration_seconds = time.time() - start_time duration_seconds = time.time() - start_time
LOG.info( LOG.info(
"LLM API handler duration metrics", "LLM API handler duration metrics",
llm_key=llm_key, llm_key=local_llm_key,
prompt_name=prompt_name, prompt_name=prompt_name,
model=llm_config.model_name, model=local_llm_config.model_name,
duration_seconds=duration_seconds, duration_seconds=duration_seconds,
step_id=step.step_id if step else None, step_id=step.step_id if step else None,
thought_id=thought.observer_thought_id if thought else None, thought_id=thought.observer_thought_id if thought else None,

View File

@@ -94,6 +94,7 @@ class LLMAPIHandler(Protocol):
ai_suggestion: AISuggestion | None = None, ai_suggestion: AISuggestion | None = None,
screenshots: list[bytes] | None = None, screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None, parameters: dict[str, Any] | None = None,
llm_key_override: str | None = None,
) -> Awaitable[dict[str, Any]]: ... ) -> Awaitable[dict[str, Any]]: ...
@@ -106,5 +107,6 @@ async def dummy_llm_api_handler(
ai_suggestion: AISuggestion | None = None, ai_suggestion: AISuggestion | None = None,
screenshots: list[bytes] | None = None, screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None, parameters: dict[str, Any] | None = None,
llm_key_override: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.") raise NotImplementedError("Your LLM provider is not configured. Please configure it in the .env file.")