Automatically create output parameters (#327)
This commit is contained in:
@@ -54,7 +54,7 @@ class BlockType(StrEnum):
|
||||
@dataclass(frozen=True)
|
||||
class BlockResult:
|
||||
success: bool
|
||||
output_parameter: OutputParameter | None = None
|
||||
output_parameter: OutputParameter
|
||||
output_parameter_value: dict[str, Any] | list | str | None = None
|
||||
|
||||
|
||||
@@ -62,9 +62,37 @@ class Block(BaseModel, abc.ABC):
|
||||
# Must be unique within workflow definition
|
||||
label: str
|
||||
block_type: BlockType
|
||||
output_parameter: OutputParameter | None = None
|
||||
output_parameter: OutputParameter
|
||||
continue_on_failure: bool = False
|
||||
|
||||
async def record_output_parameter_value(
|
||||
self,
|
||||
workflow_run_context: WorkflowRunContext,
|
||||
workflow_run_id: str,
|
||||
value: dict[str, Any] | list | str | None = None,
|
||||
) -> None:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=value,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=value,
|
||||
)
|
||||
LOG.info(
|
||||
f"Registered output parameter value",
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
def build_block_result(
|
||||
self, success: bool, output_parameter_value: dict[str, Any] | list | str | None = None
|
||||
) -> BlockResult:
|
||||
return BlockResult(
|
||||
success=success, output_parameter=self.output_parameter, output_parameter_value=output_parameter_value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_subclasses(cls) -> tuple[type["Block"], ...]:
|
||||
return tuple(cls.__subclasses__())
|
||||
@@ -91,7 +119,11 @@ class Block(BaseModel, abc.ABC):
|
||||
block_label=self.label,
|
||||
block_type=self.block_type,
|
||||
)
|
||||
return BlockResult(success=False)
|
||||
# Record output parameter value if it hasn't been recorded yet
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
if not workflow_run_context.has_value(self.output_parameter.key):
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id)
|
||||
return self.build_block_result(success=False)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_all_parameters(
|
||||
@@ -238,6 +270,7 @@ class TaskBlock(Block):
|
||||
raise TaskNotFound(task.task_id)
|
||||
if not updated_task.status.is_final():
|
||||
raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status)
|
||||
|
||||
if updated_task.status == TaskStatus.completed or updated_task.status == TaskStatus.terminated:
|
||||
LOG.info(
|
||||
f"Task completed",
|
||||
@@ -249,30 +282,9 @@ class TaskBlock(Block):
|
||||
)
|
||||
success = updated_task.status == TaskStatus.completed
|
||||
task_output = TaskOutput.from_task(updated_task)
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=task_output.model_dump(),
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=task_output.model_dump(),
|
||||
)
|
||||
LOG.info(
|
||||
f"Registered output parameter value",
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=updated_task.extracted_information,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
task_id=updated_task.task_id,
|
||||
)
|
||||
return BlockResult(
|
||||
success=success,
|
||||
output_parameter=self.output_parameter,
|
||||
output_parameter_value=task_output.model_dump(),
|
||||
)
|
||||
return BlockResult(success=success)
|
||||
output_parameter_value = task_output.model_dump()
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, output_parameter_value)
|
||||
return self.build_block_result(success=success, output_parameter_value=output_parameter_value)
|
||||
else:
|
||||
current_retry += 1
|
||||
will_retry = current_retry <= self.max_retries
|
||||
@@ -289,14 +301,20 @@ class TaskBlock(Block):
|
||||
max_retries=self.max_retries,
|
||||
task_output=task_output.model_dump_json(),
|
||||
)
|
||||
if not will_retry:
|
||||
output_parameter_value = task_output.model_dump()
|
||||
await self.record_output_parameter_value(
|
||||
workflow_run_context, workflow_run_id, output_parameter_value
|
||||
)
|
||||
return self.build_block_result(success=False, output_parameter_value=output_parameter_value)
|
||||
|
||||
return BlockResult(success=False)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id)
|
||||
return self.build_block_result(success=False)
|
||||
|
||||
|
||||
class ForLoopBlock(Block):
|
||||
block_type: Literal[BlockType.FOR_LOOP] = BlockType.FOR_LOOP
|
||||
|
||||
# TODO (kerem): Add support for ContextParameter
|
||||
loop_over: PARAMETER_TYPE
|
||||
loop_blocks: list["BlockTypeVar"]
|
||||
|
||||
@@ -370,6 +388,7 @@ class ForLoopBlock(Block):
|
||||
return [parameter_value]
|
||||
|
||||
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
|
||||
outputs_with_loop_values = []
|
||||
success = False
|
||||
workflow_run_context = self.get_workflow_run_context(workflow_run_id)
|
||||
loop_over_values = self.get_loop_over_parameter_values(workflow_run_context)
|
||||
@@ -386,30 +405,27 @@ class ForLoopBlock(Block):
|
||||
workflow_run_id=workflow_run_id,
|
||||
num_loop_over_values=len(loop_over_values),
|
||||
)
|
||||
return BlockResult(success=success)
|
||||
outputs_with_loop_values = []
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, [])
|
||||
return self.build_block_result(success=False)
|
||||
for loop_idx, loop_over_value in enumerate(loop_over_values):
|
||||
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
|
||||
for context_parameter in context_parameters_with_value:
|
||||
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
|
||||
try:
|
||||
block_outputs = []
|
||||
for block_idx, loop_block in enumerate(self.loop_blocks):
|
||||
block_output = await loop_block.execute_safe(workflow_run_id=workflow_run_id)
|
||||
block_outputs.append(block_output)
|
||||
if not block_output.success and not loop_block.continue_on_failure:
|
||||
LOG.info(
|
||||
f"ForLoopBlock: Encountered an failure processing block {block_idx} during loop {loop_idx}, terminating early",
|
||||
block_outputs=block_outputs,
|
||||
loop_idx=loop_idx,
|
||||
block_idx=block_idx,
|
||||
loop_over_value=loop_over_value,
|
||||
loop_block_continue_on_failure=loop_block.continue_on_failure,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
LOG.error("ForLoopBlock: Failed to execute loop block", exc_info=True)
|
||||
raise e
|
||||
block_outputs = []
|
||||
for block_idx, loop_block in enumerate(self.loop_blocks):
|
||||
block_output = await loop_block.execute_safe(workflow_run_id=workflow_run_id)
|
||||
block_outputs.append(block_output)
|
||||
if not block_output.success and not loop_block.continue_on_failure:
|
||||
LOG.info(
|
||||
f"ForLoopBlock: Encountered an failure processing block {block_idx} during loop {loop_idx}, terminating early",
|
||||
block_outputs=block_outputs,
|
||||
loop_idx=loop_idx,
|
||||
block_idx=block_idx,
|
||||
loop_over_value=loop_over_value,
|
||||
loop_block_continue_on_failure=loop_block.continue_on_failure,
|
||||
)
|
||||
break
|
||||
|
||||
outputs_with_loop_values.append(
|
||||
[
|
||||
{
|
||||
@@ -427,27 +443,13 @@ class ForLoopBlock(Block):
|
||||
success = all([block_output.success for block_output in block_outputs])
|
||||
if not success and not self.continue_on_failure:
|
||||
LOG.info(
|
||||
"ForLoopBlock: Encountered an failure processing block, terminating early",
|
||||
block_outputs=block_outputs,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
f"ForLoopBlock: Encountered an failure processing loop {loop_idx}, won't continue to the next loop. Total number of loops: {len(loop_over_values)}",
|
||||
for_loop_continue_on_failure=self.continue_on_failure,
|
||||
)
|
||||
break
|
||||
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=outputs_with_loop_values,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=outputs_with_loop_values,
|
||||
)
|
||||
return BlockResult(
|
||||
success=success, output_parameter=self.output_parameter, output_parameter_value=outputs_with_loop_values
|
||||
)
|
||||
|
||||
return BlockResult(success=success)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, outputs_with_loop_values)
|
||||
return self.build_block_result(success=success, output_parameter_value=outputs_with_loop_values)
|
||||
|
||||
|
||||
class CodeBlock(Block):
|
||||
@@ -478,19 +480,8 @@ class CodeBlock(Block):
|
||||
local_variables: dict[str, Any] = {}
|
||||
exec(self.code, parameter_values, local_variables)
|
||||
result = {"result": local_variables.get("result")}
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=result,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=result,
|
||||
)
|
||||
return BlockResult(success=True, output_parameter=self.output_parameter, output_parameter_value=result)
|
||||
|
||||
return BlockResult(success=True)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result)
|
||||
return self.build_block_result(success=True, output_parameter_value=result)
|
||||
|
||||
|
||||
class TextPromptBlock(Block):
|
||||
@@ -547,19 +538,8 @@ class TextPromptBlock(Block):
|
||||
parameter_values[parameter.key] = value
|
||||
|
||||
response = await self.send_prompt(self.prompt, parameter_values)
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=response,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=response,
|
||||
)
|
||||
return BlockResult(success=True, output_parameter=self.output_parameter, output_parameter_value=response)
|
||||
|
||||
return BlockResult(success=True)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, response)
|
||||
return self.build_block_result(success=True, output_parameter_value=response)
|
||||
|
||||
|
||||
class DownloadToS3Block(Block):
|
||||
@@ -615,21 +595,8 @@ class DownloadToS3Block(Block):
|
||||
raise e
|
||||
|
||||
LOG.info("DownloadToS3Block: File downloaded and uploaded to S3", uri=uri)
|
||||
if self.output_parameter:
|
||||
LOG.info("DownloadToS3Block: Output parameter defined, registering output parameter value")
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=uri,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=uri,
|
||||
)
|
||||
return BlockResult(success=True, output_parameter=self.output_parameter, output_parameter_value=uri)
|
||||
|
||||
LOG.info("DownloadToS3Block: No output parameter defined, returning None")
|
||||
return BlockResult(success=True)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, uri)
|
||||
return self.build_block_result(success=True, output_parameter_value=uri)
|
||||
|
||||
|
||||
class UploadToS3Block(Block):
|
||||
@@ -675,6 +642,7 @@ class UploadToS3Block(Block):
|
||||
if not self.path or not os.path.exists(self.path):
|
||||
raise FileNotFoundError(f"UploadToS3Block: File not found at path: {self.path}")
|
||||
|
||||
s3_uris = []
|
||||
try:
|
||||
client = self.get_async_aws_client()
|
||||
# is the file path a file or a directory?
|
||||
@@ -689,19 +657,20 @@ class UploadToS3Block(Block):
|
||||
LOG.warning("UploadToS3Block: Skipping directory", file=file)
|
||||
continue
|
||||
file_path = os.path.join(self.path, file)
|
||||
await client.upload_file_from_path(
|
||||
uri=self._get_s3_uri(workflow_run_id, file_path), file_path=file_path
|
||||
)
|
||||
s3_uri = self._get_s3_uri(workflow_run_id, file_path)
|
||||
s3_uris.append(s3_uri)
|
||||
await client.upload_file_from_path(uri=s3_uri, file_path=file_path)
|
||||
else:
|
||||
await client.upload_file_from_path(
|
||||
uri=self._get_s3_uri(workflow_run_id, self.path), file_path=self.path
|
||||
)
|
||||
s3_uri = self._get_s3_uri(workflow_run_id, self.path)
|
||||
s3_uris.append(s3_uri)
|
||||
await client.upload_file_from_path(uri=s3_uri, file_path=self.path)
|
||||
except Exception as e:
|
||||
LOG.exception("UploadToS3Block: Failed to upload file to S3", file_path=self.path)
|
||||
raise e
|
||||
|
||||
LOG.info("UploadToS3Block: File(s) uploaded to S3", file_path=self.path)
|
||||
return BlockResult(success=True)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, s3_uris)
|
||||
return self.build_block_result(success=True, output_parameter_value=s3_uris)
|
||||
|
||||
|
||||
class SendEmailBlock(Block):
|
||||
@@ -902,39 +871,16 @@ class SendEmailBlock(Block):
|
||||
LOG.info("SendEmailBlock: Email sent")
|
||||
except Exception as e:
|
||||
LOG.error("SendEmailBlock: Failed to send email", exc_info=True)
|
||||
if self.output_parameter:
|
||||
result_dict = {"success": False, "error": str(e)}
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=result_dict,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=result_dict,
|
||||
)
|
||||
return BlockResult(
|
||||
success=False, output_parameter=self.output_parameter, output_parameter_value=result_dict
|
||||
)
|
||||
raise e
|
||||
result_dict = {"success": False, "error": str(e)}
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result_dict)
|
||||
return self.build_block_result(success=False, output_parameter_value=result_dict)
|
||||
finally:
|
||||
if smtp_host:
|
||||
smtp_host.quit()
|
||||
|
||||
result_dict = {"success": True}
|
||||
if self.output_parameter:
|
||||
await workflow_run_context.register_output_parameter_value_post_execution(
|
||||
parameter=self.output_parameter,
|
||||
value=result_dict,
|
||||
)
|
||||
await app.DATABASE.create_workflow_run_output_parameter(
|
||||
workflow_run_id=workflow_run_id,
|
||||
output_parameter_id=self.output_parameter.output_parameter_id,
|
||||
value=result_dict,
|
||||
)
|
||||
return BlockResult(success=True, output_parameter=self.output_parameter, output_parameter_value=result_dict)
|
||||
|
||||
return BlockResult(success=True)
|
||||
await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result_dict)
|
||||
return self.build_block_result(success=True, output_parameter_value=result_dict)
|
||||
|
||||
|
||||
BlockSubclasses = Union[
|
||||
|
||||
Reference in New Issue
Block a user