Record output of cached task run when there's extracted information (#4140)

This commit is contained in:
Shuchang Zheng
2025-11-28 19:05:10 -08:00
committed by GitHub
parent 1802435ed5
commit 0ad149d905
2 changed files with 36 additions and 12 deletions

View File

@@ -498,7 +498,6 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
name = _safe_name(block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}")
cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"
body_stmts: list[cst.BaseStatement] = []
is_extraction_block = block.get("block_type") == "extraction"
if block.get("url"):
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))
@@ -508,7 +507,7 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
continue
# For extraction blocks, assign extract action results to output variable
assign_to_output = is_extraction_block and act["action_type"] == "extract"
assign_to_output = act["action_type"] == "extract"
body_stmts.append(_action_to_stmt(act, block, assign_to_output=assign_to_output))
# add complete action
@@ -518,7 +517,7 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
body_stmts.append(_action_to_stmt(complete_action, block))
# For extraction blocks, add return output statement if we have actions
if is_extraction_block and any(
if any(
act["action_type"] == "extract"
for act in actions
if act["action_type"] not in [ActionType.COMPLETE, ActionType.TERMINATE, ActionType.NULL_ACTION]
@@ -549,12 +548,18 @@ def _build_task_v2_block_fn(block: dict[str, Any], child_blocks: list[dict[str,
body_stmts: list[cst.BaseStatement] = []
# Add calls to child workflow sub-tasks
has_extract_block = False
for child_block in child_blocks:
stmt = _build_block_statement(child_block)
is_extract_block = child_block.get("block_type") == "extraction"
if is_extract_block:
has_extract_block = True
stmt = _build_block_statement(child_block, assign_output=is_extract_block)
body_stmts.append(stmt)
if not body_stmts:
body_stmts.append(cst.parse_statement("return None"))
elif has_extract_block:
body_stmts.append(cst.parse_statement("return output"))
return FunctionDef(
name=Name(name),
@@ -727,7 +732,10 @@ def _build_login_statement(
def _build_extract_statement(
block_title: str, block: dict[str, Any], data_variable_name: str | None = None
block_title: str,
block: dict[str, Any],
data_variable_name: str | None = None,
assign_output: bool = True,
) -> cst.SimpleStatementLine:
"""Build a skyvern.extract statement."""
args = [
@@ -781,7 +789,17 @@ def _build_extract_statement(
),
)
return cst.SimpleStatementLine([cst.Expr(cst.Await(call))])
if assign_output:
return cst.SimpleStatementLine(
[
cst.Assign(
targets=[cst.AssignTarget(target=cst.Name("output"))],
value=cst.Await(call),
)
]
)
else:
return cst.SimpleStatementLine([cst.Expr(cst.Await(call))])
def _build_navigate_statement(
@@ -1687,7 +1705,9 @@ def __build_base_task_statement(
# --------------------------------------------------------------------- #
def _build_block_statement(block: dict[str, Any], data_variable_name: str | None = None) -> cst.SimpleStatementLine:
def _build_block_statement(
block: dict[str, Any], data_variable_name: str | None = None, assign_output: bool = False
) -> cst.SimpleStatementLine:
"""Build a block statement."""
block_type = block.get("block_type")
block_title = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"
@@ -1703,7 +1723,7 @@ def _build_block_statement(block: dict[str, Any], data_variable_name: str | None
elif block_type == "login":
stmt = _build_login_statement(block_title, block, data_variable_name)
elif block_type == "extraction":
stmt = _build_extract_statement(block_title, block, data_variable_name)
stmt = _build_extract_statement(block_title, block, data_variable_name, assign_output)
elif block_type == "navigation":
stmt = _build_navigate_statement(block_title, block, data_variable_name)
elif block_type == "validation":
@@ -1748,7 +1768,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
]
for block in blocks:
stmt = _build_block_statement(block)
stmt = _build_block_statement(block, assign_output=False)
body.append(stmt)
params = cst.Parameters(