From c58aaba4bb46c34f804f4105c452dc6d06546e73 Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Mon, 25 Mar 2024 00:57:37 -0700 Subject: [PATCH] Ykeremy/workflow prompt block (#124) --- poetry.lock | 2 +- pyproject.toml | 1 + skyvern/forge/sdk/db/client.py | 4 +- skyvern/forge/sdk/prompting.py | 18 ++++++ skyvern/forge/sdk/workflow/models/block.py | 74 +++++++++++++++++++++- skyvern/forge/sdk/workflow/models/yaml.py | 15 ++++- skyvern/forge/sdk/workflow/service.py | 20 +++++- 7 files changed, 129 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index c10bedc1..a767b3dc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6880,4 +6880,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11,<3.12" -content-hash = "402c47a5e38eef5bbd38a63cc8116f11b4abf0d676c9b411a2d7f425748e9c4c" +content-hash = "9c2a8d3c2c9b239c6338f53485f9eace6d3eac112fa9246d9bc0a83c92f61a1d" diff --git a/pyproject.toml b/pyproject.toml index 74bcc3bd..cd2cd932 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ curlify = "^2.2.1" typer = "^0.9.0" types-toml = "^0.10.8.7" apscheduler = "^3.10.4" +httpx = "^0.27.0" [tool.poetry.group.dev.dependencies] diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index a8128b42..645004f3 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -692,7 +692,9 @@ class AgentDB: ) -> Workflow: try: async with self.Session() as session: - if workflow := await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id).first()): + if workflow := ( + await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id)) + ).first(): if title: workflow.title = title if description: diff --git a/skyvern/forge/sdk/prompting.py b/skyvern/forge/sdk/prompting.py index a23f8c67..9aa9f947 100644 --- a/skyvern/forge/sdk/prompting.py +++ b/skyvern/forge/sdk/prompting.py @@ -96,3 +96,21 @@ class PromptEngine: except Exception: LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True) raise + + def load_prompt_from_string(self, template: str, **kwargs: Any) -> str: + """ + Load and populate the specified template from a string. + + Args: + template (str): The template string to load. + **kwargs: The arguments to populate the template with. + + Returns: + str: The populated template. + """ + try: + jinja_template = self.env.from_string(template) + return jinja_template.render(**kwargs) + except Exception: + LOG.error("Failed to load prompt from string.", template=template, kwargs_keys=kwargs.keys(), exc_info=True) + raise diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 00ba04ef..0a09ccdd 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -1,4 +1,5 @@ import abc +import json from enum import StrEnum from typing import Annotated, Any, Literal, Union @@ -12,6 +13,8 @@ from skyvern.exceptions import ( UnexpectedTaskStatus, ) from skyvern.forge import app +from skyvern.forge.prompts import prompt_engine +from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.schemas.tasks import TaskStatus from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.models.parameter import ( @@ -28,6 +31,7 @@ class BlockType(StrEnum): TASK = "task" FOR_LOOP = "for_loop" CODE = "code" + TEXT_PROMPT = "text_prompt" class Block(BaseModel, abc.ABC): @@ -345,5 +349,73 @@ class CodeBlock(Block): return None -BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock] +class TextPromptBlock(Block): + block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT + + llm_key: str + prompt: str + parameters: list[PARAMETER_TYPE] = [] + json_schema: dict[str, Any] | None = None + + def get_all_parameters( + self, + ) -> list[PARAMETER_TYPE]: + return self.parameters + + async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]: + llm_api_handler = LLMAPIHandlerFactory.get_llm_api_handler(self.llm_key) + if not self.json_schema: + self.json_schema = { + "type": "object", + "properties": { + "llm_response": { + "type": "string", + "description": "Your response to the prompt", + } + }, + } + + prompt = prompt_engine.load_prompt_from_string(prompt, **parameter_values) + prompt += ( + "\n\n" + + "Please respond to the prompt above using the following JSON definition:\n\n" + + "```json\n" + + json.dumps(self.json_schema, indent=2) + + "\n```\n\n" + ) + LOG.info("TextPromptBlock: Sending prompt to LLM", prompt=prompt, llm_key=self.llm_key) + response = await llm_api_handler(prompt=prompt) + LOG.info("TextPromptBlock: Received response from LLM", response=response) + return response + + async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None: + # get workflow run context + workflow_run_context = self.get_workflow_run_context(workflow_run_id) + # get all parameters into a dictionary + parameter_values = {} + for parameter in self.parameters: + value = workflow_run_context.get_value(parameter.key) + secret_value = workflow_run_context.get_original_secret_value_or_none(value) + if secret_value is not None: + parameter_values[parameter.key] = secret_value + else: + parameter_values[parameter.key] = value + + response = await self.send_prompt(self.prompt, parameter_values) + if self.output_parameter: + await workflow_run_context.register_output_parameter_value_post_execution( + parameter=self.output_parameter, + value=response, + ) + await app.DATABASE.create_workflow_run_output_parameter( + workflow_run_id=workflow_run_id, + output_parameter_id=self.output_parameter.output_parameter_id, + value=response, + ) + return self.output_parameter + + return None + + +BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock] BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")] diff --git a/skyvern/forge/sdk/workflow/models/yaml.py b/skyvern/forge/sdk/workflow/models/yaml.py index df5f35e9..fefb9ad3 100644 --- a/skyvern/forge/sdk/workflow/models/yaml.py +++ b/skyvern/forge/sdk/workflow/models/yaml.py @@ -94,10 +94,23 @@ class CodeBlockYAML(BlockYAML): parameter_keys: list[str] | None = None +class TextPromptBlockYAML(BlockYAML): + # There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error: + # Parameter 1 of Literal[...] cannot be of type "Any" + # This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able + # to infer the type of the parameter_type attribute. + block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT # type: ignore + + llm_key: str + prompt: str + parameter_keys: list[str] | None = None + json_schema: dict[str, Any] | None = None + + PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")] -BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML +BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")] diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 352f79b9..c2c9c160 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -21,7 +21,14 @@ from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateParameterKeys -from skyvern.forge.sdk.workflow.models.block import BlockType, BlockTypeVar, CodeBlock, ForLoopBlock, TaskBlock +from skyvern.forge.sdk.workflow.models.block import ( + BlockType, + BlockTypeVar, + CodeBlock, + ForLoopBlock, + TaskBlock, + TextPromptBlock, +) from skyvern.forge.sdk.workflow.models.parameter import ( AWSSecretParameter, OutputParameter, @@ -714,4 +721,15 @@ class WorkflowService: else [], output_parameter=output_parameter, ) + elif block_yaml.block_type == BlockType.TEXT_PROMPT: + return TextPromptBlock( + label=block_yaml.label, + llm_key=block_yaml.llm_key, + prompt=block_yaml.prompt, + parameters=[parameters[parameter_key] for parameter_key in block_yaml.parameter_keys] + if block_yaml.parameter_keys + else [], + json_schema=block_yaml.json_schema, + output_parameter=output_parameter, + ) raise ValueError(f"Invalid block type {block_yaml.block_type}")