diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index fdea2c24..66390b37 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -498,7 +498,6 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun name = _safe_name(block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}") cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}" body_stmts: list[cst.BaseStatement] = [] - is_extraction_block = block.get("block_type") == "extraction" if block.get("url"): body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})")) @@ -508,7 +507,7 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun continue # For extraction blocks, assign extract action results to output variable - assign_to_output = is_extraction_block and act["action_type"] == "extract" + assign_to_output = act["action_type"] == "extract" body_stmts.append(_action_to_stmt(act, block, assign_to_output=assign_to_output)) # add complete action @@ -518,7 +517,7 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun body_stmts.append(_action_to_stmt(complete_action, block)) # For extraction blocks, add return output statement if we have actions - if is_extraction_block and any( + if any( act["action_type"] == "extract" for act in actions if act["action_type"] not in [ActionType.COMPLETE, ActionType.TERMINATE, ActionType.NULL_ACTION] @@ -549,12 +548,18 @@ def _build_task_v2_block_fn(block: dict[str, Any], child_blocks: list[dict[str, body_stmts: list[cst.BaseStatement] = [] # Add calls to child workflow sub-tasks + has_extract_block = False for child_block in child_blocks: - stmt = _build_block_statement(child_block) + is_extract_block = child_block.get("block_type") == "extraction" + if is_extract_block: + has_extract_block = True + stmt = _build_block_statement(child_block, assign_output=is_extract_block) body_stmts.append(stmt) if not body_stmts: body_stmts.append(cst.parse_statement("return None")) + elif has_extract_block: + body_stmts.append(cst.parse_statement("return output")) return FunctionDef( name=Name(name), @@ -727,7 +732,10 @@ def _build_login_statement( def _build_extract_statement( - block_title: str, block: dict[str, Any], data_variable_name: str | None = None + block_title: str, + block: dict[str, Any], + data_variable_name: str | None = None, + assign_output: bool = True, ) -> cst.SimpleStatementLine: """Build a skyvern.extract statement.""" args = [ @@ -781,7 +789,17 @@ def _build_extract_statement( ), ) - return cst.SimpleStatementLine([cst.Expr(cst.Await(call))]) + if assign_output: + return cst.SimpleStatementLine( + [ + cst.Assign( + targets=[cst.AssignTarget(target=cst.Name("output"))], + value=cst.Await(call), + ) + ] + ) + else: + return cst.SimpleStatementLine([cst.Expr(cst.Await(call))]) def _build_navigate_statement( @@ -1687,7 +1705,9 @@ def __build_base_task_statement( # --------------------------------------------------------------------- # -def _build_block_statement(block: dict[str, Any], data_variable_name: str | None = None) -> cst.SimpleStatementLine: +def _build_block_statement( + block: dict[str, Any], data_variable_name: str | None = None, assign_output: bool = False +) -> cst.SimpleStatementLine: """Build a block statement.""" block_type = block.get("block_type") block_title = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}" @@ -1703,7 +1723,7 @@ def _build_block_statement(block: dict[str, Any], data_variable_name: str | None elif block_type == "login": stmt = _build_login_statement(block_title, block, data_variable_name) elif block_type == "extraction": - stmt = _build_extract_statement(block_title, block, data_variable_name) + stmt = _build_extract_statement(block_title, block, data_variable_name, assign_output) elif block_type == "navigation": stmt = _build_navigate_statement(block_title, block, data_variable_name) elif block_type == "validation": @@ -1748,7 +1768,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct ] for block in blocks: - stmt = _build_block_statement(block) + stmt = _build_block_statement(block, assign_output=False) body.append(stmt) params = cst.Parameters( diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index c1d62ddb..bcdefd59 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -1194,7 +1194,7 @@ async def run_task( cache_key: str | None = None, engine: RunEngine = RunEngine.skyvern_v1, model: dict[str, Any] | None = None, -) -> None: +) -> dict[str, Any] | list | str | None: cache_key = cache_key or label cached_fn = script_run_context_manager.get_cached_fn(cache_key) @@ -1213,7 +1213,7 @@ async def run_task( context = skyvern_context.ensure_context() context.prompt = prompt try: - await _run_cached_function(cached_fn) + output = await _run_cached_function(cached_fn) # Update block status to completed if workflow block was created if workflow_run_block_id: @@ -1221,9 +1221,11 @@ async def run_task( workflow_run_block_id, BlockStatus.completed, task_id=task_id, + output=output, step_id=step_id, label=cache_key, ) + return output except Exception as e: LOG.exception("Failed to run task block. Falling back to AI run.") @@ -1238,6 +1240,7 @@ async def run_task( error=e, workflow_run_block_id=workflow_run_block_id, ) + return None finally: # clear the prompt in the RunContext context.prompt = None @@ -1255,12 +1258,13 @@ async def run_task( engine=RunEngine.skyvern_v1, model=model, ) - await task_block.execute_safe( + block_output = await task_block.execute_safe( workflow_run_id=block_validation_output.workflow_run_id, parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id, organization_id=block_validation_output.organization_id, browser_session_id=block_validation_output.browser_session_id, ) + return block_output.output_parameter_value async def download(