Max retries per step configuration per org (#168)

This commit is contained in:
Kerem Yilmaz
2024-04-08 16:58:45 -07:00
committed by GitHub
parent 8e2aaa95d2
commit ffe917f2b5
7 changed files with 47 additions and 6 deletions

View File

@@ -227,7 +227,7 @@ class ForgeAgent(Agent):
# If the step failed, mark the step as failed and retry
if step.status == StepStatus.failed:
maybe_next_step = await self.handle_failed_step(task, step)
maybe_next_step = await self.handle_failed_step(organization, task, step)
# If there is no next step, it means that the task has failed
if maybe_next_step:
next_step = maybe_next_step
@@ -965,8 +965,14 @@ class ForgeAgent(Agent):
**updates,
)
async def handle_failed_step(self, task: Task, step: Step) -> Step | None:
if step.retry_index >= SettingsManager.get_settings().MAX_RETRIES_PER_STEP:
async def handle_failed_step(self, organization: Organization, task: Task, step: Step) -> Step | None:
max_retries_per_step = (
organization.max_retries_per_step
# we need to check by None because 0 is a valid value for max_retries_per_step
if organization.max_retries_per_step is not None
else SettingsManager.get_settings().MAX_RETRIES_PER_STEP
)
if step.retry_index >= max_retries_per_step:
LOG.warning(
"Step failed after max retries, marking task as failed",
task_id=task.task_id,
@@ -978,7 +984,7 @@ class ForgeAgent(Agent):
await self.update_task(
task,
TaskStatus.failed,
failure_reason=f"Max retries per step ({SettingsManager.get_settings().MAX_RETRIES_PER_STEP}) exceeded",
failure_reason=f"Max retries per step ({max_retries_per_step}) exceeded",
)
return None
else:

View File

@@ -418,12 +418,14 @@ class AgentDB:
organization_name: str,
webhook_callback_url: str | None = None,
max_steps_per_run: int | None = None,
max_retries_per_step: int | None = None,
) -> Organization:
async with self.Session() as session:
org = OrganizationModel(
organization_name=organization_name,
webhook_callback_url=webhook_callback_url,
max_steps_per_run=max_steps_per_run,
max_retries_per_step=max_retries_per_step,
)
session.add(org)
await session.commit()

View File

@@ -74,7 +74,8 @@ class OrganizationModel(Base):
organization_id = Column(String, primary_key=True, index=True, default=generate_org_id)
organization_name = Column(String, nullable=False)
webhook_callback_url = Column(UnicodeText)
max_steps_per_run = Column(Integer)
max_steps_per_run = Column(Integer, nullable=True)
max_retries_per_step = Column(Integer, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)

View File

@@ -104,6 +104,7 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
organization_name=org_model.organization_name,
webhook_callback_url=org_model.webhook_callback_url,
max_steps_per_run=org_model.max_steps_per_run,
max_retries_per_step=org_model.max_retries_per_step,
created_at=org_model.created_at,
modified_at=org_model.modified_at,
)

View File

@@ -117,6 +117,7 @@ class Organization(BaseModel):
organization_name: str
webhook_callback_url: str | None = None
max_steps_per_run: int | None = None
max_retries_per_step: int | None = None
created_at: datetime
modified_at: datetime