diff --git a/skyvern/forge/sdk/experimentation/llm_prompt_config.py b/skyvern/forge/sdk/experimentation/llm_prompt_config.py new file mode 100644 index 00000000..5ba7c6de --- /dev/null +++ b/skyvern/forge/sdk/experimentation/llm_prompt_config.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import json + +import structlog + +from skyvern.forge import app +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory +from skyvern.forge.sdk.api.llm.models import LLMAPIHandler + +LOG = structlog.get_logger() + + +async def get_llm_config_by_prompt_type(distinct_id: str, organization_id: str | None = None) -> dict[str, str] | None: + """Return PostHog-configured LLM mapping for each prompt type.""" + llm_config_experiment = await app.EXPERIMENTATION_PROVIDER.get_value_cached( + "LLM_CONFIG_BY_PROMPT_TYPE", distinct_id, properties={"organization_id": organization_id} + ) + if llm_config_experiment in (False, "False") or not llm_config_experiment: + return None + + payload = await app.EXPERIMENTATION_PROVIDER.get_payload_cached( + "LLM_CONFIG_BY_PROMPT_TYPE", distinct_id, properties={"organization_id": organization_id} + ) + if not payload: + LOG.warning( + "No payload found for LLM config experiment", + distinct_id=distinct_id, + organization_id=organization_id, + variant=llm_config_experiment, + ) + return None + + try: + config = json.loads(payload) + except (json.JSONDecodeError, KeyError, TypeError) as exc: + LOG.warning( + "Failed to parse LLM config experiment payload", + distinct_id=distinct_id, + organization_id=organization_id, + variant=llm_config_experiment, + payload=payload, + error=str(exc), + ) + return None + + LOG.info( + "LLM config by prompt type experiment enabled", + distinct_id=distinct_id, + organization_id=organization_id, + variant=llm_config_experiment, + config=config, + ) + return config + + +async def get_llm_handler_for_prompt_type( + prompt_type: str, distinct_id: str, organization_id: str | None = None +) -> LLMAPIHandler | None: + """Return initialized handler for prompt type from LLM_CONFIG_BY_PROMPT_TYPE flag.""" + config = await get_llm_config_by_prompt_type(distinct_id, organization_id) + if not config or prompt_type not in config: + LOG.warning( + "No config found for prompt type", + prompt_type=prompt_type, + config=config, + distinct_id=distinct_id, + organization_id=organization_id, + ) + return None + + llm_config_name = config[prompt_type] + try: + handler = LLMAPIHandlerFactory.get_llm_api_handler(llm_config_name) + LOG.info( + "Using LLM handler for prompt type from experiment", + prompt_type=prompt_type, + llm_config_name=llm_config_name, + distinct_id=distinct_id, + organization_id=organization_id, + ) + return handler + except Exception: + LOG.error( + "Failed to initialize LLM handler for prompt type", + prompt_type=prompt_type, + llm_config_name=llm_config_name, + distinct_id=distinct_id, + organization_id=organization_id, + exc_info=True, + ) + return None diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 0e139f15..102c71ac 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -57,10 +57,12 @@ from skyvern.forge.sdk.api.files import ( get_path_for_workflow_download_directory, ) from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory +from skyvern.forge.sdk.api.llm.models import LLMAPIHandler from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_request from skyvern.forge.sdk.db.enums import TaskType +from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.task_v2 import TaskV2Status from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus @@ -1822,9 +1824,16 @@ class TextPromptBlock(Block): ) self.prompt = self.format_block_parameter_template_from_workflow_run_context(self.prompt, workflow_run_context) - async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]: + async def send_prompt( + self, + prompt: str, + parameter_values: dict[str, Any], + workflow_run_id: str, + organization_id: str | None = None, + ) -> dict[str, Any]: + default_llm_handler = await self._resolve_default_llm_handler(workflow_run_id, organization_id) llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( - self.override_llm_key or self.llm_key, default=app.LLM_API_HANDLER + self.override_llm_key or self.llm_key, default=default_llm_handler ) if not self.json_schema: self.json_schema = { @@ -1854,6 +1863,22 @@ class TextPromptBlock(Block): LOG.info("TextPromptBlock: Received response from LLM", response=response) return response + async def _resolve_default_llm_handler(self, workflow_run_id: str, organization_id: str | None) -> LLMAPIHandler: + prompt_config_handler = await get_llm_handler_for_prompt_type("text-prompt", workflow_run_id, organization_id) + if prompt_config_handler: + return prompt_config_handler + + secondary_handler = app.SECONDARY_LLM_API_HANDLER + if secondary_handler: + return secondary_handler + + LOG.warning( + "Secondary LLM handler not configured; falling back to primary handler for TextPromptBlock", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return app.LLM_API_HANDLER + async def execute( self, workflow_run_id: str, @@ -1897,7 +1922,7 @@ class TextPromptBlock(Block): else: parameter_values[parameter.key] = value - response = await self.send_prompt(self.prompt, parameter_values) + response = await self.send_prompt(self.prompt, parameter_values, workflow_run_id, organization_id) await self.record_output_parameter_value(workflow_run_context, workflow_run_id, response) return await self.build_block_result( success=True,