Ykeremy/workflow prompt block (#124)

This commit is contained in:
Kerem Yilmaz
2024-03-25 00:57:37 -07:00
committed by GitHub
parent 0b5456a4c6
commit c58aaba4bb
7 changed files with 129 additions and 5 deletions

2
poetry.lock generated
View File

@@ -6880,4 +6880,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.11,<3.12"
content-hash = "402c47a5e38eef5bbd38a63cc8116f11b4abf0d676c9b411a2d7f425748e9c4c"
content-hash = "9c2a8d3c2c9b239c6338f53485f9eace6d3eac112fa9246d9bc0a83c92f61a1d"

View File

@@ -47,6 +47,7 @@ curlify = "^2.2.1"
typer = "^0.9.0"
types-toml = "^0.10.8.7"
apscheduler = "^3.10.4"
httpx = "^0.27.0"
[tool.poetry.group.dev.dependencies]

View File

@@ -692,7 +692,9 @@ class AgentDB:
) -> Workflow:
try:
async with self.Session() as session:
if workflow := await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id).first()):
if workflow := (
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))
).first():
if title:
workflow.title = title
if description:

View File

@@ -96,3 +96,21 @@ class PromptEngine:
except Exception:
LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
raise
def load_prompt_from_string(self, template: str, **kwargs: Any) -> str:
"""
Load and populate the specified template from a string.
Args:
template (str): The template string to load.
**kwargs: The arguments to populate the template with.
Returns:
str: The populated template.
"""
try:
jinja_template = self.env.from_string(template)
return jinja_template.render(**kwargs)
except Exception:
LOG.error("Failed to load prompt from string.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
raise

View File

@@ -1,4 +1,5 @@
import abc
import json
from enum import StrEnum
from typing import Annotated, Any, Literal, Union
@@ -12,6 +13,8 @@ from skyvern.exceptions import (
UnexpectedTaskStatus,
)
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.parameter import (
@@ -28,6 +31,7 @@ class BlockType(StrEnum):
TASK = "task"
FOR_LOOP = "for_loop"
CODE = "code"
TEXT_PROMPT = "text_prompt"
class Block(BaseModel, abc.ABC):
@@ -345,5 +349,73 @@ class CodeBlock(Block):
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock]
class TextPromptBlock(Block):
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT
llm_key: str
prompt: str
parameters: list[PARAMETER_TYPE] = []
json_schema: dict[str, Any] | None = None
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.parameters
async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]:
llm_api_handler = LLMAPIHandlerFactory.get_llm_api_handler(self.llm_key)
if not self.json_schema:
self.json_schema = {
"type": "object",
"properties": {
"llm_response": {
"type": "string",
"description": "Your response to the prompt",
}
},
}
prompt = prompt_engine.load_prompt_from_string(prompt, **parameter_values)
prompt += (
"\n\n"
+ "Please respond to the prompt above using the following JSON definition:\n\n"
+ "```json\n"
+ json.dumps(self.json_schema, indent=2)
+ "\n```\n\n"
)
LOG.info("TextPromptBlock: Sending prompt to LLM", prompt=prompt, llm_key=self.llm_key)
response = await llm_api_handler(prompt=prompt)
LOG.info("TextPromptBlock: Received response from LLM", response=response)
return response
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
# get all parameters into a dictionary
parameter_values = {}
for parameter in self.parameters:
value = workflow_run_context.get_value(parameter.key)
secret_value = workflow_run_context.get_original_secret_value_or_none(value)
if secret_value is not None:
parameter_values[parameter.key] = secret_value
else:
parameter_values[parameter.key] = value
response = await self.send_prompt(self.prompt, parameter_values)
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=response,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=response,
)
return self.output_parameter
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]

View File

@@ -94,10 +94,23 @@ class CodeBlockYAML(BlockYAML):
parameter_keys: list[str] | None = None
class TextPromptBlockYAML(BlockYAML):
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
# Parameter 1 of Literal[...] cannot be of type "Any"
# This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able
# to infer the type of the parameter_type attribute.
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT # type: ignore
llm_key: str
prompt: str
parameter_keys: list[str] | None = None
json_schema: dict[str, Any] | None = None
PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML
PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")]
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]

View File

@@ -21,7 +21,14 @@ from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateParameterKeys
from skyvern.forge.sdk.workflow.models.block import BlockType, BlockTypeVar, CodeBlock, ForLoopBlock, TaskBlock
from skyvern.forge.sdk.workflow.models.block import (
BlockType,
BlockTypeVar,
CodeBlock,
ForLoopBlock,
TaskBlock,
TextPromptBlock,
)
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
OutputParameter,
@@ -714,4 +721,15 @@ class WorkflowService:
else [],
output_parameter=output_parameter,
)
elif block_yaml.block_type == BlockType.TEXT_PROMPT:
return TextPromptBlock(
label=block_yaml.label,
llm_key=block_yaml.llm_key,
prompt=block_yaml.prompt,
parameters=[parameters[parameter_key] for parameter_key in block_yaml.parameter_keys]
if block_yaml.parameter_keys
else [],
json_schema=block_yaml.json_schema,
output_parameter=output_parameter,
)
raise ValueError(f"Invalid block type {block_yaml.block_type}")