From d697023994b022bac521e514e3156119739b9663 Mon Sep 17 00:00:00 2001 From: LawyZheng Date: Fri, 29 Nov 2024 15:24:35 +0800 Subject: [PATCH] all blocks support jinja template (#1288) --- skyvern/forge/sdk/workflow/models/block.py | 143 ++++++++++++++++----- 1 file changed, 109 insertions(+), 34 deletions(-) diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 0d5efb99..c2b42e47 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -154,10 +154,24 @@ class Block(BaseModel, abc.ABC): def get_async_aws_client() -> AsyncAWSClient: return app.WORKFLOW_CONTEXT_MANAGER.aws_client + @staticmethod + def format_block_parameter_template_from_workflow_run_context( + potential_template: str, workflow_run_context: WorkflowRunContext + ) -> str: + if not potential_template: + return potential_template + template = Template(potential_template) + return template.render(workflow_run_context.values) + @abc.abstractmethod async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: pass + def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None: + self.label = self.format_block_parameter_template_from_workflow_run_context( + potential_template=self.label, workflow_run_context=workflow_run_context + ) + async def execute_safe(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: try: return await self.execute(workflow_run_id, **kwargs) @@ -219,14 +233,48 @@ class BaseTaskBlock(Block): return parameters - @staticmethod - def format_task_block_parameter_template_from_workflow_run_context( - potential_template: str | None, workflow_run_context: WorkflowRunContext - ) -> str | None: - if not potential_template: - return potential_template - template = Template(potential_template) - return template.render(workflow_run_context.values) + def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None: + super().format_potential_template_parameters(workflow_run_context=workflow_run_context) + + self.title = self.format_block_parameter_template_from_workflow_run_context(self.title, workflow_run_context) + + if self.url: + self.url = self.format_block_parameter_template_from_workflow_run_context(self.url, workflow_run_context) + + if self.totp_identifier: + self.totp_identifier = self.format_block_parameter_template_from_workflow_run_context( + self.totp_identifier, workflow_run_context + ) + + if self.totp_verification_url: + self.totp_verification_url = self.format_block_parameter_template_from_workflow_run_context( + self.totp_verification_url, workflow_run_context + ) + + if self.download_suffix: + self.download_suffix = self.format_block_parameter_template_from_workflow_run_context( + self.download_suffix, workflow_run_context + ) + + if self.navigation_goal: + self.navigation_goal = self.format_block_parameter_template_from_workflow_run_context( + self.navigation_goal, workflow_run_context + ) + + if self.data_extraction_goal: + self.data_extraction_goal = self.format_block_parameter_template_from_workflow_run_context( + self.data_extraction_goal, workflow_run_context + ) + + if self.complete_criterion: + self.complete_criterion = self.format_block_parameter_template_from_workflow_run_context( + self.complete_criterion, workflow_run_context + ) + + if self.terminate_criterion: + self.terminate_criterion = self.format_block_parameter_template_from_workflow_run_context( + self.terminate_criterion, workflow_run_context + ) @staticmethod async def get_task_order(workflow_run_id: str, current_retry: int) -> tuple[int, int]: @@ -301,26 +349,7 @@ class BaseTaskBlock(Block): ) self.download_suffix = download_suffix_parameter_value - self.url = self.format_task_block_parameter_template_from_workflow_run_context(self.url, workflow_run_context) - self.totp_identifier = self.format_task_block_parameter_template_from_workflow_run_context( - self.totp_identifier, workflow_run_context - ) - self.download_suffix = self.format_task_block_parameter_template_from_workflow_run_context( - self.download_suffix, workflow_run_context - ) - self.navigation_goal = self.format_task_block_parameter_template_from_workflow_run_context( - self.navigation_goal, workflow_run_context - ) - self.data_extraction_goal = self.format_task_block_parameter_template_from_workflow_run_context( - self.data_extraction_goal, workflow_run_context - ) - self.complete_criterion = self.format_task_block_parameter_template_from_workflow_run_context( - self.complete_criterion, workflow_run_context - ) - self.terminate_criterion = self.format_task_block_parameter_template_from_workflow_run_context( - self.terminate_criterion, workflow_run_context - ) - + self.format_potential_template_parameters(workflow_run_context=workflow_run_context) # TODO (kerem) we should always retry on terminated. We should make a distinction between retriable and # non-retryable terminations while will_retry: @@ -698,6 +727,7 @@ class ForLoopBlock(Block): async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: workflow_run_context = self.get_workflow_run_context(workflow_run_id) + self.format_potential_template_parameters(workflow_run_context) loop_over_values = self.get_loop_over_parameter_values(workflow_run_context) LOG.info( f"Number of loop_over values: {len(loop_over_values)}", @@ -772,10 +802,15 @@ class CodeBlock(Block): ) -> list[PARAMETER_TYPE]: return self.parameters + def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None: + super().format_potential_template_parameters(workflow_run_context) + self.code = self.format_block_parameter_template_from_workflow_run_context(self.code, workflow_run_context) + async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: raise DisabledBlockExecutionError("CodeBlock is disabled") # get workflow run context workflow_run_context = self.get_workflow_run_context(workflow_run_id) + self.format_potential_template_parameters(workflow_run_context) # get all parameters into a dictionary parameter_values = {} @@ -836,6 +871,13 @@ class TextPromptBlock(Block): ) -> list[PARAMETER_TYPE]: return self.parameters + def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None: + super().format_potential_template_parameters(workflow_run_context) + self.llm_key = self.format_block_parameter_template_from_workflow_run_context( + self.llm_key, workflow_run_context + ) + self.prompt = self.format_block_parameter_template_from_workflow_run_context(self.prompt, workflow_run_context) + async def send_prompt(self, prompt: str, parameter_values: dict[str, Any]) -> dict[str, Any]: llm_key = self.llm_key or DEFAULT_TEXT_PROMPT_LLM_KEY llm_api_handler = LLMAPIHandlerFactory.get_llm_api_handler(llm_key) @@ -870,6 +912,7 @@ class TextPromptBlock(Block): async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: # get workflow run context workflow_run_context = self.get_workflow_run_context(workflow_run_id) + self.format_potential_template_parameters(workflow_run_context) # get all parameters into a dictionary parameter_values = {} for parameter in self.parameters: @@ -903,6 +946,10 @@ class DownloadToS3Block(Block): return [] + def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None: + super().format_potential_template_parameters(workflow_run_context) + self.url = self.format_block_parameter_template_from_workflow_run_context(self.url, workflow_run_context) + async def _upload_file_to_s3(self, uri: str, file_path: str) -> None: try: client = self.get_async_aws_client() @@ -925,6 +972,8 @@ class DownloadToS3Block(Block): ) self.url = task_url_parameter_value + self.format_potential_template_parameters(workflow_run_context) + try: file_path = await download_file(self.url, max_size_mb=10) except Exception as e: @@ -963,6 +1012,11 @@ class UploadToS3Block(Block): return [] + def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None: + super().format_potential_template_parameters(workflow_run_context) + if self.path: + self.path = self.format_block_parameter_template_from_workflow_run_context(self.path, workflow_run_context) + @staticmethod def _get_s3_uri(workflow_run_id: str, path: str) -> str: s3_bucket = SettingsManager.get_settings().AWS_S3_BUCKET_UPLOADS @@ -986,6 +1040,7 @@ class UploadToS3Block(Block): elif self.path == SettingsManager.get_settings().WORKFLOW_DOWNLOAD_DIRECTORY_PARAMETER_KEY: self.path = str(get_path_for_workflow_download_directory(workflow_run_id).absolute()) + self.format_potential_template_parameters(workflow_run_context) if not self.path or not os.path.exists(self.path): raise FileNotFoundError(f"UploadToS3Block: File not found at path: {self.path}") @@ -1061,6 +1116,16 @@ class SendEmailBlock(Block): return parameters + def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None: + super().format_potential_template_parameters(workflow_run_context) + self.sender = self.format_block_parameter_template_from_workflow_run_context(self.sender, workflow_run_context) + self.subject = self.format_block_parameter_template_from_workflow_run_context( + self.subject, workflow_run_context + ) + self.body = self.format_block_parameter_template_from_workflow_run_context(self.body, workflow_run_context) + # file_attachments are formatted in _get_file_paths() + # recipients are formatted in get_real_email_recipients() + def _decrypt_smtp_parameters(self, workflow_run_context: WorkflowRunContext) -> tuple[str, int, str, str]: obfuscated_smtp_host_value = workflow_run_context.get_value(self.smtp_host.key) obfuscated_smtp_port_value = workflow_run_context.get_value(self.smtp_port.key) @@ -1117,6 +1182,7 @@ class SendEmailBlock(Block): file_path=path, ) + path = self.format_block_parameter_template_from_workflow_run_context(path, workflow_run_context) # if the file path is a directory, add all files in the directory, skip directories, limit to 10 files if os.path.exists(path): if os.path.isdir(path): @@ -1157,6 +1223,7 @@ class SendEmailBlock(Block): else: maybe_recipient = recipient + recipient = self.format_block_parameter_template_from_workflow_run_context(recipient, workflow_run_context) # check if maybe_recipient is a valid email address try: validate_email(maybe_recipient) @@ -1269,6 +1336,7 @@ class SendEmailBlock(Block): async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: workflow_run_context = self.get_workflow_run_context(workflow_run_id) + self.format_potential_template_parameters(workflow_run_context) smtp_host_value, smtp_port_value, smtp_username_value, smtp_password_value = self._decrypt_smtp_parameters( workflow_run_context ) @@ -1320,6 +1388,12 @@ class FileParserBlock(Block): return [workflow_run_context.get_parameter(self.file_url)] return [] + def format_potential_template_parameters(self, workflow_run_context: WorkflowRunContext) -> None: + super().format_potential_template_parameters(workflow_run_context) + self.file_url = self.format_block_parameter_template_from_workflow_run_context( + self.file_url, workflow_run_context + ) + def validate_file_type(self, file_url_used: str, file_path: str) -> None: if self.file_type == FileType.CSV: try: @@ -1330,7 +1404,6 @@ class FileParserBlock(Block): async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: workflow_run_context = self.get_workflow_run_context(workflow_run_id) - file_url_to_use = self.file_url if ( self.file_url and workflow_run_context.has_parameter(self.file_url) @@ -1343,15 +1416,17 @@ class FileParserBlock(Block): file_url_parameter_value=file_url_parameter_value, file_url_parameter_key=self.file_url, ) - file_url_to_use = file_url_parameter_value + self.file_url = file_url_parameter_value + + self.format_potential_template_parameters(workflow_run_context) # Download the file - if file_url_to_use.startswith("s3://"): - file_path = await download_from_s3(self.get_async_aws_client(), file_url_to_use) + if self.file_url.startswith("s3://"): + file_path = await download_from_s3(self.get_async_aws_client(), self.file_url) else: - file_path = await download_file(file_url_to_use) + file_path = await download_file(self.file_url) # Validate the file type - self.validate_file_type(file_url_to_use, file_path) + self.validate_file_type(self.file_url, file_path) # Parse the file into a list of dictionaries where each dictionary represents a row in the file parsed_data = [] with open(file_path, "r") as file: