Ykeremy/workflow prompt block (#124)
This commit is contained in:
2
poetry.lock
generated
2
poetry.lock
generated
@@ -6880,4 +6880,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11,<3.12"
|
||||
content-hash = "402c47a5e38eef5bbd38a63cc8116f11b4abf0d676c9b411a2d7f425748e9c4c"
|
||||
content-hash = "9c2a8d3c2c9b239c6338f53485f9eace6d3eac112fa9246d9bc0a83c92f61a1d"
|
||||
|
||||
@@ -47,6 +47,7 @@ curlify = "^2.2.1"
|
||||
typer = "^0.9.0"
|
||||
types-toml = "^0.10.8.7"
|
||||
apscheduler = "^3.10.4"
|
||||
httpx = "^0.27.0"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -692,7 +692,9 @@ class AgentDB:
|
||||
) -> Workflow:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
if workflow := await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id).first()):
|
||||
if workflow := (
|
||||
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id))
|
||||
).first():
|
||||
if title:
|
||||
workflow.title = title
|
||||
if description:
|
||||
|
||||
@@ -96,3 +96,21 @@ class PromptEngine:
|
||||
except Exception:
|
||||
LOG.error("Failed to load prompt.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
|
||||
raise
|
||||
|
||||
def load_prompt_from_string(self, template: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Load and populate the specified template from a string.
|
||||
|
||||
Args:
|
||||
template (str): The template string to load.
|
||||
**kwargs: The arguments to populate the template with.
|
||||
|
||||
Returns:
|
||||
str: The populated template.
|
||||
"""
|
||||
try:
|
||||
jinja_template = self.env.from_string(template)
|
||||
return jinja_template.render(**kwargs)
|
||||
except Exception:
|
||||
LOG.error("Failed to load prompt from string.", template=template, kwargs_keys=kwargs.keys(), exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
import json
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
@@ -12,6 +13,8 @@ from skyvern.exceptions import (
|
||||
UnexpectedTaskStatus,
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
@@ -28,6 +31,7 @@ class BlockType(StrEnum):
|
||||
TASK = "task"
|
||||
FOR_LOOP = "for_loop"
|
||||
CODE = "code"
|
||||
TEXT_PROMPT = "text_prompt"
|
||||
|
||||
|
||||
class Block(BaseModel, abc.ABC):
|
||||
@@ -345,5 +349,73 @@ class CodeBlock(Block):
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock]
|
||||
class TextPromptBlock(Block):
|
||||
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT
|
||||
|
||||
llm_key: str
|
||||
prompt: str
|
||||
parameters: list[PARAMETER_TYPE] = []
|
||||
json_schema: dict[str, Any] | None = None
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
return self.parameters
|
||||
|
||||
async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]:
|
||||
llm_api_handler = LLMAPIHandlerFactory.get_llm_api_handler(self.llm_key)
|
||||
if not self.json_schema:
|
||||
self.json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"llm_response": {
|
||||
"type": "string",
|
||||
"description": "Your response to the prompt",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
prompt = prompt_engine.load_prompt_from_string(prompt, **parameter_values)
|
||||
prompt += (
|
||||
"\n\n"
|
||||
+ "Please respond to the prompt above using the following JSON definition:\n\n"
|
||||
+ "```json\n"
|
||||
+ json.dumps(self.json_schema, indent=2)
|
||||
+ "\n```\n\n"
|
||||
)
|
||||
LOG.info("TextPromptBlock: Sending prompt to LLM", prompt=prompt, llm_key=self.llm_key)
|
||||
response = await llm_api_handler(prompt=prompt)
|
||||
LOG.info("TextPromptBlock: Received response from LLM", response=response)
|
||||
return response
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
# get workflow run context
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
# get all parameters into a dictionary
|
||||
parameter_values = {}
|
||||
for parameter in self.parameters:
|
||||
value = workflow_run_context.get_value(parameter.key)
|
||||
secret_value = workflow_run_context.get_original_secret_value_or_none(value)
|
||||
if secret_value is not None:
|
||||
parameter_values[parameter.key] = secret_value
|
||||
else:
|
||||
parameter_values[parameter.key] = value
|
||||
|
||||
response = await self.send_prompt(self.prompt, parameter_values)
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=response,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=response,
|
||||
)
|
||||
return self.output_parameter
|
||||
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock]
|
||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||
|
||||
@@ -94,10 +94,23 @@ class CodeBlockYAML(BlockYAML):
|
||||
parameter_keys: list[str] | None = None
|
||||
|
||||
|
||||
class TextPromptBlockYAML(BlockYAML):
|
||||
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
||||
# Parameter 1 of Literal[...] cannot be of type "Any"
|
||||
# This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able
|
||||
# to infer the type of the parameter_type attribute.
|
||||
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT # type: ignore
|
||||
|
||||
llm_key: str
|
||||
prompt: str
|
||||
parameter_keys: list[str] | None = None
|
||||
json_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML
|
||||
PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")]
|
||||
|
||||
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML
|
||||
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML
|
||||
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,14 @@ from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateParameterKeys
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockType, BlockTypeVar, CodeBlock, ForLoopBlock, TaskBlock
|
||||
from skyvern.forge.sdk.workflow.models.block import (
|
||||
BlockType,
|
||||
BlockTypeVar,
|
||||
CodeBlock,
|
||||
ForLoopBlock,
|
||||
TaskBlock,
|
||||
TextPromptBlock,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
AWSSecretParameter,
|
||||
OutputParameter,
|
||||
@@ -714,4 +721,15 @@ class WorkflowService:
|
||||
else [],
|
||||
output_parameter=output_parameter,
|
||||
)
|
||||
elif block_yaml.block_type == BlockType.TEXT_PROMPT:
|
||||
return TextPromptBlock(
|
||||
label=block_yaml.label,
|
||||
llm_key=block_yaml.llm_key,
|
||||
prompt=block_yaml.prompt,
|
||||
parameters=[parameters[parameter_key] for parameter_key in block_yaml.parameter_keys]
|
||||
if block_yaml.parameter_keys
|
||||
else [],
|
||||
json_schema=block_yaml.json_schema,
|
||||
output_parameter=output_parameter,
|
||||
)
|
||||
raise ValueError(f"Invalid block type {block_yaml.block_type}")
|
||||
|
||||
Reference in New Issue
Block a user