From 267335a0ebd482e963e91cd6d2154ac15409eef6 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Thu, 18 Jul 2024 17:00:00 -0700 Subject: [PATCH] record max_steps_per_run override in tasks table as well (#622) --- skyvern/forge/agent.py | 14 ++++++++++++++ skyvern/forge/sdk/db/client.py | 3 +++ skyvern/forge/sdk/schemas/tasks.py | 2 ++ 3 files changed, 19 insertions(+) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 5a74efae..f5875855 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -232,6 +232,20 @@ class ForgeAgent: # TODO: shall we send task response here? return step, None, None + context = skyvern_context.current() + override_max_steps_per_run = context.max_steps_override if context else None + max_steps_per_run = ( + override_max_steps_per_run + or task.max_steps_per_run + or organization.max_steps_per_run + or SettingsManager.get_settings().MAX_STEPS_PER_RUN + ) + if max_steps_per_run and task.max_steps_per_run != max_steps_per_run: + await app.DATABASE.update_task( + task_id=task.task_id, + organization_id=organization.organization_id, + max_steps_per_run=max_steps_per_run, + ) next_step: Step | None = None detailed_output: DetailedAgentStepOutput | None = None num_files_before = 0 diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index ede377f4..62460c05 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -376,6 +376,7 @@ class AgentDB: extracted_information: dict[str, Any] | list | str | None = None, failure_reason: str | None = None, errors: list[dict[str, Any]] | None = None, + max_steps_per_run: int | None = None, organization_id: str | None = None, ) -> Task: if status is None and extracted_information is None and failure_reason is None and errors is None: @@ -397,6 +398,8 @@ class AgentDB: task.failure_reason = failure_reason if errors is not None: task.errors = errors + if max_steps_per_run is not None: + task.max_steps_per_run = max_steps_per_run await session.commit() updated_task = await self.get_task(task_id, organization_id=organization_id) if not updated_task: diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index c4dcd584..4c704d76 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -221,6 +221,7 @@ class Task(TaskRequest): screenshot_url=screenshot_url, recording_url=recording_url, errors=self.errors, + max_steps_per_run=self.max_steps_per_run, ) @@ -236,6 +237,7 @@ class TaskResponse(BaseModel): recording_url: str | None = None failure_reason: str | None = None errors: list[dict[str, Any]] = [] + max_steps_per_run: int | None = None class TaskOutput(BaseModel):