cache task run uses block level model override when ai fallback happens (#4073)
This commit is contained in:
@@ -673,16 +673,30 @@ def _build_action_statement(
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
),
|
||||
cst.Arg(
|
||||
keyword=cst.Name("label"),
|
||||
value=_value(block_title),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
),
|
||||
comma=cst.Comma(),
|
||||
),
|
||||
]
|
||||
|
||||
if block.get("model"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("model"),
|
||||
value=_value(block.get("model")),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
if block.get("label"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("label"),
|
||||
value=_value(block.get("label")),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
),
|
||||
comma=cst.Comma(),
|
||||
)
|
||||
)
|
||||
_mark_last_arg_as_comma(args)
|
||||
call = cst.Call(
|
||||
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("action")),
|
||||
args=args,
|
||||
@@ -733,15 +747,30 @@ def _build_extract_statement(
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
),
|
||||
cst.Arg(
|
||||
keyword=cst.Name("label"),
|
||||
value=_value(block_title),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
),
|
||||
comma=cst.Comma(),
|
||||
),
|
||||
]
|
||||
if block.get("model"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("model"),
|
||||
value=_value(block.get("model")),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
if block.get("label"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("label"),
|
||||
value=_value(block_title),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
),
|
||||
comma=cst.Comma(),
|
||||
)
|
||||
)
|
||||
_mark_last_arg_as_comma(args)
|
||||
|
||||
call = cst.Call(
|
||||
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("extract")),
|
||||
@@ -882,6 +911,18 @@ def _build_validate_statement(
|
||||
)
|
||||
)
|
||||
|
||||
if block.get("model"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("model"),
|
||||
value=_value(block.get("model")),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Add label if it exists
|
||||
if block.get("label") is not None:
|
||||
args.append(
|
||||
@@ -1114,6 +1155,18 @@ def _build_pdf_parser_statement(block: dict[str, Any]) -> cst.SimpleStatementLin
|
||||
)
|
||||
)
|
||||
|
||||
if block.get("model"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("model"),
|
||||
value=_value(block.get("model")),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if block.get("label") is not None:
|
||||
args.append(
|
||||
cst.Arg(
|
||||
@@ -1172,6 +1225,18 @@ def _build_file_url_parser_statement(block: dict[str, Any]) -> cst.SimpleStateme
|
||||
)
|
||||
)
|
||||
|
||||
if block.get("model"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("model"),
|
||||
value=_value(block.get("model")),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if block.get("label") is not None:
|
||||
args.append(
|
||||
cst.Arg(
|
||||
@@ -1319,6 +1384,18 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
|
||||
)
|
||||
)
|
||||
|
||||
if block.get("model"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("model"),
|
||||
value=_value(block.get("model")),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if block.get("label") is not None:
|
||||
args.append(
|
||||
cst.Arg(
|
||||
@@ -1331,7 +1408,7 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
|
||||
)
|
||||
)
|
||||
|
||||
if block.get("parameters") is not None:
|
||||
if block.get("parameters"):
|
||||
parameters = block.get("parameters", [])
|
||||
parameter_list = [parameter["key"] for parameter in parameters]
|
||||
args.append(
|
||||
@@ -1340,10 +1417,12 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
|
||||
value=_value(parameter_list),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
_mark_last_arg_as_comma(args)
|
||||
call = cst.Call(
|
||||
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("prompt")),
|
||||
args=args,
|
||||
@@ -1568,6 +1647,17 @@ def __build_base_task_statement(
|
||||
),
|
||||
)
|
||||
)
|
||||
if block.get("model"):
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("model"),
|
||||
value=_value(block.get("model")),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
if block.get("block_type") == "task_v2":
|
||||
args.append(
|
||||
cst.Arg(
|
||||
|
||||
@@ -130,7 +130,6 @@ CacheInvalidationReason = Literal["updated_block", "new_block", "removed_block"]
|
||||
BLOCK_TYPES_THAT_SHOULD_BE_CACHED = {
|
||||
BlockType.TASK,
|
||||
BlockType.TaskV2,
|
||||
BlockType.VALIDATION,
|
||||
BlockType.ACTION,
|
||||
BlockType.NAVIGATION,
|
||||
BlockType.EXTRACTION,
|
||||
|
||||
@@ -399,6 +399,7 @@ async def _create_workflow_block_run_and_task(
|
||||
schema: dict[str, Any] | list | str | None = None,
|
||||
url: str | None = None,
|
||||
label: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""
|
||||
Create a workflow block run and optionally a task if workflow_run_id is available in context.
|
||||
@@ -449,6 +450,7 @@ async def _create_workflow_block_run_and_task(
|
||||
status="running",
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
model=model,
|
||||
)
|
||||
|
||||
task_id = task.task_id
|
||||
@@ -1210,6 +1212,7 @@ async def run_task(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1256,6 +1259,7 @@ async def run_task(
|
||||
totp_verification_url=totp_url,
|
||||
include_action_history_in_verification=True,
|
||||
engine=RunEngine.skyvern_v1,
|
||||
model=model,
|
||||
)
|
||||
await task_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1275,6 +1279,7 @@ async def download(
|
||||
totp_url: str | None = None,
|
||||
label: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
cache_key = cache_key or label
|
||||
cached_fn = script_run_context_manager.get_cached_fn(cache_key)
|
||||
@@ -1286,6 +1291,7 @@ async def download(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1332,6 +1338,7 @@ async def download(
|
||||
totp_verification_url=totp_url,
|
||||
include_action_history_in_verification=True,
|
||||
engine=RunEngine.skyvern_v1,
|
||||
model=model,
|
||||
)
|
||||
await file_download_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1350,6 +1357,7 @@ async def action(
|
||||
totp_url: str | None = None,
|
||||
label: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
context: skyvern_context.SkyvernContext | None
|
||||
cache_key = cache_key or label
|
||||
@@ -1361,6 +1369,7 @@ async def action(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1406,6 +1415,7 @@ async def action(
|
||||
max_steps_per_run=max_steps,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_verification_url=totp_url,
|
||||
model=model,
|
||||
)
|
||||
await action_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1423,6 +1433,7 @@ async def login(
|
||||
totp_url: str | None = None,
|
||||
label: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
context: skyvern_context.SkyvernContext | None
|
||||
cache_key = cache_key or label
|
||||
@@ -1435,6 +1446,7 @@ async def login(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1478,6 +1490,7 @@ async def login(
|
||||
max_steps_per_run=max_steps,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_verification_url=totp_url,
|
||||
model=model,
|
||||
)
|
||||
await login_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1494,6 +1507,7 @@ async def extract(
|
||||
max_steps: int | None = None,
|
||||
label: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
output: dict[str, Any] | list | str | None = None
|
||||
|
||||
@@ -1508,6 +1522,7 @@ async def extract(
|
||||
schema=schema,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1553,6 +1568,7 @@ async def extract(
|
||||
max_steps_per_run=max_steps,
|
||||
data_schema=schema,
|
||||
output_parameter=block_validation_output.output_parameter,
|
||||
model=model,
|
||||
)
|
||||
block_result = await extraction_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1568,6 +1584,7 @@ async def validate(
|
||||
terminate_criterion: str | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
label: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Validate function that behaves like a ValidationBlock"""
|
||||
if not complete_criterion and not terminate_criterion:
|
||||
@@ -1582,6 +1599,7 @@ async def validate(
|
||||
terminate_criterion=terminate_criterion,
|
||||
error_code_mapping=error_code_mapping,
|
||||
max_steps_per_run=2,
|
||||
model=model,
|
||||
)
|
||||
result = await validation_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1927,6 +1945,7 @@ async def parse_file(
|
||||
schema: dict[str, Any] | None = None,
|
||||
label: str | None = None,
|
||||
parameters: list[str] | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
|
||||
file_url = _render_template_with_label(file_url, label)
|
||||
@@ -1937,6 +1956,7 @@ async def parse_file(
|
||||
label=block_validation_output.label,
|
||||
output_parameter=block_validation_output.output_parameter,
|
||||
parameters=block_validation_output.input_parameters,
|
||||
model=model,
|
||||
)
|
||||
await file_parser_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -2008,6 +2028,7 @@ async def prompt(
|
||||
schema: dict[str, Any] | None = None,
|
||||
label: str | None = None,
|
||||
parameters: list[str] | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
|
||||
prompt = _render_template_with_label(prompt, label)
|
||||
@@ -2017,6 +2038,7 @@ async def prompt(
|
||||
label=block_validation_output.label,
|
||||
output_parameter=block_validation_output.output_parameter,
|
||||
parameters=block_validation_output.input_parameters,
|
||||
model=model,
|
||||
)
|
||||
result = await prompt_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
|
||||
Reference in New Issue
Block a user