forloop metadata variables (#1334)

This commit is contained in:
LawyZheng
2024-12-06 11:35:32 +08:00
committed by GitHub
parent 01e9678d27
commit db5b9d1dbd
2 changed files with 34 additions and 10 deletions

View File

@@ -25,6 +25,9 @@ if TYPE_CHECKING:
LOG = structlog.get_logger()
BlockMetadata = dict[str, str | int | float | bool | dict | list]
class WorkflowRunContext:
parameters: dict[str, PARAMETER_TYPE]
values: dict[str, Any]
@@ -36,9 +39,12 @@ class WorkflowRunContext:
workflow_output_parameters: list[OutputParameter],
context_parameters: list[ContextParameter],
) -> None:
# key is label name
self.blocks_metadata: dict[str, BlockMetadata] = {}
self.parameters = {}
self.values = {}
self.secrets = {}
for parameter, run_parameter in workflow_parameter_tuples:
if parameter.key in self.parameters:
prev_value = self.parameters[parameter.key]
@@ -81,6 +87,15 @@ class WorkflowRunContext:
def set_value(self, key: str, value: Any) -> None:
self.values[key] = value
def update_block_metadata(self, label: str, metadata: BlockMetadata) -> None:
if label in self.blocks_metadata:
self.blocks_metadata[label].update(metadata)
return
self.blocks_metadata[label] = metadata
def get_block_metadata(self, label: str) -> BlockMetadata:
return self.blocks_metadata.get(label, BlockMetadata())
def get_original_secret_value_or_none(self, secret_id_or_value: Any) -> Any:
"""
Get the original secret value from the secrets dict. If the secret id is not found, return None.

View File

@@ -45,7 +45,7 @@ from skyvern.forge.sdk.api.files import (
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.context_manager import BlockMetadata, WorkflowRunContext
from skyvern.forge.sdk.workflow.exceptions import (
InvalidEmailClientConfiguration,
InvalidFileType,
@@ -140,6 +140,17 @@ class Block(BaseModel, abc.ABC):
status=status,
)
def format_block_parameter_template_from_workflow_run_context(
self, potential_template: str, workflow_run_context: WorkflowRunContext
) -> str:
if not potential_template:
return potential_template
template = Template(potential_template)
template_data = workflow_run_context.values.copy()
template_data[self.label] = workflow_run_context.get_block_metadata(self.label)
return template.render(template_data)
@classmethod
def get_subclasses(cls) -> tuple[type["Block"], ...]:
return tuple(cls.__subclasses__())
@@ -152,15 +163,6 @@ class Block(BaseModel, abc.ABC):
def get_async_aws_client() -> AsyncAWSClient:
return app.WORKFLOW_CONTEXT_MANAGER.aws_client
@staticmethod
def format_block_parameter_template_from_workflow_run_context(
potential_template: str, workflow_run_context: WorkflowRunContext
) -> str:
if not potential_template:
return potential_template
template = Template(potential_template)
return template.render(workflow_run_context.values)
@abc.abstractmethod
async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult:
pass
@@ -659,8 +661,15 @@ class ForLoopBlock(Block):
context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value)
for context_parameter in context_parameters_with_value:
workflow_run_context.set_value(context_parameter.key, context_parameter.value)
each_loop_output_values: list[dict[str, Any]] = []
for block_idx, loop_block in enumerate(self.loop_blocks):
metadata: BlockMetadata = {
"current_index": loop_idx,
"current_value": loop_over_value,
}
workflow_run_context.update_block_metadata(loop_block.label, metadata)
original_loop_block = loop_block
loop_block = loop_block.copy()
current_block = loop_block