diff --git a/skyvern/__init__.py b/skyvern/__init__.py index 6626be75..70c03417 100644 --- a/skyvern/__init__.py +++ b/skyvern/__init__.py @@ -35,6 +35,7 @@ from skyvern.services.script_service import ( # noqa: E402 extract, # noqa: E402 generate_text, # noqa: E402 login, # noqa: E402 + render_template, # noqa: E402 run_script, # noqa: E402 run_task, # noqa: E402 wait, # noqa: E402 @@ -51,6 +52,7 @@ __all__ = [ "extract", "generate_text", "login", + "render_template", "run_script", "run_task", "setup", diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index 0a489156..171043f0 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -102,11 +102,24 @@ def _value(value: Any) -> cst.BaseExpression: return cst.SimpleString(repr(str(value))) +def _prompt_value(prompt_text: str) -> cst.BaseExpression: + """Create a prompt value with template rendering logic if needed.""" + if "{{" in prompt_text and "}}" in prompt_text: + # Generate code for: render_template(prompt_text) + return cst.Call( + func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("render_template")), + args=[cst.Arg(value=_value(prompt_text))], + ) + else: + # Return the prompt as a simple string value + return _value(prompt_text) + + def _generate_text_call(text_value: str, intention: str, parameter_key: str) -> cst.BaseExpression: """Create a generate_text function call CST expression.""" return cst.Await( expression=cst.Call( - func=cst.Name("generate_text"), + func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("generate_text")), whitespace_before_args=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(DOUBLE_INDENT), @@ -433,7 +446,7 @@ def _build_run_task_statement(block_title: str, block: dict[str, Any]) -> cst.Si args = [ cst.Arg( keyword=cst.Name("prompt"), - value=_value(block.get("navigation_goal", "")), + value=_prompt_value(block.get("navigation_goal", "")), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(INDENT), @@ -474,7 +487,7 @@ def _build_download_statement(block_title: str, block: dict[str, Any]) -> cst.Si args = [ cst.Arg( keyword=cst.Name("prompt"), - value=_value(block.get("navigation_goal", "")), + value=_prompt_value(block.get("navigation_goal", "")), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(INDENT), @@ -523,7 +536,7 @@ def _build_action_statement(block_title: str, block: dict[str, Any]) -> cst.Simp args = [ cst.Arg( keyword=cst.Name("prompt"), - value=_value(block.get("navigation_goal", "")), + value=_prompt_value(block.get("navigation_goal", "")), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(INDENT), @@ -572,7 +585,7 @@ def _build_login_statement(block_title: str, block: dict[str, Any]) -> cst.Simpl ), cst.Arg( keyword=cst.Name("prompt"), - value=_value(block.get("navigation_goal", "")), + value=_prompt_value(block.get("navigation_goal", "")), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(INDENT), @@ -654,7 +667,7 @@ def _build_navigate_statement(block_title: str, block: dict[str, Any]) -> cst.Si args = [ cst.Arg( keyword=cst.Name("prompt"), - value=_value(block.get("navigation_goal", "")), + value=_prompt_value(block.get("navigation_goal", "")), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(INDENT), @@ -760,7 +773,7 @@ def _build_validate_statement(block: dict[str, Any]) -> cst.SimpleStatementLine: args = [ cst.Arg( keyword=cst.Name("prompt"), - value=_value(block.get("navigation_goal", "")), + value=_prompt_value(block.get("navigation_goal", "")), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, ), @@ -810,7 +823,7 @@ def _build_for_loop_statement(block_title: str, block: dict[str, Any]) -> cst.Si args = [ cst.Arg( keyword=cst.Name("prompt"), - value=_value(block.get("navigation_goal", "")), + value=_prompt_value(block.get("navigation_goal", "")), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(INDENT), @@ -1065,7 +1078,6 @@ async def generate_workflow_script( names=[ cst.ImportAlias(cst.Name("RunContext")), cst.ImportAlias(cst.Name("SkyvernPage")), - cst.ImportAlias(cst.Name("generate_text")), ], ) ] diff --git a/skyvern/core/script_generations/transform_workflow_run.py b/skyvern/core/script_generations/transform_workflow_run.py index 830087d7..ef0017a3 100644 --- a/skyvern/core/script_generations/transform_workflow_run.py +++ b/skyvern/core/script_generations/transform_workflow_run.py @@ -40,36 +40,72 @@ async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organiz raise ValueError(f"Workflow {run_request.workflow_id} not found") workflow_json = workflow.model_dump() - # get the tasks - ## first, get all the workflow run blocks + # get the original workflow definition blocks (with templated information) + workflow_definition_blocks = workflow.workflow_definition.blocks + + # get workflow run blocks for task execution data workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks( workflow_run_id=workflow_run_id, organization_id=organization_id ) workflow_run_blocks.sort(key=lambda x: x.created_at) + + # Create mapping from definition blocks by label for quick lookup + definition_blocks_by_label = {block.label: block for block in workflow_definition_blocks if block.label} + workflow_block_dump = [] - # Hydrate blocks with task data - # TODO: support task v2 actions_by_task = {} - for block in workflow_run_blocks: - block_dump = block.model_dump() - if block.block_type == BlockType.TaskV2: + + # Loop through workflow run blocks and match to original definition blocks by label + for run_block in workflow_run_blocks: + if run_block.block_type == BlockType.TaskV2: raise ValueError("TaskV2 blocks are not supported yet") - if block.block_type in SCRIPT_TASK_BLOCKS and block.task_id: - task = await app.DATABASE.get_task(task_id=block.task_id, organization_id=organization_id) - if not task: - LOG.warning(f"Task {block.task_id} not found") - continue - block_dump.update(task.model_dump()) - actions = await app.DATABASE.get_task_actions_hydrated( - task_id=block.task_id, organization_id=organization_id - ) - action_dumps = [] - for action in actions: - action_dump = action.model_dump() - action_dump["xpath"] = action.get_xpath() - action_dumps.append(action_dump) - actions_by_task[block.task_id] = action_dumps - workflow_block_dump.append(block_dump) + + # Find corresponding definition block by label to get templated information + definition_block = definition_blocks_by_label.get(run_block.label) if run_block.label else None + + if definition_block: + # Start with the original templated definition block + final_dump = definition_block.model_dump() + else: + # Fallback to run block data if no matching definition block found + final_dump = run_block.model_dump() + LOG.warning(f"No matching definition block found for run block with label: {run_block.label}") + + # For task blocks, add execution data while preserving templated information + if run_block.block_type in SCRIPT_TASK_BLOCKS and run_block.task_id: + task = await app.DATABASE.get_task(task_id=run_block.task_id, organization_id=organization_id) + if task: + # Add task execution data but preserve original templated fields + task_dump = task.model_dump() + # Update with execution data, but keep templated values from definition + if definition_block: + final_dump.update({k: v for k, v in task_dump.items() if k not in final_dump}) + else: + final_dump.update(task_dump) + + # Add run block execution metadata + final_dump.update( + { + "task_id": run_block.task_id, + "status": run_block.status, + "output": run_block.output, + } + ) + + # Get task actions + actions = await app.DATABASE.get_task_actions_hydrated( + task_id=run_block.task_id, organization_id=organization_id + ) + action_dumps = [] + for action in actions: + action_dump = action.model_dump() + action_dump["xpath"] = action.get_xpath() + action_dumps.append(action_dump) + actions_by_task[run_block.task_id] = action_dumps + else: + LOG.warning(f"Task {run_block.task_id} not found") + + workflow_block_dump.append(final_dump) return CodeGenInput( file_name=f"{workflow_run_id}.py", diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index 9fe41851..8cb2470b 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -9,7 +9,7 @@ from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.task_v2 import TaskV2 from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels from skyvern.forge.sdk.workflow.models.block import BlockTypeVar -from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE +from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, OutputParameter from skyvern.schemas.runs import ProxyLocation from skyvern.schemas.workflows import WorkflowStatus from skyvern.utils.url_validators import validate_url @@ -83,6 +83,12 @@ class Workflow(BaseModel): modified_at: datetime deleted_at: datetime | None = None + def get_output_parameter(self, label: str) -> OutputParameter | None: + for block in self.workflow_definition.blocks: + if block.label == label: + return block.output_parameter + return None + class WorkflowRunStatus(StrEnum): created = "created" diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 55f57081..f9337ccc 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -277,24 +277,7 @@ class WorkflowService: workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id) workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id=workflow_run.workflow_permanent_id) close_browser_on_completion = browser_session_id is None and not workflow_run.browser_address - - # Check if there's a related workflow script that should be used instead - workflow_script = await self._get_workflow_script(workflow, workflow_run) - if workflow_script is not None: - LOG.info( - "Found related workflow script, running script instead of workflow", - workflow_run_id=workflow_run_id, - workflow_id=workflow.workflow_id, - organization_id=organization_id, - workflow_script_id=workflow_script.script_id, - ) - return await self._execute_workflow_script( - script_id=workflow_script.script_id, - workflow=workflow, - workflow_run=workflow_run, - api_key=api_key, - organization=organization, - ) + skyvern_context.current() # Set workflow run status to running, create workflow run parameters workflow_run = await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id) @@ -357,6 +340,24 @@ class WorkflowService: ) return workflow_run + # Check if there's a related workflow script that should be used instead + workflow_script = await self._get_workflow_script(workflow, workflow_run) + if workflow_script is not None: + LOG.info( + "Found related workflow script, running script instead of workflow", + workflow_run_id=workflow_run_id, + workflow_id=workflow.workflow_id, + organization_id=organization_id, + workflow_script_id=workflow_script.script_id, + ) + return await self._execute_workflow_script( + script_id=workflow_script.script_id, + workflow=workflow, + workflow_run=workflow_run, + api_key=api_key, + organization=organization, + ) + top_level_blocks = workflow.workflow_definition.blocks all_blocks = get_all_blocks(top_level_blocks) @@ -2350,9 +2351,6 @@ class WorkflowService: """ try: - # Set workflow run status to running - workflow_run = await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id) - # Render the cache_key_value to find the right script parameter_tuples = await app.DATABASE.get_workflow_run_parameters( workflow_run_id=workflow_run.workflow_run_id @@ -2402,7 +2400,7 @@ class WorkflowService: workflow_run: WorkflowRun, ) -> None: cache_key = workflow.cache_key - rendered_cache_key_value = "" + rendered_cache_key_value = "default" # 1) Build cache_key_value from workflow run parameters via jinja if cache_key: parameter_tuples = await app.DATABASE.get_workflow_run_parameters( diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index ed6ac6d0..63ebb99f 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -9,6 +9,7 @@ from typing import Any, cast import structlog from fastapi import BackgroundTasks, HTTPException +from jinja2.sandbox import SandboxedEnvironment from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS @@ -19,10 +20,11 @@ from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.tasks import TaskOutput, TaskStatus -from skyvern.forge.sdk.workflow.models.block import BlockStatus, BlockType from skyvern.schemas.scripts import CreateScriptResponse, FileNode, ScriptFileCreate +from skyvern.schemas.workflows import BlockStatus, BlockType LOG = structlog.get_logger(__name__) +jinja_sandbox_env = SandboxedEnvironment() async def build_file_tree( @@ -320,20 +322,34 @@ async def _create_workflow_block_run_and_task( async def _record_output_parameter_value( workflow_run_id: str, + workflow_id: str, + organization_id: str, output: dict[str, Any] | list | str | None, + label: str | None = None, ) -> None: + if not label: + return # TODO support this in the future - # workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id) - # await workflow_run_context.register_output_parameter_value_post_execution( - # parameter=self.output_parameter, - # value=value, - # ) - # await app.DATABASE.create_or_update_workflow_run_output_parameter( - # workflow_run_id=workflow_run_id, - # output_parameter_id=self.output_parameter.output_parameter_id, - # value=value, - # ) - return + workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id) + # get the workflow + workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id) + if not workflow: + return + + # get the output_paramter + output_parameter = workflow.get_output_parameter(label) + if not output_parameter: + return + + await workflow_run_context.register_output_parameter_value_post_execution( + parameter=output_parameter, + value=output, + ) + await app.DATABASE.create_or_update_workflow_run_output_parameter( + workflow_run_id=workflow_run_id, + output_parameter_id=output_parameter.output_parameter_id, + value=output, + ) async def _update_workflow_block( @@ -341,13 +357,14 @@ async def _update_workflow_block( status: BlockStatus, task_id: str | None = None, task_status: TaskStatus = TaskStatus.completed, + label: str | None = None, failure_reason: str | None = None, output: dict[str, Any] | list | str | None = None, ) -> None: """Update the status of a workflow run block.""" try: context = skyvern_context.current() - if not context or not context.organization_id or not context.workflow_run_id: + if not context or not context.organization_id or not context.workflow_run_id or not context.workflow_id: return final_output = output if task_id: @@ -385,7 +402,13 @@ async def _update_workflow_block( status=status, failure_reason=failure_reason, ) - await _record_output_parameter_value(context.workflow_run_id, final_output) + await _record_output_parameter_value( + context.workflow_run_id, + context.workflow_id, + context.organization_id, + final_output, + label, + ) except Exception as e: LOG.warning( @@ -429,7 +452,9 @@ async def run_task( # Update block status to completed if workflow block was created if workflow_run_block_id: - await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id) + await _update_workflow_block( + workflow_run_block_id, BlockStatus.completed, task_id=task_id, label=cache_key + ) except Exception as e: # TODO: fallback to AI run in case of error @@ -440,6 +465,7 @@ async def run_task( BlockStatus.failed, task_id=task_id, task_status=TaskStatus.failed, + label=cache_key, failure_reason=str(e), ) raise @@ -481,7 +507,9 @@ async def download( # Update block status to completed if workflow block was created if workflow_run_block_id: - await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id) + await _update_workflow_block( + workflow_run_block_id, BlockStatus.completed, task_id=task_id, label=cache_key + ) except Exception as e: # Update block status to failed if workflow block was created @@ -491,6 +519,7 @@ async def download( BlockStatus.failed, task_id=task_id, task_status=TaskStatus.failed, + label=cache_key, failure_reason=str(e), ) raise @@ -531,7 +560,9 @@ async def action( # Update block status to completed if workflow block was created if workflow_run_block_id: - await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id) + await _update_workflow_block( + workflow_run_block_id, BlockStatus.completed, task_id=task_id, label=cache_key + ) except Exception as e: # Update block status to failed if workflow block was created @@ -541,6 +572,7 @@ async def action( BlockStatus.failed, task_id=task_id, task_status=TaskStatus.failed, + label=cache_key, failure_reason=str(e), ) raise @@ -581,7 +613,9 @@ async def login( # Update block status to completed if workflow block was created if workflow_run_block_id: - await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id) + await _update_workflow_block( + workflow_run_block_id, BlockStatus.completed, task_id=task_id, label=cache_key + ) except Exception as e: # Update block status to failed if workflow block was created @@ -591,6 +625,7 @@ async def login( BlockStatus.failed, task_id=task_id, task_status=TaskStatus.failed, + label=cache_key, failure_reason=str(e), ) raise @@ -637,6 +672,7 @@ async def extract( BlockStatus.completed, task_id=task_id, output=output, + label=cache_key, ) return output except Exception as e: @@ -649,6 +685,7 @@ async def extract( task_status=TaskStatus.failed, failure_reason=str(e), output=output, + label=cache_key, ) raise finally: @@ -752,3 +789,20 @@ async def generate_text( # If anything goes wrong, fall back to the original text pass return new_text + + +def render_template(template: str, data: dict[str, Any] | None = None) -> str: + """ + Refer to Block.format_block_parameter_template_from_workflow_run_context + + TODO: complete this function so that block code shares the same template rendering logic + """ + template_data = data or {} + jinja_template = jinja_sandbox_env.from_string(template) + context = skyvern_context.current() + if context and context.workflow_run_id: + workflow_run_id = context.workflow_run_id + workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id) + template_data.update(workflow_run_context.values) + + return jinja_template.render(template_data)