Workflow: Output Parameters & Code Blocks (#117)
This commit is contained in:
@@ -14,7 +14,12 @@ from skyvern.exceptions import (
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter, WorkflowParameter
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
PARAMETER_TYPE,
|
||||
ContextParameter,
|
||||
OutputParameter,
|
||||
WorkflowParameter,
|
||||
)
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -22,12 +27,14 @@ LOG = structlog.get_logger()
|
||||
class BlockType(StrEnum):
|
||||
TASK = "task"
|
||||
FOR_LOOP = "for_loop"
|
||||
CODE = "code"
|
||||
|
||||
|
||||
class Block(BaseModel, abc.ABC):
|
||||
# Must be unique within workflow definition
|
||||
label: str
|
||||
block_type: BlockType
|
||||
parent_block_id: str | None = None
|
||||
next_block_id: str | None = None
|
||||
output_parameter: OutputParameter | None = None
|
||||
|
||||
@classmethod
|
||||
def get_subclasses(cls) -> tuple[type["Block"], ...]:
|
||||
@@ -38,7 +45,7 @@ class Block(BaseModel, abc.ABC):
|
||||
return app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -96,7 +103,7 @@ class TaskBlock(Block):
|
||||
|
||||
return order, retry + 1
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
current_retry = 0
|
||||
# initial value for will_retry is True, so that the loop runs at least once
|
||||
@@ -158,6 +165,32 @@ class TaskBlock(Block):
|
||||
raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status)
|
||||
if updated_task.status == TaskStatus.completed:
|
||||
will_retry = False
|
||||
LOG.info(
|
||||
f"Task completed",
|
||||
task_id=updated_task.task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=updated_task.extracted_information,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=updated_task.extracted_information,
|
||||
)
|
||||
LOG.info(
|
||||
f"Registered output parameter value",
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=updated_task.extracted_information,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
task_id=updated_task.task_id,
|
||||
)
|
||||
return self.output_parameter
|
||||
else:
|
||||
current_retry += 1
|
||||
will_retry = current_retry <= self.max_retries
|
||||
@@ -172,6 +205,7 @@ class TaskBlock(Block):
|
||||
current_retry=current_retry,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class ForLoopBlock(Block):
|
||||
@@ -216,9 +250,10 @@ class ForLoopBlock(Block):
|
||||
return [parameter_value]
|
||||
else:
|
||||
# TODO (kerem): Implement this for context parameters
|
||||
# TODO (kerem): Implement this for output parameters
|
||||
raise NotImplementedError
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> Any:
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
|
||||
LOG.info(
|
||||
@@ -227,14 +262,77 @@ class ForLoopBlock(Block):
|
||||
workflow_run_id=workflow_run_id,
|
||||
num_loop_over_values=len(loop_over_values),
|
||||
)
|
||||
outputs_with_loop_values = []
|
||||
for loop_over_value in loop_over_values:
|
||||
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
|
||||
for context_parameter in context_parameters_with_value:
|
||||
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
|
||||
await self.loop_block.execute(workflow_run_id=workflow_run_id)
|
||||
if self.loop_block.output_parameter:
|
||||
outputs_with_loop_values.append(
|
||||
{
|
||||
"loop_value": loop_over_value,
|
||||
"output_parameter": self.loop_block.output_parameter,
|
||||
"output_value": workflow_run_context.get_value(self.loop_block.output_parameter.key),
|
||||
}
|
||||
)
|
||||
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=outputs_with_loop_values,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=outputs_with_loop_values,
|
||||
)
|
||||
return self.output_parameter
|
||||
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock]
|
||||
class CodeBlock(Block):
|
||||
block_type: Literal[BlockType.CODE] = BlockType.CODE
|
||||
|
||||
code: str
|
||||
parameters: list[PARAMETER_TYPE] = []
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
) -> list[PARAMETER_TYPE]:
|
||||
return self.parameters
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> OutputParameter | None:
|
||||
# get workflow run context
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
# get all parameters into a dictionary
|
||||
parameter_values = {}
|
||||
for parameter in self.parameters:
|
||||
value = workflow_run_context.get_value(parameter.key)
|
||||
secret_value = workflow_run_context.get_original_secret_value_or_none(value)
|
||||
if secret_value is not None:
|
||||
parameter_values[parameter.key] = secret_value
|
||||
else:
|
||||
parameter_values[parameter.key] = value
|
||||
|
||||
local_variables: dict[str, Any] = {}
|
||||
exec(self.code, parameter_values, local_variables)
|
||||
result = {"result": local_variables.get("result")}
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=result,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=result,
|
||||
)
|
||||
return self.output_parameter
|
||||
|
||||
return None
|
||||
|
||||
|
||||
BlockSubclasses = Union[ForLoopBlock, TaskBlock, CodeBlock]
|
||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||
|
||||
@@ -11,6 +11,7 @@ class ParameterType(StrEnum):
|
||||
WORKFLOW = "workflow"
|
||||
CONTEXT = "context"
|
||||
AWS_SECRET = "aws_secret"
|
||||
OUTPUT = "output"
|
||||
|
||||
|
||||
class Parameter(BaseModel, abc.ABC):
|
||||
@@ -80,5 +81,16 @@ class ContextParameter(Parameter):
|
||||
value: str | int | float | bool | dict | list | None = None
|
||||
|
||||
|
||||
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter]
|
||||
class OutputParameter(Parameter):
|
||||
parameter_type: Literal[ParameterType.OUTPUT] = ParameterType.OUTPUT
|
||||
|
||||
output_parameter_id: str
|
||||
workflow_id: str
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
deleted_at: datetime | None = None
|
||||
|
||||
|
||||
ParameterSubclasses = Union[WorkflowParameter, ContextParameter, AWSSecretParameter, OutputParameter]
|
||||
PARAMETER_TYPE = Annotated[ParameterSubclasses, Field(discriminator="parameter_type")]
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation
|
||||
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
|
||||
|
||||
|
||||
@@ -22,6 +23,18 @@ class RunWorkflowResponse(BaseModel):
|
||||
class WorkflowDefinition(BaseModel):
|
||||
blocks: List[BlockTypeVar]
|
||||
|
||||
def validate(self) -> None:
|
||||
labels: set[str] = set()
|
||||
duplicate_labels: set[str] = set()
|
||||
for block in self.blocks:
|
||||
if block.label in labels:
|
||||
duplicate_labels.add(block.label)
|
||||
else:
|
||||
labels.add(block.label)
|
||||
|
||||
if duplicate_labels:
|
||||
raise WorkflowDefinitionHasDuplicateBlockLabels(duplicate_labels)
|
||||
|
||||
|
||||
class Workflow(BaseModel):
|
||||
workflow_id: str
|
||||
@@ -61,6 +74,13 @@ class WorkflowRunParameter(BaseModel):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowRunOutputParameter(BaseModel):
|
||||
workflow_run_id: str
|
||||
output_parameter_id: str
|
||||
value: dict[str, Any] | list | str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowRunStatusResponse(BaseModel):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
|
||||
Reference in New Issue
Block a user