prefer secondary llm for text prompts (#4143)
This commit is contained in:
92
skyvern/forge/sdk/experimentation/llm_prompt_config.py
Normal file
92
skyvern/forge/sdk/experimentation/llm_prompt_config.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_llm_config_by_prompt_type(distinct_id: str, organization_id: str | None = None) -> dict[str, str] | None:
|
||||
"""Return PostHog-configured LLM mapping for each prompt type."""
|
||||
llm_config_experiment = await app.EXPERIMENTATION_PROVIDER.get_value_cached(
|
||||
"LLM_CONFIG_BY_PROMPT_TYPE", distinct_id, properties={"organization_id": organization_id}
|
||||
)
|
||||
if llm_config_experiment in (False, "False") or not llm_config_experiment:
|
||||
return None
|
||||
|
||||
payload = await app.EXPERIMENTATION_PROVIDER.get_payload_cached(
|
||||
"LLM_CONFIG_BY_PROMPT_TYPE", distinct_id, properties={"organization_id": organization_id}
|
||||
)
|
||||
if not payload:
|
||||
LOG.warning(
|
||||
"No payload found for LLM config experiment",
|
||||
distinct_id=distinct_id,
|
||||
organization_id=organization_id,
|
||||
variant=llm_config_experiment,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
config = json.loads(payload)
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
||||
LOG.warning(
|
||||
"Failed to parse LLM config experiment payload",
|
||||
distinct_id=distinct_id,
|
||||
organization_id=organization_id,
|
||||
variant=llm_config_experiment,
|
||||
payload=payload,
|
||||
error=str(exc),
|
||||
)
|
||||
return None
|
||||
|
||||
LOG.info(
|
||||
"LLM config by prompt type experiment enabled",
|
||||
distinct_id=distinct_id,
|
||||
organization_id=organization_id,
|
||||
variant=llm_config_experiment,
|
||||
config=config,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
async def get_llm_handler_for_prompt_type(
|
||||
prompt_type: str, distinct_id: str, organization_id: str | None = None
|
||||
) -> LLMAPIHandler | None:
|
||||
"""Return initialized handler for prompt type from LLM_CONFIG_BY_PROMPT_TYPE flag."""
|
||||
config = await get_llm_config_by_prompt_type(distinct_id, organization_id)
|
||||
if not config or prompt_type not in config:
|
||||
LOG.warning(
|
||||
"No config found for prompt type",
|
||||
prompt_type=prompt_type,
|
||||
config=config,
|
||||
distinct_id=distinct_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
llm_config_name = config[prompt_type]
|
||||
try:
|
||||
handler = LLMAPIHandlerFactory.get_llm_api_handler(llm_config_name)
|
||||
LOG.info(
|
||||
"Using LLM handler for prompt type from experiment",
|
||||
prompt_type=prompt_type,
|
||||
llm_config_name=llm_config_name,
|
||||
distinct_id=distinct_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return handler
|
||||
except Exception:
|
||||
LOG.error(
|
||||
"Failed to initialize LLM handler for prompt type",
|
||||
prompt_type=prompt_type,
|
||||
llm_config_name=llm_config_name,
|
||||
distinct_id=distinct_id,
|
||||
organization_id=organization_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
@@ -57,10 +57,12 @@ from skyvern.forge.sdk.api.files import (
|
||||
get_path_for_workflow_download_directory,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_request
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Status
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
|
||||
@@ -1822,9 +1824,16 @@ class TextPromptBlock(Block):
|
||||
)
|
||||
self.prompt = self.format_block_parameter_template_from_workflow_run_context(self.prompt, workflow_run_context)
|
||||
|
||||
async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]:
|
||||
async def send_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
parameter_values: dict[str, Any],
|
||||
workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
default_llm_handler = await self._resolve_default_llm_handler(workflow_run_id, organization_id)
|
||||
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
self.override_llm_key or self.llm_key, default=app.LLM_API_HANDLER
|
||||
self.override_llm_key or self.llm_key, default=default_llm_handler
|
||||
)
|
||||
if not self.json_schema:
|
||||
self.json_schema = {
|
||||
@@ -1854,6 +1863,22 @@ class TextPromptBlock(Block):
|
||||
LOG.info("TextPromptBlock: Received response from LLM", response=response)
|
||||
return response
|
||||
|
||||
async def _resolve_default_llm_handler(self, workflow_run_id: str, organization_id: str | None) -> LLMAPIHandler:
|
||||
prompt_config_handler = await get_llm_handler_for_prompt_type("text-prompt", workflow_run_id, organization_id)
|
||||
if prompt_config_handler:
|
||||
return prompt_config_handler
|
||||
|
||||
secondary_handler = app.SECONDARY_LLM_API_HANDLER
|
||||
if secondary_handler:
|
||||
return secondary_handler
|
||||
|
||||
LOG.warning(
|
||||
"Secondary LLM handler not configured; falling back to primary handler for TextPromptBlock",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return app.LLM_API_HANDLER
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@@ -1897,7 +1922,7 @@ class TextPromptBlock(Block):
|
||||
else:
|
||||
parameter_values[parameter.key] = value
|
||||
|
||||
response = await self.send_prompt(self.prompt, parameter_values)
|
||||
response = await self.send_prompt(self.prompt, parameter_values, workflow_run_id, organization_id)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, response)
|
||||
return await self.build_block_result(
|
||||
success=True,
|
||||
|
||||
Reference in New Issue
Block a user