flag to control ai fallback (#3313)

This commit is contained in:
Shuchang Zheng
2025-08-29 05:24:17 +08:00
committed by GitHub
parent 015194f2a4
commit 916ab6c067
7 changed files with 48 additions and 2 deletions

View File

@@ -0,0 +1,31 @@
"""add ai_fallback field to workflows table
Revision ID: d3ec63728c2a
Revises: 1bba8a38ddc7
Create Date: 2025-08-28 21:12:54.750395+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "d3ec63728c2a"
down_revision: Union[str, None] = "1bba8a38ddc7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("workflows", sa.Column("ai_fallback", sa.Boolean(), nullable=False, server_default=sa.false()))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("workflows", "ai_fallback")
# ### end Alembic commands ###

View File

@@ -1366,6 +1366,7 @@ class AgentDB:
is_saved_task: bool = False, is_saved_task: bool = False,
status: WorkflowStatus = WorkflowStatus.published, status: WorkflowStatus = WorkflowStatus.published,
generate_script: bool = False, generate_script: bool = False,
ai_fallback: bool = False,
cache_key: str | None = None, cache_key: str | None = None,
) -> Workflow: ) -> Workflow:
async with self.Session() as session: async with self.Session() as session:
@@ -1385,6 +1386,7 @@ class AgentDB:
is_saved_task=is_saved_task, is_saved_task=is_saved_task,
status=status, status=status,
generate_script=generate_script, generate_script=generate_script,
ai_fallback=ai_fallback,
cache_key=cache_key, cache_key=cache_key,
) )
if workflow_permanent_id: if workflow_permanent_id:

View File

@@ -243,6 +243,7 @@ class WorkflowModel(Base):
model = Column(JSON, nullable=True) model = Column(JSON, nullable=True)
status = Column(String, nullable=False, default="published") status = Column(String, nullable=False, default="published")
generate_script = Column(Boolean, default=False, nullable=False) generate_script = Column(Boolean, default=False, nullable=False)
ai_fallback = Column(Boolean, default=False, nullable=False)
cache_key = Column(String, nullable=True) cache_key = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)

View File

@@ -264,6 +264,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
status=WorkflowStatus(workflow_model.status), status=WorkflowStatus(workflow_model.status),
extra_http_headers=workflow_model.extra_http_headers, extra_http_headers=workflow_model.extra_http_headers,
generate_script=workflow_model.generate_script, generate_script=workflow_model.generate_script,
ai_fallback=workflow_model.ai_fallback,
cache_key=workflow_model.cache_key, cache_key=workflow_model.cache_key,
) )

View File

@@ -77,6 +77,7 @@ class Workflow(BaseModel):
max_screenshot_scrolls: int | None = None max_screenshot_scrolls: int | None = None
extra_http_headers: dict[str, str] | None = None extra_http_headers: dict[str, str] | None = None
generate_script: bool = False generate_script: bool = False
ai_fallback: bool = False
cache_key: str | None = None cache_key: str | None = None
created_at: datetime created_at: datetime

View File

@@ -506,6 +506,7 @@ class WorkflowCreateYAMLRequest(BaseModel):
extra_http_headers: dict[str, str] | None = None extra_http_headers: dict[str, str] | None = None
status: WorkflowStatus = WorkflowStatus.published status: WorkflowStatus = WorkflowStatus.published
generate_script: bool = False generate_script: bool = False
ai_fallback: bool = False
cache_key: str | None = None cache_key: str | None = None

View File

@@ -478,7 +478,7 @@ async def _fallback_to_ai_run(
try: try:
organization_id = context.organization_id organization_id = context.organization_id
LOG.info( LOG.info(
"Script fallback to AI run", "Script trying to fallback to AI run",
cache_key=cache_key, cache_key=cache_key,
organization_id=organization_id, organization_id=organization_id,
workflow_id=context.workflow_id, workflow_id=context.workflow_id,
@@ -510,13 +510,22 @@ async def _fallback_to_ai_run(
if not task: if not task:
raise Exception(f"Task is missing task_id={context.task_id}") raise Exception(f"Task is missing task_id={context.task_id}")
workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id) workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id)
if not workflow: if not workflow or not workflow.ai_fallback:
return return
# get the output_paramter # get the output_paramter
output_parameter = workflow.get_output_parameter(cache_key) output_parameter = workflow.get_output_parameter(cache_key)
if not output_parameter: if not output_parameter:
return return
LOG.info(
"Script starting to fallback to AI run",
cache_key=cache_key,
organization_id=organization_id,
workflow_id=context.workflow_id,
workflow_run_id=context.workflow_run_id,
task_id=context.task_id,
step_id=context.step_id,
)
task_block = TaskBlock( task_block = TaskBlock(
label=cache_key, label=cache_key,