cache task run uses block level model override when ai fallback happens (#4073)
This commit is contained in:
@@ -673,16 +673,30 @@ def _build_action_statement(
|
|||||||
last_line=cst.SimpleWhitespace(INDENT),
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
cst.Arg(
|
|
||||||
keyword=cst.Name("label"),
|
|
||||||
value=_value(block_title),
|
|
||||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
|
||||||
indent=True,
|
|
||||||
),
|
|
||||||
comma=cst.Comma(),
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
if block.get("model"):
|
||||||
|
args.append(
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("model"),
|
||||||
|
value=_value(block.get("model")),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if block.get("label"):
|
||||||
|
args.append(
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("label"),
|
||||||
|
value=_value(block.get("label")),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
),
|
||||||
|
comma=cst.Comma(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_mark_last_arg_as_comma(args)
|
||||||
call = cst.Call(
|
call = cst.Call(
|
||||||
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("action")),
|
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("action")),
|
||||||
args=args,
|
args=args,
|
||||||
@@ -733,15 +747,30 @@ def _build_extract_statement(
|
|||||||
last_line=cst.SimpleWhitespace(INDENT),
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
cst.Arg(
|
|
||||||
keyword=cst.Name("label"),
|
|
||||||
value=_value(block_title),
|
|
||||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
|
||||||
indent=True,
|
|
||||||
),
|
|
||||||
comma=cst.Comma(),
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
if block.get("model"):
|
||||||
|
args.append(
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("model"),
|
||||||
|
value=_value(block.get("model")),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if block.get("label"):
|
||||||
|
args.append(
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("label"),
|
||||||
|
value=_value(block_title),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
),
|
||||||
|
comma=cst.Comma(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_mark_last_arg_as_comma(args)
|
||||||
|
|
||||||
call = cst.Call(
|
call = cst.Call(
|
||||||
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("extract")),
|
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("extract")),
|
||||||
@@ -882,6 +911,18 @@ def _build_validate_statement(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if block.get("model"):
|
||||||
|
args.append(
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("model"),
|
||||||
|
value=_value(block.get("model")),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Add label if it exists
|
# Add label if it exists
|
||||||
if block.get("label") is not None:
|
if block.get("label") is not None:
|
||||||
args.append(
|
args.append(
|
||||||
@@ -1114,6 +1155,18 @@ def _build_pdf_parser_statement(block: dict[str, Any]) -> cst.SimpleStatementLin
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if block.get("model"):
|
||||||
|
args.append(
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("model"),
|
||||||
|
value=_value(block.get("model")),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if block.get("label") is not None:
|
if block.get("label") is not None:
|
||||||
args.append(
|
args.append(
|
||||||
cst.Arg(
|
cst.Arg(
|
||||||
@@ -1172,6 +1225,18 @@ def _build_file_url_parser_statement(block: dict[str, Any]) -> cst.SimpleStateme
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if block.get("model"):
|
||||||
|
args.append(
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("model"),
|
||||||
|
value=_value(block.get("model")),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if block.get("label") is not None:
|
if block.get("label") is not None:
|
||||||
args.append(
|
args.append(
|
||||||
cst.Arg(
|
cst.Arg(
|
||||||
@@ -1319,6 +1384,18 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if block.get("model"):
|
||||||
|
args.append(
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("model"),
|
||||||
|
value=_value(block.get("model")),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if block.get("label") is not None:
|
if block.get("label") is not None:
|
||||||
args.append(
|
args.append(
|
||||||
cst.Arg(
|
cst.Arg(
|
||||||
@@ -1331,7 +1408,7 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if block.get("parameters") is not None:
|
if block.get("parameters"):
|
||||||
parameters = block.get("parameters", [])
|
parameters = block.get("parameters", [])
|
||||||
parameter_list = [parameter["key"] for parameter in parameters]
|
parameter_list = [parameter["key"] for parameter in parameters]
|
||||||
args.append(
|
args.append(
|
||||||
@@ -1340,10 +1417,12 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
|
|||||||
value=_value(parameter_list),
|
value=_value(parameter_list),
|
||||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
indent=True,
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_mark_last_arg_as_comma(args)
|
||||||
call = cst.Call(
|
call = cst.Call(
|
||||||
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("prompt")),
|
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("prompt")),
|
||||||
args=args,
|
args=args,
|
||||||
@@ -1568,6 +1647,17 @@ def __build_base_task_statement(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if block.get("model"):
|
||||||
|
args.append(
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("model"),
|
||||||
|
value=_value(block.get("model")),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
if block.get("block_type") == "task_v2":
|
if block.get("block_type") == "task_v2":
|
||||||
args.append(
|
args.append(
|
||||||
cst.Arg(
|
cst.Arg(
|
||||||
|
|||||||
@@ -130,7 +130,6 @@ CacheInvalidationReason = Literal["updated_block", "new_block", "removed_block"]
|
|||||||
BLOCK_TYPES_THAT_SHOULD_BE_CACHED = {
|
BLOCK_TYPES_THAT_SHOULD_BE_CACHED = {
|
||||||
BlockType.TASK,
|
BlockType.TASK,
|
||||||
BlockType.TaskV2,
|
BlockType.TaskV2,
|
||||||
BlockType.VALIDATION,
|
|
||||||
BlockType.ACTION,
|
BlockType.ACTION,
|
||||||
BlockType.NAVIGATION,
|
BlockType.NAVIGATION,
|
||||||
BlockType.EXTRACTION,
|
BlockType.EXTRACTION,
|
||||||
|
|||||||
@@ -399,6 +399,7 @@ async def _create_workflow_block_run_and_task(
|
|||||||
schema: dict[str, Any] | list | str | None = None,
|
schema: dict[str, Any] | list | str | None = None,
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
|
model: dict[str, Any] | None = None,
|
||||||
) -> tuple[str | None, str | None, str | None]:
|
) -> tuple[str | None, str | None, str | None]:
|
||||||
"""
|
"""
|
||||||
Create a workflow block run and optionally a task if workflow_run_id is available in context.
|
Create a workflow block run and optionally a task if workflow_run_id is available in context.
|
||||||
@@ -449,6 +450,7 @@ async def _create_workflow_block_run_and_task(
|
|||||||
status="running",
|
status="running",
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_id = task.task_id
|
task_id = task.task_id
|
||||||
@@ -1210,6 +1212,7 @@ async def run_task(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
url=url,
|
url=url,
|
||||||
label=cache_key,
|
label=cache_key,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
prompt = _render_template_with_label(prompt, cache_key)
|
prompt = _render_template_with_label(prompt, cache_key)
|
||||||
# set the prompt in the RunContext
|
# set the prompt in the RunContext
|
||||||
@@ -1256,6 +1259,7 @@ async def run_task(
|
|||||||
totp_verification_url=totp_url,
|
totp_verification_url=totp_url,
|
||||||
include_action_history_in_verification=True,
|
include_action_history_in_verification=True,
|
||||||
engine=RunEngine.skyvern_v1,
|
engine=RunEngine.skyvern_v1,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
await task_block.execute_safe(
|
await task_block.execute_safe(
|
||||||
workflow_run_id=block_validation_output.workflow_run_id,
|
workflow_run_id=block_validation_output.workflow_run_id,
|
||||||
@@ -1275,6 +1279,7 @@ async def download(
|
|||||||
totp_url: str | None = None,
|
totp_url: str | None = None,
|
||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
cache_key: str | None = None,
|
cache_key: str | None = None,
|
||||||
|
model: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
cache_key = cache_key or label
|
cache_key = cache_key or label
|
||||||
cached_fn = script_run_context_manager.get_cached_fn(cache_key)
|
cached_fn = script_run_context_manager.get_cached_fn(cache_key)
|
||||||
@@ -1286,6 +1291,7 @@ async def download(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
url=url,
|
url=url,
|
||||||
label=cache_key,
|
label=cache_key,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
prompt = _render_template_with_label(prompt, cache_key)
|
prompt = _render_template_with_label(prompt, cache_key)
|
||||||
# set the prompt in the RunContext
|
# set the prompt in the RunContext
|
||||||
@@ -1332,6 +1338,7 @@ async def download(
|
|||||||
totp_verification_url=totp_url,
|
totp_verification_url=totp_url,
|
||||||
include_action_history_in_verification=True,
|
include_action_history_in_verification=True,
|
||||||
engine=RunEngine.skyvern_v1,
|
engine=RunEngine.skyvern_v1,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
await file_download_block.execute_safe(
|
await file_download_block.execute_safe(
|
||||||
workflow_run_id=block_validation_output.workflow_run_id,
|
workflow_run_id=block_validation_output.workflow_run_id,
|
||||||
@@ -1350,6 +1357,7 @@ async def action(
|
|||||||
totp_url: str | None = None,
|
totp_url: str | None = None,
|
||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
cache_key: str | None = None,
|
cache_key: str | None = None,
|
||||||
|
model: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
context: skyvern_context.SkyvernContext | None
|
context: skyvern_context.SkyvernContext | None
|
||||||
cache_key = cache_key or label
|
cache_key = cache_key or label
|
||||||
@@ -1361,6 +1369,7 @@ async def action(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
url=url,
|
url=url,
|
||||||
label=cache_key,
|
label=cache_key,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
prompt = _render_template_with_label(prompt, cache_key)
|
prompt = _render_template_with_label(prompt, cache_key)
|
||||||
# set the prompt in the RunContext
|
# set the prompt in the RunContext
|
||||||
@@ -1406,6 +1415,7 @@ async def action(
|
|||||||
max_steps_per_run=max_steps,
|
max_steps_per_run=max_steps,
|
||||||
totp_identifier=totp_identifier,
|
totp_identifier=totp_identifier,
|
||||||
totp_verification_url=totp_url,
|
totp_verification_url=totp_url,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
await action_block.execute_safe(
|
await action_block.execute_safe(
|
||||||
workflow_run_id=block_validation_output.workflow_run_id,
|
workflow_run_id=block_validation_output.workflow_run_id,
|
||||||
@@ -1423,6 +1433,7 @@ async def login(
|
|||||||
totp_url: str | None = None,
|
totp_url: str | None = None,
|
||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
cache_key: str | None = None,
|
cache_key: str | None = None,
|
||||||
|
model: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
context: skyvern_context.SkyvernContext | None
|
context: skyvern_context.SkyvernContext | None
|
||||||
cache_key = cache_key or label
|
cache_key = cache_key or label
|
||||||
@@ -1435,6 +1446,7 @@ async def login(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
url=url,
|
url=url,
|
||||||
label=cache_key,
|
label=cache_key,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
prompt = _render_template_with_label(prompt, cache_key)
|
prompt = _render_template_with_label(prompt, cache_key)
|
||||||
# set the prompt in the RunContext
|
# set the prompt in the RunContext
|
||||||
@@ -1478,6 +1490,7 @@ async def login(
|
|||||||
max_steps_per_run=max_steps,
|
max_steps_per_run=max_steps,
|
||||||
totp_identifier=totp_identifier,
|
totp_identifier=totp_identifier,
|
||||||
totp_verification_url=totp_url,
|
totp_verification_url=totp_url,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
await login_block.execute_safe(
|
await login_block.execute_safe(
|
||||||
workflow_run_id=block_validation_output.workflow_run_id,
|
workflow_run_id=block_validation_output.workflow_run_id,
|
||||||
@@ -1494,6 +1507,7 @@ async def extract(
|
|||||||
max_steps: int | None = None,
|
max_steps: int | None = None,
|
||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
cache_key: str | None = None,
|
cache_key: str | None = None,
|
||||||
|
model: dict[str, Any] | None = None,
|
||||||
) -> dict[str, Any] | list | str | None:
|
) -> dict[str, Any] | list | str | None:
|
||||||
output: dict[str, Any] | list | str | None = None
|
output: dict[str, Any] | list | str | None = None
|
||||||
|
|
||||||
@@ -1508,6 +1522,7 @@ async def extract(
|
|||||||
schema=schema,
|
schema=schema,
|
||||||
url=url,
|
url=url,
|
||||||
label=cache_key,
|
label=cache_key,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
prompt = _render_template_with_label(prompt, cache_key)
|
prompt = _render_template_with_label(prompt, cache_key)
|
||||||
# set the prompt in the RunContext
|
# set the prompt in the RunContext
|
||||||
@@ -1553,6 +1568,7 @@ async def extract(
|
|||||||
max_steps_per_run=max_steps,
|
max_steps_per_run=max_steps,
|
||||||
data_schema=schema,
|
data_schema=schema,
|
||||||
output_parameter=block_validation_output.output_parameter,
|
output_parameter=block_validation_output.output_parameter,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
block_result = await extraction_block.execute_safe(
|
block_result = await extraction_block.execute_safe(
|
||||||
workflow_run_id=block_validation_output.workflow_run_id,
|
workflow_run_id=block_validation_output.workflow_run_id,
|
||||||
@@ -1568,6 +1584,7 @@ async def validate(
|
|||||||
terminate_criterion: str | None = None,
|
terminate_criterion: str | None = None,
|
||||||
error_code_mapping: dict[str, str] | None = None,
|
error_code_mapping: dict[str, str] | None = None,
|
||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
|
model: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validate function that behaves like a ValidationBlock"""
|
"""Validate function that behaves like a ValidationBlock"""
|
||||||
if not complete_criterion and not terminate_criterion:
|
if not complete_criterion and not terminate_criterion:
|
||||||
@@ -1582,6 +1599,7 @@ async def validate(
|
|||||||
terminate_criterion=terminate_criterion,
|
terminate_criterion=terminate_criterion,
|
||||||
error_code_mapping=error_code_mapping,
|
error_code_mapping=error_code_mapping,
|
||||||
max_steps_per_run=2,
|
max_steps_per_run=2,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
result = await validation_block.execute_safe(
|
result = await validation_block.execute_safe(
|
||||||
workflow_run_id=block_validation_output.workflow_run_id,
|
workflow_run_id=block_validation_output.workflow_run_id,
|
||||||
@@ -1927,6 +1945,7 @@ async def parse_file(
|
|||||||
schema: dict[str, Any] | None = None,
|
schema: dict[str, Any] | None = None,
|
||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
parameters: list[str] | None = None,
|
parameters: list[str] | None = None,
|
||||||
|
model: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
|
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
|
||||||
file_url = _render_template_with_label(file_url, label)
|
file_url = _render_template_with_label(file_url, label)
|
||||||
@@ -1937,6 +1956,7 @@ async def parse_file(
|
|||||||
label=block_validation_output.label,
|
label=block_validation_output.label,
|
||||||
output_parameter=block_validation_output.output_parameter,
|
output_parameter=block_validation_output.output_parameter,
|
||||||
parameters=block_validation_output.input_parameters,
|
parameters=block_validation_output.input_parameters,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
await file_parser_block.execute_safe(
|
await file_parser_block.execute_safe(
|
||||||
workflow_run_id=block_validation_output.workflow_run_id,
|
workflow_run_id=block_validation_output.workflow_run_id,
|
||||||
@@ -2008,6 +2028,7 @@ async def prompt(
|
|||||||
schema: dict[str, Any] | None = None,
|
schema: dict[str, Any] | None = None,
|
||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
parameters: list[str] | None = None,
|
parameters: list[str] | None = None,
|
||||||
|
model: dict[str, Any] | None = None,
|
||||||
) -> dict[str, Any] | list | str | None:
|
) -> dict[str, Any] | list | str | None:
|
||||||
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
|
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
|
||||||
prompt = _render_template_with_label(prompt, label)
|
prompt = _render_template_with_label(prompt, label)
|
||||||
@@ -2017,6 +2038,7 @@ async def prompt(
|
|||||||
label=block_validation_output.label,
|
label=block_validation_output.label,
|
||||||
output_parameter=block_validation_output.output_parameter,
|
output_parameter=block_validation_output.output_parameter,
|
||||||
parameters=block_validation_output.input_parameters,
|
parameters=block_validation_output.input_parameters,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
result = await prompt_block.execute_safe(
|
result = await prompt_block.execute_safe(
|
||||||
workflow_run_id=block_validation_output.workflow_run_id,
|
workflow_run_id=block_validation_output.workflow_run_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user