From f97b53975f499121f3bad5d451391fb8435040e1 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 1 Oct 2025 14:13:56 -0700 Subject: [PATCH] add ai_fallback to workflow_runs (#3581) --- ...c4941e_add_ai_fallback_to_workflow_runs.py | 31 +++++++++++++++++++ skyvern/forge/sdk/db/client.py | 5 +++ skyvern/forge/sdk/db/models.py | 1 + skyvern/forge/sdk/workflow/models/workflow.py | 2 ++ skyvern/forge/sdk/workflow/service.py | 5 +++ skyvern/schemas/runs.py | 8 +++++ skyvern/services/script_service.py | 11 ++++++- 7 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 alembic/versions/2025_10_01_2107-d36daac4941e_add_ai_fallback_to_workflow_runs.py diff --git a/alembic/versions/2025_10_01_2107-d36daac4941e_add_ai_fallback_to_workflow_runs.py b/alembic/versions/2025_10_01_2107-d36daac4941e_add_ai_fallback_to_workflow_runs.py new file mode 100644 index 00000000..1afc0697 --- /dev/null +++ b/alembic/versions/2025_10_01_2107-d36daac4941e_add_ai_fallback_to_workflow_runs.py @@ -0,0 +1,31 @@ +"""add ai_fallback to workflow_runs + +Revision ID: d36daac4941e +Revises: c50ee6f26432 +Create Date: 2025-10-01 21:07:38.185580+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d36daac4941e" +down_revision: Union[str, None] = "c50ee6f26432" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("workflow_runs", sa.Column("ai_fallback", sa.Boolean(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("workflow_runs", "ai_fallback") + # ### end Alembic commands ### diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 80bef714..2c0da470 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1699,6 +1699,7 @@ class AgentDB: sequential_key: str | None = None, run_with: str | None = None, debug_session_id: str | None = None, + ai_fallback: bool | None = None, ) -> WorkflowRun: try: async with self.Session() as session: @@ -1719,6 +1720,7 @@ class AgentDB: sequential_key=sequential_key, run_with=run_with, debug_session_id=debug_session_id, + ai_fallback=ai_fallback, ) session.add(workflow_run) await session.commit() @@ -1738,6 +1740,7 @@ class AgentDB: job_id: str | None = None, run_with: str | None = None, sequential_key: str | None = None, + ai_fallback: bool | None = None, ) -> WorkflowRun: async with self.Session() as session: workflow_run = ( @@ -1764,6 +1767,8 @@ class AgentDB: workflow_run.run_with = run_with if sequential_key: workflow_run.sequential_key = sequential_key + if ai_fallback is not None: + workflow_run.ai_fallback = ai_fallback await session.commit() await session.refresh(workflow_run) await save_workflow_run_logs(workflow_run_id) diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 56baa8f8..dd10fc91 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -291,6 +291,7 @@ class WorkflowRunModel(Base): sequential_key = Column(String, nullable=True) run_with = Column(String, nullable=True) # 'agent' or 'code' debug_session_id: Column = Column(String, nullable=True) + ai_fallback = Column(Boolean, nullable=True) queued_at = Column(DateTime, nullable=True) started_at = Column(DateTime, nullable=True) diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index e359765b..6759d819 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -27,6 +27,7 @@ class WorkflowRequestBody(BaseModel): extra_http_headers: dict[str, str] | None = None browser_address: str | None = None run_with: str | None = None + ai_fallback: bool | None = None @field_validator("webhook_callback_url", "totp_verification_url") @classmethod @@ -143,6 +144,7 @@ class WorkflowRun(BaseModel): script_run: ScriptRunResponse | None = None job_id: str | None = None sequential_key: str | None = None + ai_fallback: bool | None = None queued_at: datetime | None = None started_at: datetime | None = None diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 8bb22b62..cda8334e 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -241,6 +241,8 @@ class WorkflowService: proxy_location=workflow_request.proxy_location, webhook_callback_url=workflow_request.webhook_callback_url, max_screenshot_scrolling_times=workflow_request.max_screenshot_scrolls, + ai_fallback=workflow_request.ai_fallback, + run_with=workflow_request.run_with, ) context: skyvern_context.SkyvernContext | None = skyvern_context.current() current_run_id = context.run_id if context and context.run_id else workflow_run.workflow_run_id @@ -1071,6 +1073,7 @@ class WorkflowService: sequential_key=sequential_key, run_with=workflow_request.run_with, debug_session_id=debug_session_id, + ai_fallback=workflow_request.ai_fallback, ) async def _update_workflow_run_status( @@ -1079,12 +1082,14 @@ class WorkflowService: status: WorkflowRunStatus, failure_reason: str | None = None, run_with: str | None = None, + ai_fallback: bool | None = None, ) -> WorkflowRun: workflow_run = await app.DATABASE.update_workflow_run( workflow_run_id=workflow_run_id, status=status, failure_reason=failure_reason, run_with=run_with, + ai_fallback=ai_fallback, ) if status in [WorkflowRunStatus.completed, WorkflowRunStatus.failed, WorkflowRunStatus.terminated]: start_time = ( diff --git a/skyvern/schemas/runs.py b/skyvern/schemas/runs.py index ce102724..75988086 100644 --- a/skyvern/schemas/runs.py +++ b/skyvern/schemas/runs.py @@ -358,6 +358,14 @@ class WorkflowRunRequest(BaseModel): description="The CDP address for the workflow run.", examples=["http://127.0.0.1:9222", "ws://127.0.0.1:9222/devtools/browser/1234567890"], ) + ai_fallback: bool | None = Field( + default=None, + description="Whether to fallback to AI if the workflow run fails.", + ) + run_with: str | None = Field( + default=None, + description="Whether to run the workflow with agent or code.", + ) @field_validator("webhook_url", "totp_url") @classmethod diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index 13d11c5a..3d6f6821 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -700,7 +700,16 @@ async def _fallback_to_ai_run( workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id) if not workflow: return - if not workflow.ai_fallback: + workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, organization_id=organization_id + ) + if not workflow_run: + return + # Use workflow_run.ai_fallback if explicitly set, otherwise fall back to workflow.ai_fallback + effective_ai_fallback = ( + workflow_run.ai_fallback if workflow_run.ai_fallback is not None else workflow.ai_fallback + ) + if not effective_ai_fallback: LOG.info( "AI fallback is not enabled for the workflow", workflow_id=workflow_id,