Ykeremy/workflow prompt block (#124)

This commit is contained in:
Kerem Yilmaz
2024-03-25 00:57:37 -07:00
committed by GitHub
parent 0b5456a4c6
commit c58aaba4bb
7 changed files with 129 additions and 5 deletions

View File

@@ -1,4 +1,5 @@
import abc
import json
from enum import StrEnum
from typing import Annotated, Any, Literal, Union
@@ -12,6 +13,8 @@ from skyvern.exceptions import (
UnexpectedTaskStatus,
)
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.parameter import (
@@ -28,6 +31,7 @@ class BlockType(StrEnum):
TASK = "task"
FOR_LOOP = "for_loop"
CODE = "code"
TEXT_PROMPT = "text_prompt"
class Block(BaseModel, abc.ABC):
@@ -345,5 +349,73 @@ class CodeBlock(Block):
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock]
class TextPromptBlock(Block):
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT
llm_key: str
prompt: str
parameters: list[PARAMETER_TYPE] = []
json_schema: dict[str, Any] | None = None
def get_all_parameters(
self,
) -> list[PARAMETER_TYPE]:
return self.parameters
async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]:
llm_api_handler = LLMAPIHandlerFactory.get_llm_api_handler(self.llm_key)
if not self.json_schema:
self.json_schema = {
"type": "object",
"properties": {
"llm_response": {
"type": "string",
"description": "Your response to the prompt",
}
},
}
prompt = prompt_engine.load_prompt_from_string(prompt, **parameter_values)
prompt += (
"\n\n"
+ "Please respond to the prompt above using the following JSON definition:\n\n"
+ "```json\n"
+ json.dumps(self.json_schema, indent=2)
+ "\n```\n\n"
)
LOG.info("TextPromptBlock: Sending prompt to LLM", prompt=prompt, llm_key=self.llm_key)
response = await llm_api_handler(prompt=prompt)
LOG.info("TextPromptBlock: Received response from LLM", response=response)
return response
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
# get workflow run context
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
# get all parameters into a dictionary
parameter_values = {}
for parameter in self.parameters:
value = workflow_run_context.get_value(parameter.key)
secret_value = workflow_run_context.get_original_secret_value_or_none(value)
if secret_value is not None:
parameter_values[parameter.key] = secret_value
else:
parameter_values[parameter.key] = value
response = await self.send_prompt(self.prompt, parameter_values)
if self.output_parameter:
await workflow_run_context.register_output_parameter_value_post_execution(
parameter=self.output_parameter,
value=response,
)
await app.DATABASE.create_workflow_run_output_parameter(
workflow_run_id=workflow_run_id,
output_parameter_id=self.output_parameter.output_parameter_id,
value=response,
)
return self.output_parameter
return None
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock, TextPromptBlock]
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]

View File

@@ -94,10 +94,23 @@ class CodeBlockYAML(BlockYAML):
parameter_keys: list[str] | None = None
class TextPromptBlockYAML(BlockYAML):
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
# Parameter 1 of Literal[...] cannot be of type "Any"
# This pattern already works in block.py but since the BlockType is not defined in this file, mypy is not able
# to infer the type of the parameter_type attribute.
block_type: Literal[BlockType.TEXT_PROMPT] = BlockType.TEXT_PROMPT # type: ignore
llm_key: str
prompt: str
parameter_keys: list[str] | None = None
json_schema: dict[str, Any] | None = None
PARAMETER_YAML_SUBCLASSES = AWSSecretParameterYAML | WorkflowParameterYAML | ContextParameterYAML | OutputParameterYAML
PARAMETER_YAML_TYPES = Annotated[PARAMETER_YAML_SUBCLASSES, Field(discriminator="parameter_type")]
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML
BLOCK_YAML_SUBCLASSES = TaskBlockYAML | ForLoopBlockYAML | CodeBlockYAML | TextPromptBlockYAML
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]