Ykeremy/workflow prompt block (#124)

This commit is contained in:
Kerem Yilmaz
2024-03-25 00:57:37 -07:00
committed by GitHub
parent 0b5456a4c6
commit c58aaba4bb
7 changed files with 129 additions and 5 deletions

2
poetry.lock generated
View File

@@ -6880,4 +6880,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11,<3.12" python-versions = "^3.11,<3.12"
content-hash = "402c47a5e38eef5bbd38a63cc8116f11b4abf0d676c9b411a2d7f425748e9c4c" content-hash = "9c2a8d3c2c9b239c6338f53485f9eace6d3eac112fa9246d9bc0a83c92f61a1d"

View File

@@ -47,6 +47,7 @@ curlify = "^2.2.1"
typer = "^0.9.0" typer = "^0.9.0"
types-toml = "^0.10.8.7" types-toml = "^0.10.8.7"
apscheduler = "^3.10.4" apscheduler = "^3.10.4"
httpx = "^0.27.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]

View File

@@ -692,7 +692,9 @@ class AgentDB:
) -> Workflow: ) -> Workflow:
try: try:
async with self.Session() as session: async with self.Session() as session:
if workflow := await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id).first()): if workflow := (
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))
).first():
if title: if title:
workflow.title = title workflow.title = title
if description: if description:

View File

@@ -96,3 +96,21 @@ class PromptEngine:
except Exception: except Exception:
LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True) LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
raise raise
def load_prompt_from_string(self, template: str, **kwargs: Any) -> str:
"""
Load and populate the specified template from a string.
Args:
template (str): The template string to load.
**kwargs: The arguments to populate the template with.
Returns:
str: The populated template.
"""
try:
jinja_template = self.env.from_string(template)
return jinja_template.render(**kwargs)
except Exception:
LOG.error("Failed to load prompt from string.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
raise

View File

@@ -1,4 +1,5 @@
import abc import abc
import json
from enum import StrEnum from enum import StrEnum
from typing import Annotated, Any, Literal, Union from typing import Annotated, Any, Literal, Union
@@ -12,6 +13,8 @@ from skyvern.exceptions import (
UnexpectedTaskStatus, UnexpectedTaskStatus,
) )
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.schemas.tasks import TaskStatus from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.parameter import ( from skyvern.forge.sdk.workflow.models.parameter import (
@@ -28,6 +31,7 @@ class BlockType(StrEnum):
TASK = "task" TASK = "task"
FOR_LOOP = "for_loop" FOR_LOOP = "for_loop"
CODE = "code" CODE = "code"
TEXT_PROMPT = "text_prompt"
class Block(BaseModel, abc.ABC): class Block(BaseModel, abc.ABC):
@@ -345,5 +349,73 @@ class CodeBlock(Block):
return None return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock] class TextPromptBlock(Block):
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT
llm_key: str
prompt: str
parameters: list[PARAMETER_TYPE] = []
json_schema: dict[str, Any] | None = None
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.parameters
async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]:
llm_api_handler = LLMAPIHandlerFactory.get_llm_api_handler(self.llm_key)
if not self.json_schema:
self.json_schema = {
"type": "object",
"properties": {
"llm_response": {
"type": "string",
"description": "Your response to the prompt",
}
},
}
prompt = prompt_engine.load_prompt_from_string(prompt, **parameter_values)
prompt += (
"\n\n"
+ "Please respond to the prompt above using the following JSON definition:\n\n"
+ "```json\n"
+ json.dumps(self.json_schema, indent=2)
+ "\n```\n\n"
)
LOG.info("TextPromptBlock: Sending prompt to LLM", prompt=prompt, llm_key=self.llm_key)
response = await llm_api_handler(prompt=prompt)
LOG.info("TextPromptBlock: Received response from LLM", response=response)
return response
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
# get all parameters into a dictionary
parameter_values = {}
for parameter in self.parameters:
value = workflow_run_context.get_value(parameter.key)
secret_value = workflow_run_context.get_original_secret_value_or_none(value)
if secret_value is not None:
parameter_values[parameter.key] = secret_value
else:
parameter_values[parameter.key] = value
response = await self.send_prompt(self.prompt, parameter_values)
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=response,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=response,
)
return self.output_parameter
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")] BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]

View File

@@ -94,10 +94,23 @@ class CodeBlockYAML(BlockYAML):
parameter_keys: list[str] | None = None parameter_keys: list[str] | None = None
class TextPromptBlockYAML(BlockYAML):
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
# Parameter 1 of Literal[...] cannot be of type "Any"
# This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able
# to infer the type of the parameter_type attribute.
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT # type: ignore
llm_key: str
prompt: str
parameter_keys: list[str] | None = None
json_schema: dict[str, Any] | None = None
PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML
PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")] PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")]
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")] BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]

View File

@@ -21,7 +21,14 @@ from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateParameterKeys from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateParameterKeys
from skyvern.forge.sdk.workflow.models.block import BlockType, BlockTypeVar, CodeBlock, ForLoopBlock, TaskBlock from skyvern.forge.sdk.workflow.models.block import (
BlockType,
BlockTypeVar,
CodeBlock,
ForLoopBlock,
TaskBlock,
TextPromptBlock,
)
from skyvern.forge.sdk.workflow.models.parameter import ( from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter, AWSSecretParameter,
OutputParameter, OutputParameter,
@@ -714,4 +721,15 @@ class WorkflowService:
else [], else [],
output_parameter=output_parameter, output_parameter=output_parameter,
) )
elif block_yaml.block_type == BlockType.TEXT_PROMPT:
return TextPromptBlock(
label=block_yaml.label,
llm_key=block_yaml.llm_key,
prompt=block_yaml.prompt,
parameters=[parameters[parameter_key] for parameter_key in block_yaml.parameter_keys]
if block_yaml.parameter_keys
else [],
json_schema=block_yaml.json_schema,
output_parameter=output_parameter,
)
raise ValueError(f"Invalid block type {block_yaml.block_type}") raise ValueError(f"Invalid block type {block_yaml.block_type}")