cache task run uses block level model override when ai fallback happens (#4073)

This commit is contained in:
Shuchang Zheng
2025-11-21 22:48:20 -08:00
committed by GitHub
parent b52982d3c8
commit 7729d7cffe
3 changed files with 130 additions and 19 deletions

View File

@@ -673,16 +673,30 @@ def _build_action_statement(
last_line=cst.SimpleWhitespace(INDENT),
),
),
cst.Arg(
keyword=cst.Name("label"),
value=_value(block_title),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
),
comma=cst.Comma(),
),
]
if block.get("model"):
args.append(
cst.Arg(
keyword=cst.Name("model"),
value=_value(block.get("model")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
if block.get("label"):
args.append(
cst.Arg(
keyword=cst.Name("label"),
value=_value(block.get("label")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
),
comma=cst.Comma(),
)
)
_mark_last_arg_as_comma(args)
call = cst.Call(
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("action")),
args=args,
@@ -733,15 +747,30 @@ def _build_extract_statement(
last_line=cst.SimpleWhitespace(INDENT),
),
),
cst.Arg(
keyword=cst.Name("label"),
value=_value(block_title),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
),
comma=cst.Comma(),
),
]
if block.get("model"):
args.append(
cst.Arg(
keyword=cst.Name("model"),
value=_value(block.get("model")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
if block.get("label"):
args.append(
cst.Arg(
keyword=cst.Name("label"),
value=_value(block_title),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
),
comma=cst.Comma(),
)
)
_mark_last_arg_as_comma(args)
call = cst.Call(
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("extract")),
@@ -882,6 +911,18 @@ def _build_validate_statement(
)
)
if block.get("model"):
args.append(
cst.Arg(
keyword=cst.Name("model"),
value=_value(block.get("model")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
# Add label if it exists
if block.get("label") is not None:
args.append(
@@ -1114,6 +1155,18 @@ def _build_pdf_parser_statement(block: dict[str, Any]) -> cst.SimpleStatementLin
)
)
if block.get("model"):
args.append(
cst.Arg(
keyword=cst.Name("model"),
value=_value(block.get("model")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
if block.get("label") is not None:
args.append(
cst.Arg(
@@ -1172,6 +1225,18 @@ def _build_file_url_parser_statement(block: dict[str, Any]) -> cst.SimpleStateme
)
)
if block.get("model"):
args.append(
cst.Arg(
keyword=cst.Name("model"),
value=_value(block.get("model")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
if block.get("label") is not None:
args.append(
cst.Arg(
@@ -1319,6 +1384,18 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
)
)
if block.get("model"):
args.append(
cst.Arg(
keyword=cst.Name("model"),
value=_value(block.get("model")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
if block.get("label") is not None:
args.append(
cst.Arg(
@@ -1331,7 +1408,7 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
)
)
if block.get("parameters") is not None:
if block.get("parameters"):
parameters = block.get("parameters", [])
parameter_list = [parameter["key"] for parameter in parameters]
args.append(
@@ -1340,10 +1417,12 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
value=_value(parameter_list),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
_mark_last_arg_as_comma(args)
call = cst.Call(
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("prompt")),
args=args,
@@ -1568,6 +1647,17 @@ def __build_base_task_statement(
),
)
)
if block.get("model"):
args.append(
cst.Arg(
keyword=cst.Name("model"),
value=_value(block.get("model")),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
if block.get("block_type") == "task_v2":
args.append(
cst.Arg(

View File

@@ -130,7 +130,6 @@ CacheInvalidationReason = Literal["updated_block", "new_block", "removed_block"]
BLOCK_TYPES_THAT_SHOULD_BE_CACHED = {
BlockType.TASK,
BlockType.TaskV2,
BlockType.VALIDATION,
BlockType.ACTION,
BlockType.NAVIGATION,
BlockType.EXTRACTION,

View File

@@ -399,6 +399,7 @@ async def _create_workflow_block_run_and_task(
schema: dict[str, Any] | list | str | None = None,
url: str | None = None,
label: str | None = None,
model: dict[str, Any] | None = None,
) -> tuple[str | None, str | None, str | None]:
"""
Create a workflow block run and optionally a task if workflow_run_id is available in context.
@@ -449,6 +450,7 @@ async def _create_workflow_block_run_and_task(
status="running",
organization_id=organization_id,
workflow_run_id=workflow_run_id,
model=model,
)
task_id = task.task_id
@@ -1210,6 +1212,7 @@ async def run_task(
prompt=prompt,
url=url,
label=cache_key,
model=model,
)
prompt = _render_template_with_label(prompt, cache_key)
# set the prompt in the RunContext
@@ -1256,6 +1259,7 @@ async def run_task(
totp_verification_url=totp_url,
include_action_history_in_verification=True,
engine=RunEngine.skyvern_v1,
model=model,
)
await task_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
@@ -1275,6 +1279,7 @@ async def download(
totp_url: str | None = None,
label: str | None = None,
cache_key: str | None = None,
model: dict[str, Any] | None = None,
) -> None:
cache_key = cache_key or label
cached_fn = script_run_context_manager.get_cached_fn(cache_key)
@@ -1286,6 +1291,7 @@ async def download(
prompt=prompt,
url=url,
label=cache_key,
model=model,
)
prompt = _render_template_with_label(prompt, cache_key)
# set the prompt in the RunContext
@@ -1332,6 +1338,7 @@ async def download(
totp_verification_url=totp_url,
include_action_history_in_verification=True,
engine=RunEngine.skyvern_v1,
model=model,
)
await file_download_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
@@ -1350,6 +1357,7 @@ async def action(
totp_url: str | None = None,
label: str | None = None,
cache_key: str | None = None,
model: dict[str, Any] | None = None,
) -> None:
context: skyvern_context.SkyvernContext | None
cache_key = cache_key or label
@@ -1361,6 +1369,7 @@ async def action(
prompt=prompt,
url=url,
label=cache_key,
model=model,
)
prompt = _render_template_with_label(prompt, cache_key)
# set the prompt in the RunContext
@@ -1406,6 +1415,7 @@ async def action(
max_steps_per_run=max_steps,
totp_identifier=totp_identifier,
totp_verification_url=totp_url,
model=model,
)
await action_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
@@ -1423,6 +1433,7 @@ async def login(
totp_url: str | None = None,
label: str | None = None,
cache_key: str | None = None,
model: dict[str, Any] | None = None,
) -> None:
context: skyvern_context.SkyvernContext | None
cache_key = cache_key or label
@@ -1435,6 +1446,7 @@ async def login(
prompt=prompt,
url=url,
label=cache_key,
model=model,
)
prompt = _render_template_with_label(prompt, cache_key)
# set the prompt in the RunContext
@@ -1478,6 +1490,7 @@ async def login(
max_steps_per_run=max_steps,
totp_identifier=totp_identifier,
totp_verification_url=totp_url,
model=model,
)
await login_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
@@ -1494,6 +1507,7 @@ async def extract(
max_steps: int | None = None,
label: str | None = None,
cache_key: str | None = None,
model: dict[str, Any] | None = None,
) -> dict[str, Any] | list | str | None:
output: dict[str, Any] | list | str | None = None
@@ -1508,6 +1522,7 @@ async def extract(
schema=schema,
url=url,
label=cache_key,
model=model,
)
prompt = _render_template_with_label(prompt, cache_key)
# set the prompt in the RunContext
@@ -1553,6 +1568,7 @@ async def extract(
max_steps_per_run=max_steps,
data_schema=schema,
output_parameter=block_validation_output.output_parameter,
model=model,
)
block_result = await extraction_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
@@ -1568,6 +1584,7 @@ async def validate(
terminate_criterion: str | None = None,
error_code_mapping: dict[str, str] | None = None,
label: str | None = None,
model: dict[str, Any] | None = None,
) -> None:
"""Validate function that behaves like a ValidationBlock"""
if not complete_criterion and not terminate_criterion:
@@ -1582,6 +1599,7 @@ async def validate(
terminate_criterion=terminate_criterion,
error_code_mapping=error_code_mapping,
max_steps_per_run=2,
model=model,
)
result = await validation_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
@@ -1927,6 +1945,7 @@ async def parse_file(
schema: dict[str, Any] | None = None,
label: str | None = None,
parameters: list[str] | None = None,
model: dict[str, Any] | None = None,
) -> None:
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
file_url = _render_template_with_label(file_url, label)
@@ -1937,6 +1956,7 @@ async def parse_file(
label=block_validation_output.label,
output_parameter=block_validation_output.output_parameter,
parameters=block_validation_output.input_parameters,
model=model,
)
await file_parser_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
@@ -2008,6 +2028,7 @@ async def prompt(
schema: dict[str, Any] | None = None,
label: str | None = None,
parameters: list[str] | None = None,
model: dict[str, Any] | None = None,
) -> dict[str, Any] | list | str | None:
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
prompt = _render_template_with_label(prompt, label)
@@ -2017,6 +2038,7 @@ async def prompt(
label=block_validation_output.label,
output_parameter=block_validation_output.output_parameter,
parameters=block_validation_output.input_parameters,
model=model,
)
result = await prompt_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,