add ai_fallback to workflow_runs (#3581)
This commit is contained in:
@@ -0,0 +1,31 @@
|
|||||||
|
"""add ai_fallback to workflow_runs
|
||||||
|
|
||||||
|
Revision ID: d36daac4941e
|
||||||
|
Revises: c50ee6f26432
|
||||||
|
Create Date: 2025-10-01 21:07:38.185580+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "d36daac4941e"
|
||||||
|
down_revision: Union[str, None] = "c50ee6f26432"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column("workflow_runs", sa.Column("ai_fallback", sa.Boolean(), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("workflow_runs", "ai_fallback")
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -1699,6 +1699,7 @@ class AgentDB:
|
|||||||
sequential_key: str | None = None,
|
sequential_key: str | None = None,
|
||||||
run_with: str | None = None,
|
run_with: str | None = None,
|
||||||
debug_session_id: str | None = None,
|
debug_session_id: str | None = None,
|
||||||
|
ai_fallback: bool | None = None,
|
||||||
) -> WorkflowRun:
|
) -> WorkflowRun:
|
||||||
try:
|
try:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
@@ -1719,6 +1720,7 @@ class AgentDB:
|
|||||||
sequential_key=sequential_key,
|
sequential_key=sequential_key,
|
||||||
run_with=run_with,
|
run_with=run_with,
|
||||||
debug_session_id=debug_session_id,
|
debug_session_id=debug_session_id,
|
||||||
|
ai_fallback=ai_fallback,
|
||||||
)
|
)
|
||||||
session.add(workflow_run)
|
session.add(workflow_run)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -1738,6 +1740,7 @@ class AgentDB:
|
|||||||
job_id: str | None = None,
|
job_id: str | None = None,
|
||||||
run_with: str | None = None,
|
run_with: str | None = None,
|
||||||
sequential_key: str | None = None,
|
sequential_key: str | None = None,
|
||||||
|
ai_fallback: bool | None = None,
|
||||||
) -> WorkflowRun:
|
) -> WorkflowRun:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
workflow_run = (
|
workflow_run = (
|
||||||
@@ -1764,6 +1767,8 @@ class AgentDB:
|
|||||||
workflow_run.run_with = run_with
|
workflow_run.run_with = run_with
|
||||||
if sequential_key:
|
if sequential_key:
|
||||||
workflow_run.sequential_key = sequential_key
|
workflow_run.sequential_key = sequential_key
|
||||||
|
if ai_fallback is not None:
|
||||||
|
workflow_run.ai_fallback = ai_fallback
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(workflow_run)
|
await session.refresh(workflow_run)
|
||||||
await save_workflow_run_logs(workflow_run_id)
|
await save_workflow_run_logs(workflow_run_id)
|
||||||
|
|||||||
@@ -291,6 +291,7 @@ class WorkflowRunModel(Base):
|
|||||||
sequential_key = Column(String, nullable=True)
|
sequential_key = Column(String, nullable=True)
|
||||||
run_with = Column(String, nullable=True) # 'agent' or 'code'
|
run_with = Column(String, nullable=True) # 'agent' or 'code'
|
||||||
debug_session_id: Column = Column(String, nullable=True)
|
debug_session_id: Column = Column(String, nullable=True)
|
||||||
|
ai_fallback = Column(Boolean, nullable=True)
|
||||||
|
|
||||||
queued_at = Column(DateTime, nullable=True)
|
queued_at = Column(DateTime, nullable=True)
|
||||||
started_at = Column(DateTime, nullable=True)
|
started_at = Column(DateTime, nullable=True)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class WorkflowRequestBody(BaseModel):
|
|||||||
extra_http_headers: dict[str, str] | None = None
|
extra_http_headers: dict[str, str] | None = None
|
||||||
browser_address: str | None = None
|
browser_address: str | None = None
|
||||||
run_with: str | None = None
|
run_with: str | None = None
|
||||||
|
ai_fallback: bool | None = None
|
||||||
|
|
||||||
@field_validator("webhook_callback_url", "totp_verification_url")
|
@field_validator("webhook_callback_url", "totp_verification_url")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -143,6 +144,7 @@ class WorkflowRun(BaseModel):
|
|||||||
script_run: ScriptRunResponse | None = None
|
script_run: ScriptRunResponse | None = None
|
||||||
job_id: str | None = None
|
job_id: str | None = None
|
||||||
sequential_key: str | None = None
|
sequential_key: str | None = None
|
||||||
|
ai_fallback: bool | None = None
|
||||||
|
|
||||||
queued_at: datetime | None = None
|
queued_at: datetime | None = None
|
||||||
started_at: datetime | None = None
|
started_at: datetime | None = None
|
||||||
|
|||||||
@@ -241,6 +241,8 @@ class WorkflowService:
|
|||||||
proxy_location=workflow_request.proxy_location,
|
proxy_location=workflow_request.proxy_location,
|
||||||
webhook_callback_url=workflow_request.webhook_callback_url,
|
webhook_callback_url=workflow_request.webhook_callback_url,
|
||||||
max_screenshot_scrolling_times=workflow_request.max_screenshot_scrolls,
|
max_screenshot_scrolling_times=workflow_request.max_screenshot_scrolls,
|
||||||
|
ai_fallback=workflow_request.ai_fallback,
|
||||||
|
run_with=workflow_request.run_with,
|
||||||
)
|
)
|
||||||
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
|
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
|
||||||
current_run_id = context.run_id if context and context.run_id else workflow_run.workflow_run_id
|
current_run_id = context.run_id if context and context.run_id else workflow_run.workflow_run_id
|
||||||
@@ -1071,6 +1073,7 @@ class WorkflowService:
|
|||||||
sequential_key=sequential_key,
|
sequential_key=sequential_key,
|
||||||
run_with=workflow_request.run_with,
|
run_with=workflow_request.run_with,
|
||||||
debug_session_id=debug_session_id,
|
debug_session_id=debug_session_id,
|
||||||
|
ai_fallback=workflow_request.ai_fallback,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _update_workflow_run_status(
|
async def _update_workflow_run_status(
|
||||||
@@ -1079,12 +1082,14 @@ class WorkflowService:
|
|||||||
status: WorkflowRunStatus,
|
status: WorkflowRunStatus,
|
||||||
failure_reason: str | None = None,
|
failure_reason: str | None = None,
|
||||||
run_with: str | None = None,
|
run_with: str | None = None,
|
||||||
|
ai_fallback: bool | None = None,
|
||||||
) -> WorkflowRun:
|
) -> WorkflowRun:
|
||||||
workflow_run = await app.DATABASE.update_workflow_run(
|
workflow_run = await app.DATABASE.update_workflow_run(
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
status=status,
|
status=status,
|
||||||
failure_reason=failure_reason,
|
failure_reason=failure_reason,
|
||||||
run_with=run_with,
|
run_with=run_with,
|
||||||
|
ai_fallback=ai_fallback,
|
||||||
)
|
)
|
||||||
if status in [WorkflowRunStatus.completed, WorkflowRunStatus.failed, WorkflowRunStatus.terminated]:
|
if status in [WorkflowRunStatus.completed, WorkflowRunStatus.failed, WorkflowRunStatus.terminated]:
|
||||||
start_time = (
|
start_time = (
|
||||||
|
|||||||
@@ -358,6 +358,14 @@ class WorkflowRunRequest(BaseModel):
|
|||||||
description="The CDP address for the workflow run.",
|
description="The CDP address for the workflow run.",
|
||||||
examples=["http://127.0.0.1:9222", "ws://127.0.0.1:9222/devtools/browser/1234567890"],
|
examples=["http://127.0.0.1:9222", "ws://127.0.0.1:9222/devtools/browser/1234567890"],
|
||||||
)
|
)
|
||||||
|
ai_fallback: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to fallback to AI if the workflow run fails.",
|
||||||
|
)
|
||||||
|
run_with: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to run the workflow with agent or code.",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("webhook_url", "totp_url")
|
@field_validator("webhook_url", "totp_url")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -700,7 +700,16 @@ async def _fallback_to_ai_run(
|
|||||||
workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id)
|
workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id)
|
||||||
if not workflow:
|
if not workflow:
|
||||||
return
|
return
|
||||||
if not workflow.ai_fallback:
|
workflow_run = await app.DATABASE.get_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||||
|
)
|
||||||
|
if not workflow_run:
|
||||||
|
return
|
||||||
|
# Use workflow_run.ai_fallback if explicitly set, otherwise fall back to workflow.ai_fallback
|
||||||
|
effective_ai_fallback = (
|
||||||
|
workflow_run.ai_fallback if workflow_run.ai_fallback is not None else workflow.ai_fallback
|
||||||
|
)
|
||||||
|
if not effective_ai_fallback:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"AI fallback is not enabled for the workflow",
|
"AI fallback is not enabled for the workflow",
|
||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user