diff --git a/alembic/versions/2024_05_11_2104-8792454ce498_add_max_steps_per_run_to_task.py b/alembic/versions/2024_05_11_2104-8792454ce498_add_max_steps_per_run_to_task.py new file mode 100644 index 00000000..ed3eb185 --- /dev/null +++ b/alembic/versions/2024_05_11_2104-8792454ce498_add_max_steps_per_run_to_task.py @@ -0,0 +1,30 @@ +"""add max_steps_per_run to task + +Revision ID: 8792454ce498 +Revises: c4dca14a5e69 +Create Date: 2024-05-11 21:04:38.384261+00:00 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "8792454ce498" +down_revision: Union[str, None] = "c4dca14a5e69" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("tasks", sa.Column("max_steps_per_run", sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("tasks", "max_steps_per_run") + # ### end Alembic commands ### diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index f16425bb..9f2c9d27 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -136,6 +136,7 @@ class ForgeAgent: workflow_run_id=workflow_run.workflow_run_id, order=task_order, retry=task_retry, + max_steps_per_run=task_block.max_steps_per_run, error_code_mapping=task_block.error_code_mapping, ) LOG.info( @@ -1116,6 +1117,7 @@ class ForgeAgent: override_max_steps_per_run = context.max_steps_override if context else None max_steps_per_run = ( override_max_steps_per_run + or task.max_steps_per_run or organization.max_steps_per_run or SettingsManager.get_settings().MAX_STEPS_PER_RUN ) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index ed8cd549..d3ba88fb 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -83,6 +83,7 @@ class AgentDB: workflow_run_id: str | None = None, order: int | None = None, retry: int | None = None, + max_steps_per_run: int | None = None, error_code_mapping: dict[str, str] | None = None, ) -> Task: try: @@ -101,6 +102,7 @@ class AgentDB: workflow_run_id=workflow_run_id, order=order, retry=retry, + max_steps_per_run=max_steps_per_run, error_code_mapping=error_code_mapping, ) session.add(new_task) diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index b23a1677..6e3a4f16 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -46,6 +46,7 @@ class TaskModel(Base): retry = Column(Integer, nullable=True) error_code_mapping = Column(JSON, nullable=True) errors = Column(JSON, default=[], nullable=False) + max_steps_per_run = Column(Integer, nullable=True) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True) modified_at = Column( DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False, index=True diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index d4f78637..23925fd3 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -72,6 +72,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task: workflow_run_id=task_obj.workflow_run_id, order=task_obj.order, retry=task_obj.retry, + max_steps_per_run=task_obj.max_steps_per_run, error_code_mapping=task_obj.error_code_mapping, errors=task_obj.errors, ) diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 4f907a29..84d32075 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -143,6 +143,7 @@ class Task(TaskRequest): workflow_run_id: str | None = None order: int | None = None retry: int | None = None + max_steps_per_run: int | None = None errors: list[dict[str, Any]] = [] def validate_update( diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 5747286b..f848bbb8 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -111,6 +111,7 @@ class TaskBlock(Block): # error code to error description for the LLM error_code_mapping: dict[str, str] | None = None max_retries: int = 0 + max_steps_per_run: int | None = None parameters: list[PARAMETER_TYPE] = [] def get_all_parameters( diff --git a/skyvern/forge/sdk/workflow/models/yaml.py b/skyvern/forge/sdk/workflow/models/yaml.py index c65e6dac..b70baf86 100644 --- a/skyvern/forge/sdk/workflow/models/yaml.py +++ b/skyvern/forge/sdk/workflow/models/yaml.py @@ -84,6 +84,7 @@ class TaskBlockYAML(BlockYAML): data_schema: dict[str, Any] | None = None error_code_mapping: dict[str, str] | None = None max_retries: int = 0 + max_steps_per_run: int | None = None parameter_keys: list[str] | None = None