diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index 8460e5d5..fdea2c24 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -673,16 +673,30 @@ def _build_action_statement( last_line=cst.SimpleWhitespace(INDENT), ), ), - cst.Arg( - keyword=cst.Name("label"), - value=_value(block_title), - whitespace_after_arg=cst.ParenthesizedWhitespace( - indent=True, - ), - comma=cst.Comma(), - ), ] - + if block.get("model"): + args.append( + cst.Arg( + keyword=cst.Name("model"), + value=_value(block.get("model")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ) + ) + if block.get("label"): + args.append( + cst.Arg( + keyword=cst.Name("label"), + value=_value(block.get("label")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + ), + comma=cst.Comma(), + ) + ) + _mark_last_arg_as_comma(args) call = cst.Call( func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("action")), args=args, @@ -733,15 +747,30 @@ def _build_extract_statement( last_line=cst.SimpleWhitespace(INDENT), ), ), - cst.Arg( - keyword=cst.Name("label"), - value=_value(block_title), - whitespace_after_arg=cst.ParenthesizedWhitespace( - indent=True, - ), - comma=cst.Comma(), - ), ] + if block.get("model"): + args.append( + cst.Arg( + keyword=cst.Name("model"), + value=_value(block.get("model")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ) + ) + if block.get("label"): + args.append( + cst.Arg( + keyword=cst.Name("label"), + value=_value(block_title), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + ), + comma=cst.Comma(), + ) + ) + _mark_last_arg_as_comma(args) call = cst.Call( func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("extract")), @@ -882,6 +911,18 @@ def _build_validate_statement( ) ) + if block.get("model"): + args.append( + cst.Arg( + keyword=cst.Name("model"), + value=_value(block.get("model")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ) + ) + # Add label if it exists if block.get("label") is not None: args.append( @@ -1114,6 +1155,18 @@ def _build_pdf_parser_statement(block: dict[str, Any]) -> cst.SimpleStatementLin ) ) + if block.get("model"): + args.append( + cst.Arg( + keyword=cst.Name("model"), + value=_value(block.get("model")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ) + ) + if block.get("label") is not None: args.append( cst.Arg( @@ -1172,6 +1225,18 @@ def _build_file_url_parser_statement(block: dict[str, Any]) -> cst.SimpleStateme ) ) + if block.get("model"): + args.append( + cst.Arg( + keyword=cst.Name("model"), + value=_value(block.get("model")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ) + ) + if block.get("label") is not None: args.append( cst.Arg( @@ -1319,6 +1384,18 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine: ) ) + if block.get("model"): + args.append( + cst.Arg( + keyword=cst.Name("model"), + value=_value(block.get("model")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ) + ) + if block.get("label") is not None: args.append( cst.Arg( @@ -1331,7 +1408,7 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine: ) ) - if block.get("parameters") is not None: + if block.get("parameters"): parameters = block.get("parameters", []) parameter_list = [parameter["key"] for parameter in parameters] args.append( @@ -1340,10 +1417,12 @@ def _build_prompt_statement(block: dict[str, Any]) -> cst.SimpleStatementLine: value=_value(parameter_list), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, + last_line=cst.SimpleWhitespace(INDENT), ), ) ) + _mark_last_arg_as_comma(args) call = cst.Call( func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("prompt")), args=args, @@ -1568,6 +1647,17 @@ def __build_base_task_statement( ), ) ) + if block.get("model"): + args.append( + cst.Arg( + keyword=cst.Name("model"), + value=_value(block.get("model")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ) + ) if block.get("block_type") == "task_v2": args.append( cst.Arg( diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 4c8626f3..7a82e1e9 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -130,7 +130,6 @@ CacheInvalidationReason = Literal["updated_block", "new_block", "removed_block"] BLOCK_TYPES_THAT_SHOULD_BE_CACHED = { BlockType.TASK, BlockType.TaskV2, - BlockType.VALIDATION, BlockType.ACTION, BlockType.NAVIGATION, BlockType.EXTRACTION, diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index 2541b17c..c3fe816a 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -399,6 +399,7 @@ async def _create_workflow_block_run_and_task( schema: dict[str, Any] | list | str | None = None, url: str | None = None, label: str | None = None, + model: dict[str, Any] | None = None, ) -> tuple[str | None, str | None, str | None]: """ Create a workflow block run and optionally a task if workflow_run_id is available in context. @@ -449,6 +450,7 @@ async def _create_workflow_block_run_and_task( status="running", organization_id=organization_id, workflow_run_id=workflow_run_id, + model=model, ) task_id = task.task_id @@ -1210,6 +1212,7 @@ async def run_task( prompt=prompt, url=url, label=cache_key, + model=model, ) prompt = _render_template_with_label(prompt, cache_key) # set the prompt in the RunContext @@ -1256,6 +1259,7 @@ async def run_task( totp_verification_url=totp_url, include_action_history_in_verification=True, engine=RunEngine.skyvern_v1, + model=model, ) await task_block.execute_safe( workflow_run_id=block_validation_output.workflow_run_id, @@ -1275,6 +1279,7 @@ async def download( totp_url: str | None = None, label: str | None = None, cache_key: str | None = None, + model: dict[str, Any] | None = None, ) -> None: cache_key = cache_key or label cached_fn = script_run_context_manager.get_cached_fn(cache_key) @@ -1286,6 +1291,7 @@ async def download( prompt=prompt, url=url, label=cache_key, + model=model, ) prompt = _render_template_with_label(prompt, cache_key) # set the prompt in the RunContext @@ -1332,6 +1338,7 @@ async def download( totp_verification_url=totp_url, include_action_history_in_verification=True, engine=RunEngine.skyvern_v1, + model=model, ) await file_download_block.execute_safe( workflow_run_id=block_validation_output.workflow_run_id, @@ -1350,6 +1357,7 @@ async def action( totp_url: str | None = None, label: str | None = None, cache_key: str | None = None, + model: dict[str, Any] | None = None, ) -> None: context: skyvern_context.SkyvernContext | None cache_key = cache_key or label @@ -1361,6 +1369,7 @@ async def action( prompt=prompt, url=url, label=cache_key, + model=model, ) prompt = _render_template_with_label(prompt, cache_key) # set the prompt in the RunContext @@ -1406,6 +1415,7 @@ async def action( max_steps_per_run=max_steps, totp_identifier=totp_identifier, totp_verification_url=totp_url, + model=model, ) await action_block.execute_safe( workflow_run_id=block_validation_output.workflow_run_id, @@ -1423,6 +1433,7 @@ async def login( totp_url: str | None = None, label: str | None = None, cache_key: str | None = None, + model: dict[str, Any] | None = None, ) -> None: context: skyvern_context.SkyvernContext | None cache_key = cache_key or label @@ -1435,6 +1446,7 @@ async def login( prompt=prompt, url=url, label=cache_key, + model=model, ) prompt = _render_template_with_label(prompt, cache_key) # set the prompt in the RunContext @@ -1478,6 +1490,7 @@ async def login( max_steps_per_run=max_steps, totp_identifier=totp_identifier, totp_verification_url=totp_url, + model=model, ) await login_block.execute_safe( workflow_run_id=block_validation_output.workflow_run_id, @@ -1494,6 +1507,7 @@ async def extract( max_steps: int | None = None, label: str | None = None, cache_key: str | None = None, + model: dict[str, Any] | None = None, ) -> dict[str, Any] | list | str | None: output: dict[str, Any] | list | str | None = None @@ -1508,6 +1522,7 @@ async def extract( schema=schema, url=url, label=cache_key, + model=model, ) prompt = _render_template_with_label(prompt, cache_key) # set the prompt in the RunContext @@ -1553,6 +1568,7 @@ async def extract( max_steps_per_run=max_steps, data_schema=schema, output_parameter=block_validation_output.output_parameter, + model=model, ) block_result = await extraction_block.execute_safe( workflow_run_id=block_validation_output.workflow_run_id, @@ -1568,6 +1584,7 @@ async def validate( terminate_criterion: str | None = None, error_code_mapping: dict[str, str] | None = None, label: str | None = None, + model: dict[str, Any] | None = None, ) -> None: """Validate function that behaves like a ValidationBlock""" if not complete_criterion and not terminate_criterion: @@ -1582,6 +1599,7 @@ async def validate( terminate_criterion=terminate_criterion, error_code_mapping=error_code_mapping, max_steps_per_run=2, + model=model, ) result = await validation_block.execute_safe( workflow_run_id=block_validation_output.workflow_run_id, @@ -1927,6 +1945,7 @@ async def parse_file( schema: dict[str, Any] | None = None, label: str | None = None, parameters: list[str] | None = None, + model: dict[str, Any] | None = None, ) -> None: block_validation_output = await _validate_and_get_output_parameter(label, parameters) file_url = _render_template_with_label(file_url, label) @@ -1937,6 +1956,7 @@ async def parse_file( label=block_validation_output.label, output_parameter=block_validation_output.output_parameter, parameters=block_validation_output.input_parameters, + model=model, ) await file_parser_block.execute_safe( workflow_run_id=block_validation_output.workflow_run_id, @@ -2008,6 +2028,7 @@ async def prompt( schema: dict[str, Any] | None = None, label: str | None = None, parameters: list[str] | None = None, + model: dict[str, Any] | None = None, ) -> dict[str, Any] | list | str | None: block_validation_output = await _validate_and_get_output_parameter(label, parameters) prompt = _render_template_with_label(prompt, label) @@ -2017,6 +2038,7 @@ async def prompt( label=block_validation_output.label, output_parameter=block_validation_output.output_parameter, parameters=block_validation_output.input_parameters, + model=model, ) result = await prompt_block.execute_safe( workflow_run_id=block_validation_output.workflow_run_id,