Record output of cached task run when there's extracted information (#4140)
This commit is contained in:
@@ -498,7 +498,6 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
|
||||
name = _safe_name(block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}")
|
||||
cache_key = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"
|
||||
body_stmts: list[cst.BaseStatement] = []
|
||||
is_extraction_block = block.get("block_type") == "extraction"
|
||||
|
||||
if block.get("url"):
|
||||
body_stmts.append(cst.parse_statement(f"await page.goto({repr(block['url'])})"))
|
||||
@@ -508,7 +507,7 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
|
||||
continue
|
||||
|
||||
# For extraction blocks, assign extract action results to output variable
|
||||
assign_to_output = is_extraction_block and act["action_type"] == "extract"
|
||||
assign_to_output = act["action_type"] == "extract"
|
||||
body_stmts.append(_action_to_stmt(act, block, assign_to_output=assign_to_output))
|
||||
|
||||
# add complete action
|
||||
@@ -518,7 +517,7 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
|
||||
body_stmts.append(_action_to_stmt(complete_action, block))
|
||||
|
||||
# For extraction blocks, add return output statement if we have actions
|
||||
if is_extraction_block and any(
|
||||
if any(
|
||||
act["action_type"] == "extract"
|
||||
for act in actions
|
||||
if act["action_type"] not in [ActionType.COMPLETE, ActionType.TERMINATE, ActionType.NULL_ACTION]
|
||||
@@ -549,12 +548,18 @@ def _build_task_v2_block_fn(block: dict[str, Any], child_blocks: list[dict[str,
|
||||
body_stmts: list[cst.BaseStatement] = []
|
||||
|
||||
# Add calls to child workflow sub-tasks
|
||||
has_extract_block = False
|
||||
for child_block in child_blocks:
|
||||
stmt = _build_block_statement(child_block)
|
||||
is_extract_block = child_block.get("block_type") == "extraction"
|
||||
if is_extract_block:
|
||||
has_extract_block = True
|
||||
stmt = _build_block_statement(child_block, assign_output=is_extract_block)
|
||||
body_stmts.append(stmt)
|
||||
|
||||
if not body_stmts:
|
||||
body_stmts.append(cst.parse_statement("return None"))
|
||||
elif has_extract_block:
|
||||
body_stmts.append(cst.parse_statement("return output"))
|
||||
|
||||
return FunctionDef(
|
||||
name=Name(name),
|
||||
@@ -727,7 +732,10 @@ def _build_login_statement(
|
||||
|
||||
|
||||
def _build_extract_statement(
|
||||
block_title: str, block: dict[str, Any], data_variable_name: str | None = None
|
||||
block_title: str,
|
||||
block: dict[str, Any],
|
||||
data_variable_name: str | None = None,
|
||||
assign_output: bool = True,
|
||||
) -> cst.SimpleStatementLine:
|
||||
"""Build a skyvern.extract statement."""
|
||||
args = [
|
||||
@@ -781,7 +789,17 @@ def _build_extract_statement(
|
||||
),
|
||||
)
|
||||
|
||||
return cst.SimpleStatementLine([cst.Expr(cst.Await(call))])
|
||||
if assign_output:
|
||||
return cst.SimpleStatementLine(
|
||||
[
|
||||
cst.Assign(
|
||||
targets=[cst.AssignTarget(target=cst.Name("output"))],
|
||||
value=cst.Await(call),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return cst.SimpleStatementLine([cst.Expr(cst.Await(call))])
|
||||
|
||||
|
||||
def _build_navigate_statement(
|
||||
@@ -1687,7 +1705,9 @@ def __build_base_task_statement(
|
||||
# --------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _build_block_statement(block: dict[str, Any], data_variable_name: str | None = None) -> cst.SimpleStatementLine:
|
||||
def _build_block_statement(
|
||||
block: dict[str, Any], data_variable_name: str | None = None, assign_output: bool = False
|
||||
) -> cst.SimpleStatementLine:
|
||||
"""Build a block statement."""
|
||||
block_type = block.get("block_type")
|
||||
block_title = block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"
|
||||
@@ -1703,7 +1723,7 @@ def _build_block_statement(block: dict[str, Any], data_variable_name: str | None
|
||||
elif block_type == "login":
|
||||
stmt = _build_login_statement(block_title, block, data_variable_name)
|
||||
elif block_type == "extraction":
|
||||
stmt = _build_extract_statement(block_title, block, data_variable_name)
|
||||
stmt = _build_extract_statement(block_title, block, data_variable_name, assign_output)
|
||||
elif block_type == "navigation":
|
||||
stmt = _build_navigate_statement(block_title, block, data_variable_name)
|
||||
elif block_type == "validation":
|
||||
@@ -1748,7 +1768,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
|
||||
]
|
||||
|
||||
for block in blocks:
|
||||
stmt = _build_block_statement(block)
|
||||
stmt = _build_block_statement(block, assign_output=False)
|
||||
body.append(stmt)
|
||||
|
||||
params = cst.Parameters(
|
||||
|
||||
@@ -1194,7 +1194,7 @@ async def run_task(
|
||||
cache_key: str | None = None,
|
||||
engine: RunEngine = RunEngine.skyvern_v1,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
cache_key = cache_key or label
|
||||
cached_fn = script_run_context_manager.get_cached_fn(cache_key)
|
||||
|
||||
@@ -1213,7 +1213,7 @@ async def run_task(
|
||||
context = skyvern_context.ensure_context()
|
||||
context.prompt = prompt
|
||||
try:
|
||||
await _run_cached_function(cached_fn)
|
||||
output = await _run_cached_function(cached_fn)
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
@@ -1221,9 +1221,11 @@ async def run_task(
|
||||
workflow_run_block_id,
|
||||
BlockStatus.completed,
|
||||
task_id=task_id,
|
||||
output=output,
|
||||
step_id=step_id,
|
||||
label=cache_key,
|
||||
)
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to run task block. Falling back to AI run.")
|
||||
@@ -1238,6 +1240,7 @@ async def run_task(
|
||||
error=e,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
# clear the prompt in the RunContext
|
||||
context.prompt = None
|
||||
@@ -1255,12 +1258,13 @@ async def run_task(
|
||||
engine=RunEngine.skyvern_v1,
|
||||
model=model,
|
||||
)
|
||||
await task_block.execute_safe(
|
||||
block_output = await task_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
|
||||
organization_id=block_validation_output.organization_id,
|
||||
browser_session_id=block_validation_output.browser_session_id,
|
||||
)
|
||||
return block_output.output_parameter_value
|
||||
|
||||
|
||||
async def download(
|
||||
|
||||
Reference in New Issue
Block a user