Max retries per step configuration per org (#168)
This commit is contained in:
@@ -0,0 +1,30 @@
|
|||||||
|
"""Add orgs.max_retries_per_step
|
||||||
|
|
||||||
|
Revision ID: ea8e24d0bc8e
|
||||||
|
Revises: 4630ab8c198e
|
||||||
|
Create Date: 2024-04-08 23:47:46.306300+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "ea8e24d0bc8e"
|
||||||
|
down_revision: Union[str, None] = "4630ab8c198e"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column("organizations", sa.Column("max_retries_per_step", sa.Integer(), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("organizations", "max_retries_per_step")
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -19,7 +19,7 @@ class Settings(BaseSettings):
|
|||||||
# Ratio should be between 0 and 1.
|
# Ratio should be between 0 and 1.
|
||||||
# If the task has been running for more steps than this ratio of the max steps per run, then we'll log a warning.
|
# If the task has been running for more steps than this ratio of the max steps per run, then we'll log a warning.
|
||||||
LONG_RUNNING_TASK_WARNING_RATIO: float = 0.95
|
LONG_RUNNING_TASK_WARNING_RATIO: float = 0.95
|
||||||
MAX_RETRIES_PER_STEP: int = 2
|
MAX_RETRIES_PER_STEP: int = 5
|
||||||
DEBUG_MODE: bool = False
|
DEBUG_MODE: bool = False
|
||||||
DATABASE_STRING: str = "postgresql+psycopg://skyvern@localhost/skyvern"
|
DATABASE_STRING: str = "postgresql+psycopg://skyvern@localhost/skyvern"
|
||||||
PROMPT_ACTION_HISTORY_WINDOW: int = 5
|
PROMPT_ACTION_HISTORY_WINDOW: int = 5
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ class ForgeAgent(Agent):
|
|||||||
|
|
||||||
# If the step failed, mark the step as failed and retry
|
# If the step failed, mark the step as failed and retry
|
||||||
if step.status == StepStatus.failed:
|
if step.status == StepStatus.failed:
|
||||||
maybe_next_step = await self.handle_failed_step(task, step)
|
maybe_next_step = await self.handle_failed_step(organization, task, step)
|
||||||
# If there is no next step, it means that the task has failed
|
# If there is no next step, it means that the task has failed
|
||||||
if maybe_next_step:
|
if maybe_next_step:
|
||||||
next_step = maybe_next_step
|
next_step = maybe_next_step
|
||||||
@@ -965,8 +965,14 @@ class ForgeAgent(Agent):
|
|||||||
**updates,
|
**updates,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_failed_step(self, task: Task, step: Step) -> Step | None:
|
async def handle_failed_step(self, organization: Organization, task: Task, step: Step) -> Step | None:
|
||||||
if step.retry_index >= SettingsManager.get_settings().MAX_RETRIES_PER_STEP:
|
max_retries_per_step = (
|
||||||
|
organization.max_retries_per_step
|
||||||
|
# we need to check by None because 0 is a valid value for max_retries_per_step
|
||||||
|
if organization.max_retries_per_step is not None
|
||||||
|
else SettingsManager.get_settings().MAX_RETRIES_PER_STEP
|
||||||
|
)
|
||||||
|
if step.retry_index >= max_retries_per_step:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Step failed after max retries, marking task as failed",
|
"Step failed after max retries, marking task as failed",
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
@@ -978,7 +984,7 @@ class ForgeAgent(Agent):
|
|||||||
await self.update_task(
|
await self.update_task(
|
||||||
task,
|
task,
|
||||||
TaskStatus.failed,
|
TaskStatus.failed,
|
||||||
failure_reason=f"Max retries per step ({SettingsManager.get_settings().MAX_RETRIES_PER_STEP}) exceeded",
|
failure_reason=f"Max retries per step ({max_retries_per_step}) exceeded",
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -418,12 +418,14 @@ class AgentDB:
|
|||||||
organization_name: str,
|
organization_name: str,
|
||||||
webhook_callback_url: str | None = None,
|
webhook_callback_url: str | None = None,
|
||||||
max_steps_per_run: int | None = None,
|
max_steps_per_run: int | None = None,
|
||||||
|
max_retries_per_step: int | None = None,
|
||||||
) -> Organization:
|
) -> Organization:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
org = OrganizationModel(
|
org = OrganizationModel(
|
||||||
organization_name=organization_name,
|
organization_name=organization_name,
|
||||||
webhook_callback_url=webhook_callback_url,
|
webhook_callback_url=webhook_callback_url,
|
||||||
max_steps_per_run=max_steps_per_run,
|
max_steps_per_run=max_steps_per_run,
|
||||||
|
max_retries_per_step=max_retries_per_step,
|
||||||
)
|
)
|
||||||
session.add(org)
|
session.add(org)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -74,7 +74,8 @@ class OrganizationModel(Base):
|
|||||||
organization_id = Column(String, primary_key=True, index=True, default=generate_org_id)
|
organization_id = Column(String, primary_key=True, index=True, default=generate_org_id)
|
||||||
organization_name = Column(String, nullable=False)
|
organization_name = Column(String, nullable=False)
|
||||||
webhook_callback_url = Column(UnicodeText)
|
webhook_callback_url = Column(UnicodeText)
|
||||||
max_steps_per_run = Column(Integer)
|
max_steps_per_run = Column(Integer, nullable=True)
|
||||||
|
max_retries_per_step = Column(Integer, nullable=True)
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
|
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime, nullable=False)
|
||||||
|
|
||||||
|
|||||||
@@ -104,6 +104,7 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
|||||||
organization_name=org_model.organization_name,
|
organization_name=org_model.organization_name,
|
||||||
webhook_callback_url=org_model.webhook_callback_url,
|
webhook_callback_url=org_model.webhook_callback_url,
|
||||||
max_steps_per_run=org_model.max_steps_per_run,
|
max_steps_per_run=org_model.max_steps_per_run,
|
||||||
|
max_retries_per_step=org_model.max_retries_per_step,
|
||||||
created_at=org_model.created_at,
|
created_at=org_model.created_at,
|
||||||
modified_at=org_model.modified_at,
|
modified_at=org_model.modified_at,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ class Organization(BaseModel):
|
|||||||
organization_name: str
|
organization_name: str
|
||||||
webhook_callback_url: str | None = None
|
webhook_callback_url: str | None = None
|
||||||
max_steps_per_run: int | None = None
|
max_steps_per_run: int | None = None
|
||||||
|
max_retries_per_step: int | None = None
|
||||||
|
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
|
|||||||
Reference in New Issue
Block a user