add max_steps_per_run to task (#297)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-05-11 14:13:21 -07:00
committed by GitHub
parent 6feddbde6a
commit 270642c60c
8 changed files with 39 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
"""add max_steps_per_run to task
Revision ID: 8792454ce498
Revises: c4dca14a5e69
Create Date: 2024-05-11 21:04:38.384261+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "8792454ce498"
down_revision: Union[str, None] = "c4dca14a5e69"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tasks", sa.Column("max_steps_per_run", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("tasks", "max_steps_per_run")
# ### end Alembic commands ###

View File

@@ -136,6 +136,7 @@ class ForgeAgent:
workflow_run_id=workflow_run.workflow_run_id,
order=task_order,
retry=task_retry,
max_steps_per_run=task_block.max_steps_per_run,
error_code_mapping=task_block.error_code_mapping,
)
LOG.info(
@@ -1116,6 +1117,7 @@ class ForgeAgent:
override_max_steps_per_run = context.max_steps_override if context else None
max_steps_per_run = (
override_max_steps_per_run
or task.max_steps_per_run
or organization.max_steps_per_run
or SettingsManager.get_settings().MAX_STEPS_PER_RUN
)

View File

@@ -83,6 +83,7 @@ class AgentDB:
workflow_run_id: str | None = None,
order: int | None = None,
retry: int | None = None,
max_steps_per_run: int | None = None,
error_code_mapping: dict[str, str] | None = None,
) -> Task:
try:
@@ -101,6 +102,7 @@ class AgentDB:
workflow_run_id=workflow_run_id,
order=order,
retry=retry,
max_steps_per_run=max_steps_per_run,
error_code_mapping=error_code_mapping,
)
session.add(new_task)

View File

@@ -46,6 +46,7 @@ class TaskModel(Base):
retry = Column(Integer, nullable=True)
error_code_mapping = Column(JSON, nullable=True)
errors = Column(JSON, default=[], nullable=False)
max_steps_per_run = Column(Integer, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False, index=True

View File

@@ -72,6 +72,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
workflow_run_id=task_obj.workflow_run_id,
order=task_obj.order,
retry=task_obj.retry,
max_steps_per_run=task_obj.max_steps_per_run,
error_code_mapping=task_obj.error_code_mapping,
errors=task_obj.errors,
)

View File

@@ -143,6 +143,7 @@ class Task(TaskRequest):
workflow_run_id: str | None = None
order: int | None = None
retry: int | None = None
max_steps_per_run: int | None = None
errors: list[dict[str, Any]] = []
def validate_update(

View File

@@ -111,6 +111,7 @@ class TaskBlock(Block):
# error code to error description for the LLM
error_code_mapping: dict[str, str] | None = None
max_retries: int = 0
max_steps_per_run: int | None = None
parameters: list[PARAMETER_TYPE] = []
def get_all_parameters(

View File

@@ -84,6 +84,7 @@ class TaskBlockYAML(BlockYAML):
data_schema: dict[str, Any] | None = None
error_code_mapping: dict[str, str] | None = None
max_retries: int = 0
max_steps_per_run: int | None = None
parameter_keys: list[str] | None = None