add ai_fallback to workflow_runs (#3581)

This commit is contained in:
Shuchang Zheng
2025-10-01 14:13:56 -07:00
committed by GitHub
parent db024d42bb
commit f97b53975f
7 changed files with 62 additions and 1 deletions

View File

@@ -1699,6 +1699,7 @@ class AgentDB:
sequential_key: str | None = None,
run_with: str | None = None,
debug_session_id: str | None = None,
ai_fallback: bool | None = None,
) -> WorkflowRun:
try:
async with self.Session() as session:
@@ -1719,6 +1720,7 @@ class AgentDB:
sequential_key=sequential_key,
run_with=run_with,
debug_session_id=debug_session_id,
ai_fallback=ai_fallback,
)
session.add(workflow_run)
await session.commit()
@@ -1738,6 +1740,7 @@ class AgentDB:
job_id: str | None = None,
run_with: str | None = None,
sequential_key: str | None = None,
ai_fallback: bool | None = None,
) -> WorkflowRun:
async with self.Session() as session:
workflow_run = (
@@ -1764,6 +1767,8 @@ class AgentDB:
workflow_run.run_with = run_with
if sequential_key:
workflow_run.sequential_key = sequential_key
if ai_fallback is not None:
workflow_run.ai_fallback = ai_fallback
await session.commit()
await session.refresh(workflow_run)
await save_workflow_run_logs(workflow_run_id)

View File

@@ -291,6 +291,7 @@ class WorkflowRunModel(Base):
sequential_key = Column(String, nullable=True)
run_with = Column(String, nullable=True) # 'agent' or 'code'
debug_session_id: Column = Column(String, nullable=True)
ai_fallback = Column(Boolean, nullable=True)
queued_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)

View File

@@ -27,6 +27,7 @@ class WorkflowRequestBody(BaseModel):
extra_http_headers: dict[str, str] | None = None
browser_address: str | None = None
run_with: str | None = None
ai_fallback: bool | None = None
@field_validator("webhook_callback_url", "totp_verification_url")
@classmethod
@@ -143,6 +144,7 @@ class WorkflowRun(BaseModel):
script_run: ScriptRunResponse | None = None
job_id: str | None = None
sequential_key: str | None = None
ai_fallback: bool | None = None
queued_at: datetime | None = None
started_at: datetime | None = None

View File

@@ -241,6 +241,8 @@ class WorkflowService:
proxy_location=workflow_request.proxy_location,
webhook_callback_url=workflow_request.webhook_callback_url,
max_screenshot_scrolling_times=workflow_request.max_screenshot_scrolls,
ai_fallback=workflow_request.ai_fallback,
run_with=workflow_request.run_with,
)
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
current_run_id = context.run_id if context and context.run_id else workflow_run.workflow_run_id
@@ -1071,6 +1073,7 @@ class WorkflowService:
sequential_key=sequential_key,
run_with=workflow_request.run_with,
debug_session_id=debug_session_id,
ai_fallback=workflow_request.ai_fallback,
)
async def _update_workflow_run_status(
@@ -1079,12 +1082,14 @@ class WorkflowService:
status: WorkflowRunStatus,
failure_reason: str | None = None,
run_with: str | None = None,
ai_fallback: bool | None = None,
) -> WorkflowRun:
workflow_run = await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
status=status,
failure_reason=failure_reason,
run_with=run_with,
ai_fallback=ai_fallback,
)
if status in [WorkflowRunStatus.completed, WorkflowRunStatus.failed, WorkflowRunStatus.terminated]:
start_time = (