forloop metadata variables (#1334)

This commit is contained in:
LawyZheng
2024-12-06 11:35:32 +08:00
committed by GitHub
parent 01e9678d27
commit db5b9d1dbd
2 changed files with 34 additions and 10 deletions

View File

@@ -25,6 +25,9 @@ if TYPE_CHECKING:
LOG = structlog.get_logger() LOG = structlog.get_logger()
BlockMetadata = dict[str, str | int | float | bool | dict | list]
class WorkflowRunContext: class WorkflowRunContext:
parameters: dict[str, PARAMETER_TYPE] parameters: dict[str, PARAMETER_TYPE]
values: dict[str, Any] values: dict[str, Any]
@@ -36,9 +39,12 @@ class WorkflowRunContext:
workflow_output_parameters: list[OutputParameter], workflow_output_parameters: list[OutputParameter],
context_parameters: list[ContextParameter], context_parameters: list[ContextParameter],
) -> None: ) -> None:
# key is label name
self.blocks_metadata: dict[str, BlockMetadata] = {}
self.parameters = {} self.parameters = {}
self.values = {} self.values = {}
self.secrets = {} self.secrets = {}
for parameter, run_parameter in workflow_parameter_tuples: for parameter, run_parameter in workflow_parameter_tuples:
if parameter.key in self.parameters: if parameter.key in self.parameters:
prev_value = self.parameters[parameter.key] prev_value = self.parameters[parameter.key]
@@ -81,6 +87,15 @@ class WorkflowRunContext:
def set_value(self, key: str, value: Any) -> None: def set_value(self, key: str, value: Any) -> None:
self.values[key] = value self.values[key] = value
def update_block_metadata(self, label: str, metadata: BlockMetadata) -> None:
if label in self.blocks_metadata:
self.blocks_metadata[label].update(metadata)
return
self.blocks_metadata[label] = metadata
def get_block_metadata(self, label: str) -> BlockMetadata:
return self.blocks_metadata.get(label, BlockMetadata())
def get_original_secret_value_or_none(self, secret_id_or_value: Any) -> Any: def get_original_secret_value_or_none(self, secret_id_or_value: Any) -> Any:
""" """
Get the original secret value from the secrets dict. If the secret id is not found, return None. Get the original secret value from the secrets dict. If the secret id is not found, return None.

View File

@@ -45,7 +45,7 @@ from skyvern.forge.sdk.api.files import (
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext
from skyvern.forge.sdk.workflow.exceptions import ( from skyvern.forge.sdk.workflow.exceptions import (
InvalidEmailClientConfiguration, InvalidEmailClientConfiguration,
InvalidFileType, InvalidFileType,
@@ -140,6 +140,17 @@ class Block(BaseModel, abc.ABC):
status=status, status=status,
) )
def format_block_parameter_template_from_workflow_run_context(
self, potential_template: str, workflow_run_context: WorkflowRunContext
) -> str:
if not potential_template:
return potential_template
template = Template(potential_template)
template_data = workflow_run_context.values.copy()
template_data[self.label] = workflow_run_context.get_block_metadata(self.label)
return template.render(template_data)
@classmethod @classmethod
def get_subclasses(cls) -> tuple[type["Block"], ...]: def get_subclasses(cls) -> tuple[type["Block"], ...]:
return tuple(cls.__subclasses__()) return tuple(cls.__subclasses__())
@@ -152,15 +163,6 @@ class Block(BaseModel, abc.ABC):
def get_async_aws_client() -> AsyncAWSClient: def get_async_aws_client() -> AsyncAWSClient:
return app.WORKFLOW_CONTEXT_MANAGER.aws_client return app.WORKFLOW_CONTEXT_MANAGER.aws_client
@staticmethod
def format_block_parameter_template_from_workflow_run_context(
potential_template: str, workflow_run_context: WorkflowRunContext
) -> str:
if not potential_template:
return potential_template
template = Template(potential_template)
return template.render(workflow_run_context.values)
@abc.abstractmethod @abc.abstractmethod
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
pass pass
@@ -659,8 +661,15 @@ class ForLoopBlock(Block):
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value) context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value: for context_parameter in context_parameters_with_value:
workflow_run_context.set_value(context_parameter.key, context_parameter.value) workflow_run_context.set_value(context_parameter.key, context_parameter.value)
each_loop_output_values: list[dict[str, Any]] = [] each_loop_output_values: list[dict[str, Any]] = []
for block_idx, loop_block in enumerate(self.loop_blocks): for block_idx, loop_block in enumerate(self.loop_blocks):
metadata: BlockMetadata = {
"current_index": loop_idx,
"current_value": loop_over_value,
}
workflow_run_context.update_block_metadata(loop_block.label, metadata)
original_loop_block = loop_block original_loop_block = loop_block
loop_block = loop_block.copy() loop_block = loop_block.copy()
current_block = loop_block current_block = loop_block