cache task run uses block level model override when ai fallback happens (#4073)
This commit is contained in:
@@ -399,6 +399,7 @@ async def _create_workflow_block_run_and_task(
|
||||
schema: dict[str, Any] | list | str | None = None,
|
||||
url: str | None = None,
|
||||
label: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
"""
|
||||
Create a workflow block run and optionally a task if workflow_run_id is available in context.
|
||||
@@ -449,6 +450,7 @@ async def _create_workflow_block_run_and_task(
|
||||
status="running",
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
model=model,
|
||||
)
|
||||
|
||||
task_id = task.task_id
|
||||
@@ -1210,6 +1212,7 @@ async def run_task(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1256,6 +1259,7 @@ async def run_task(
|
||||
totp_verification_url=totp_url,
|
||||
include_action_history_in_verification=True,
|
||||
engine=RunEngine.skyvern_v1,
|
||||
model=model,
|
||||
)
|
||||
await task_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1275,6 +1279,7 @@ async def download(
|
||||
totp_url: str | None = None,
|
||||
label: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
cache_key = cache_key or label
|
||||
cached_fn = script_run_context_manager.get_cached_fn(cache_key)
|
||||
@@ -1286,6 +1291,7 @@ async def download(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1332,6 +1338,7 @@ async def download(
|
||||
totp_verification_url=totp_url,
|
||||
include_action_history_in_verification=True,
|
||||
engine=RunEngine.skyvern_v1,
|
||||
model=model,
|
||||
)
|
||||
await file_download_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1350,6 +1357,7 @@ async def action(
|
||||
totp_url: str | None = None,
|
||||
label: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
context: skyvern_context.SkyvernContext | None
|
||||
cache_key = cache_key or label
|
||||
@@ -1361,6 +1369,7 @@ async def action(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1406,6 +1415,7 @@ async def action(
|
||||
max_steps_per_run=max_steps,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_verification_url=totp_url,
|
||||
model=model,
|
||||
)
|
||||
await action_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1423,6 +1433,7 @@ async def login(
|
||||
totp_url: str | None = None,
|
||||
label: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
context: skyvern_context.SkyvernContext | None
|
||||
cache_key = cache_key or label
|
||||
@@ -1435,6 +1446,7 @@ async def login(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1478,6 +1490,7 @@ async def login(
|
||||
max_steps_per_run=max_steps,
|
||||
totp_identifier=totp_identifier,
|
||||
totp_verification_url=totp_url,
|
||||
model=model,
|
||||
)
|
||||
await login_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1494,6 +1507,7 @@ async def extract(
|
||||
max_steps: int | None = None,
|
||||
label: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
output: dict[str, Any] | list | str | None = None
|
||||
|
||||
@@ -1508,6 +1522,7 @@ async def extract(
|
||||
schema=schema,
|
||||
url=url,
|
||||
label=cache_key,
|
||||
model=model,
|
||||
)
|
||||
prompt = _render_template_with_label(prompt, cache_key)
|
||||
# set the prompt in the RunContext
|
||||
@@ -1553,6 +1568,7 @@ async def extract(
|
||||
max_steps_per_run=max_steps,
|
||||
data_schema=schema,
|
||||
output_parameter=block_validation_output.output_parameter,
|
||||
model=model,
|
||||
)
|
||||
block_result = await extraction_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1568,6 +1584,7 @@ async def validate(
|
||||
terminate_criterion: str | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
label: str | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Validate function that behaves like a ValidationBlock"""
|
||||
if not complete_criterion and not terminate_criterion:
|
||||
@@ -1582,6 +1599,7 @@ async def validate(
|
||||
terminate_criterion=terminate_criterion,
|
||||
error_code_mapping=error_code_mapping,
|
||||
max_steps_per_run=2,
|
||||
model=model,
|
||||
)
|
||||
result = await validation_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -1927,6 +1945,7 @@ async def parse_file(
|
||||
schema: dict[str, Any] | None = None,
|
||||
label: str | None = None,
|
||||
parameters: list[str] | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
|
||||
file_url = _render_template_with_label(file_url, label)
|
||||
@@ -1937,6 +1956,7 @@ async def parse_file(
|
||||
label=block_validation_output.label,
|
||||
output_parameter=block_validation_output.output_parameter,
|
||||
parameters=block_validation_output.input_parameters,
|
||||
model=model,
|
||||
)
|
||||
await file_parser_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
@@ -2008,6 +2028,7 @@ async def prompt(
|
||||
schema: dict[str, Any] | None = None,
|
||||
label: str | None = None,
|
||||
parameters: list[str] | None = None,
|
||||
model: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list | str | None:
|
||||
block_validation_output = await _validate_and_get_output_parameter(label, parameters)
|
||||
prompt = _render_template_with_label(prompt, label)
|
||||
@@ -2017,6 +2038,7 @@ async def prompt(
|
||||
label=block_validation_output.label,
|
||||
output_parameter=block_validation_output.output_parameter,
|
||||
parameters=block_validation_output.input_parameters,
|
||||
model=model,
|
||||
)
|
||||
result = await prompt_block.execute_safe(
|
||||
workflow_run_id=block_validation_output.workflow_run_id,
|
||||
|
||||
Reference in New Issue
Block a user