add ai_fallback to workflow_runs (#3581)

This commit is contained in:
Shuchang Zheng
2025-10-01 14:13:56 -07:00
committed by GitHub
parent db024d42bb
commit f97b53975f
7 changed files with 62 additions and 1 deletions

View File

@@ -0,0 +1,31 @@
"""add ai_fallback to workflow_runs
Revision ID: d36daac4941e
Revises: c50ee6f26432
Create Date: 2025-10-01 21:07:38.185580+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "d36daac4941e"
down_revision: Union[str, None] = "c50ee6f26432"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("workflow_runs", sa.Column("ai_fallback", sa.Boolean(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("workflow_runs", "ai_fallback")
# ### end Alembic commands ###

View File

@@ -1699,6 +1699,7 @@ class AgentDB:
sequential_key: str | None = None, sequential_key: str | None = None,
run_with: str | None = None, run_with: str | None = None,
debug_session_id: str | None = None, debug_session_id: str | None = None,
ai_fallback: bool | None = None,
) -> WorkflowRun: ) -> WorkflowRun:
try: try:
async with self.Session() as session: async with self.Session() as session:
@@ -1719,6 +1720,7 @@ class AgentDB:
sequential_key=sequential_key, sequential_key=sequential_key,
run_with=run_with, run_with=run_with,
debug_session_id=debug_session_id, debug_session_id=debug_session_id,
ai_fallback=ai_fallback,
) )
session.add(workflow_run) session.add(workflow_run)
await session.commit() await session.commit()
@@ -1738,6 +1740,7 @@ class AgentDB:
job_id: str | None = None, job_id: str | None = None,
run_with: str | None = None, run_with: str | None = None,
sequential_key: str | None = None, sequential_key: str | None = None,
ai_fallback: bool | None = None,
) -> WorkflowRun: ) -> WorkflowRun:
async with self.Session() as session: async with self.Session() as session:
workflow_run = ( workflow_run = (
@@ -1764,6 +1767,8 @@ class AgentDB:
workflow_run.run_with = run_with workflow_run.run_with = run_with
if sequential_key: if sequential_key:
workflow_run.sequential_key = sequential_key workflow_run.sequential_key = sequential_key
if ai_fallback is not None:
workflow_run.ai_fallback = ai_fallback
await session.commit() await session.commit()
await session.refresh(workflow_run) await session.refresh(workflow_run)
await save_workflow_run_logs(workflow_run_id) await save_workflow_run_logs(workflow_run_id)

View File

@@ -291,6 +291,7 @@ class WorkflowRunModel(Base):
sequential_key = Column(String, nullable=True) sequential_key = Column(String, nullable=True)
run_with = Column(String, nullable=True) # 'agent' or 'code' run_with = Column(String, nullable=True) # 'agent' or 'code'
debug_session_id: Column = Column(String, nullable=True) debug_session_id: Column = Column(String, nullable=True)
ai_fallback = Column(Boolean, nullable=True)
queued_at = Column(DateTime, nullable=True) queued_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True) started_at = Column(DateTime, nullable=True)

View File

@@ -27,6 +27,7 @@ class WorkflowRequestBody(BaseModel):
extra_http_headers: dict[str, str] | None = None extra_http_headers: dict[str, str] | None = None
browser_address: str | None = None browser_address: str | None = None
run_with: str | None = None run_with: str | None = None
ai_fallback: bool | None = None
@field_validator("webhook_callback_url", "totp_verification_url") @field_validator("webhook_callback_url", "totp_verification_url")
@classmethod @classmethod
@@ -143,6 +144,7 @@ class WorkflowRun(BaseModel):
script_run: ScriptRunResponse | None = None script_run: ScriptRunResponse | None = None
job_id: str | None = None job_id: str | None = None
sequential_key: str | None = None sequential_key: str | None = None
ai_fallback: bool | None = None
queued_at: datetime | None = None queued_at: datetime | None = None
started_at: datetime | None = None started_at: datetime | None = None

View File

@@ -241,6 +241,8 @@ class WorkflowService:
proxy_location=workflow_request.proxy_location, proxy_location=workflow_request.proxy_location,
webhook_callback_url=workflow_request.webhook_callback_url, webhook_callback_url=workflow_request.webhook_callback_url,
max_screenshot_scrolling_times=workflow_request.max_screenshot_scrolls, max_screenshot_scrolling_times=workflow_request.max_screenshot_scrolls,
ai_fallback=workflow_request.ai_fallback,
run_with=workflow_request.run_with,
) )
context: skyvern_context.SkyvernContext | None = skyvern_context.current() context: skyvern_context.SkyvernContext | None = skyvern_context.current()
current_run_id = context.run_id if context and context.run_id else workflow_run.workflow_run_id current_run_id = context.run_id if context and context.run_id else workflow_run.workflow_run_id
@@ -1071,6 +1073,7 @@ class WorkflowService:
sequential_key=sequential_key, sequential_key=sequential_key,
run_with=workflow_request.run_with, run_with=workflow_request.run_with,
debug_session_id=debug_session_id, debug_session_id=debug_session_id,
ai_fallback=workflow_request.ai_fallback,
) )
async def _update_workflow_run_status( async def _update_workflow_run_status(
@@ -1079,12 +1082,14 @@ class WorkflowService:
status: WorkflowRunStatus, status: WorkflowRunStatus,
failure_reason: str | None = None, failure_reason: str | None = None,
run_with: str | None = None, run_with: str | None = None,
ai_fallback: bool | None = None,
) -> WorkflowRun: ) -> WorkflowRun:
workflow_run = await app.DATABASE.update_workflow_run( workflow_run = await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
status=status, status=status,
failure_reason=failure_reason, failure_reason=failure_reason,
run_with=run_with, run_with=run_with,
ai_fallback=ai_fallback,
) )
if status in [WorkflowRunStatus.completed, WorkflowRunStatus.failed, WorkflowRunStatus.terminated]: if status in [WorkflowRunStatus.completed, WorkflowRunStatus.failed, WorkflowRunStatus.terminated]:
start_time = ( start_time = (

View File

@@ -358,6 +358,14 @@ class WorkflowRunRequest(BaseModel):
description="The CDP address for the workflow run.", description="The CDP address for the workflow run.",
examples=["http://127.0.0.1:9222", "ws://127.0.0.1:9222/devtools/browser/1234567890"], examples=["http://127.0.0.1:9222", "ws://127.0.0.1:9222/devtools/browser/1234567890"],
) )
ai_fallback: bool | None = Field(
default=None,
description="Whether to fallback to AI if the workflow run fails.",
)
run_with: str | None = Field(
default=None,
description="Whether to run the workflow with agent or code.",
)
@field_validator("webhook_url", "totp_url") @field_validator("webhook_url", "totp_url")
@classmethod @classmethod

View File

@@ -700,7 +700,16 @@ async def _fallback_to_ai_run(
workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id) workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id)
if not workflow: if not workflow:
return return
if not workflow.ai_fallback: workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
if not workflow_run:
return
# Use workflow_run.ai_fallback if explicitly set, otherwise fall back to workflow.ai_fallback
effective_ai_fallback = (
workflow_run.ai_fallback if workflow_run.ai_fallback is not None else workflow.ai_fallback
)
if not effective_ai_fallback:
LOG.info( LOG.info(
"AI fallback is not enabled for the workflow", "AI fallback is not enabled for the workflow",
workflow_id=workflow_id, workflow_id=workflow_id,