[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)

This commit is contained in:
Aaron Perez
2026-02-18 17:21:58 -05:00
committed by GitHub
parent b48bf707c3
commit e3b6d22fb6
28 changed files with 1989 additions and 41 deletions

View File

@@ -849,6 +849,102 @@ class AgentDB(BaseAlchemyDB):
LOG.error("UnexpectedError", exc_info=True)
raise
async def update_task_2fa_state(
self,
task_id: str,
organization_id: str,
waiting_for_verification_code: bool,
verification_code_identifier: str | None = None,
verification_code_polling_started_at: datetime | None = None,
) -> Task:
"""Update task 2FA verification code waiting state."""
try:
async with self.Session() as session:
if task := (
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
task.waiting_for_verification_code = waiting_for_verification_code
if verification_code_identifier is not None:
task.verification_code_identifier = verification_code_identifier
if verification_code_polling_started_at is not None:
task.verification_code_polling_started_at = verification_code_polling_started_at
if not waiting_for_verification_code:
# Clear identifiers when no longer waiting
task.verification_code_identifier = None
task.verification_code_polling_started_at = None
await session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:
raise NotFoundError("Task not found")
return updated_task
else:
raise NotFoundError("Task not found")
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
@read_retry()
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
"""Return active 2FA verification requests for an organization.
Queries both tasks and workflow runs where waiting_for_verification_code=True.
Used to provide initial state when a WebSocket notification client connects.
"""
results: list[dict] = []
async with self.Session() as session:
# Tasks waiting for verification (exclude finalized tasks)
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
task_rows = (
await session.scalars(
select(TaskModel)
.filter_by(organization_id=organization_id)
.filter_by(waiting_for_verification_code=True)
.filter_by(workflow_run_id=None)
.filter(TaskModel.status.not_in(finalized_task_statuses))
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
)
).all()
for t in task_rows:
results.append(
{
"task_id": t.task_id,
"workflow_run_id": None,
"verification_code_identifier": t.verification_code_identifier,
"verification_code_polling_started_at": (
t.verification_code_polling_started_at.isoformat()
if t.verification_code_polling_started_at
else None
),
}
)
# Workflow runs waiting for verification (exclude finalized runs)
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
wr_rows = (
await session.scalars(
select(WorkflowRunModel)
.filter_by(organization_id=organization_id)
.filter_by(waiting_for_verification_code=True)
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
)
).all()
for wr in wr_rows:
results.append(
{
"task_id": None,
"workflow_run_id": wr.workflow_run_id,
"verification_code_identifier": wr.verification_code_identifier,
"verification_code_polling_started_at": (
wr.verification_code_polling_started_at.isoformat()
if wr.verification_code_polling_started_at
else None
),
}
)
return results
async def bulk_update_tasks(
self,
task_ids: list[str],
@@ -2794,6 +2890,9 @@ class AgentDB(BaseAlchemyDB):
ai_fallback: bool | None = None,
depends_on_workflow_run_id: str | None = None,
browser_session_id: str | None = None,
waiting_for_verification_code: bool | None = None,
verification_code_identifier: str | None = None,
verification_code_polling_started_at: datetime | None = None,
) -> WorkflowRun:
async with self.Session() as session:
workflow_run = (
@@ -2826,6 +2925,17 @@ class AgentDB(BaseAlchemyDB):
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
if browser_session_id:
workflow_run.browser_session_id = browser_session_id
# 2FA verification code waiting state updates
if waiting_for_verification_code is not None:
workflow_run.waiting_for_verification_code = waiting_for_verification_code
if verification_code_identifier is not None:
workflow_run.verification_code_identifier = verification_code_identifier
if verification_code_polling_started_at is not None:
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
if waiting_for_verification_code is not None and not waiting_for_verification_code:
# Clear related fields when waiting is set to False
workflow_run.verification_code_identifier = None
workflow_run.verification_code_polling_started_at = None
await session.commit()
await save_workflow_run_logs(workflow_run_id)
await session.refresh(workflow_run)
@@ -3995,6 +4105,35 @@ class AgentDB(BaseAlchemyDB):
totp_code = (await session.scalars(query)).all()
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
async def get_otp_codes_by_run(
self,
organization_id: str,
task_id: str | None = None,
workflow_run_id: str | None = None,
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
limit: int = 1,
) -> list[TOTPCode]:
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
The user submits codes manually via the UI, and this method finds them by run context.
"""
if not workflow_run_id and not task_id:
return []
async with self.Session() as session:
query = (
select(TOTPCodeModel)
.filter_by(organization_id=organization_id)
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
)
if workflow_run_id:
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
elif task_id:
query = query.filter(TOTPCodeModel.task_id == task_id)
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
results = (await session.scalars(query)).all()
return [TOTPCode.model_validate(r) for r in results]
async def get_recent_otp_codes(
self,
organization_id: str,

View File

@@ -116,6 +116,10 @@ class TaskModel(Base):
model = Column(JSON, nullable=True)
browser_address = Column(String, nullable=True)
download_timeout = Column(Numeric, nullable=True)
# 2FA verification code waiting state fields
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
verification_code_identifier = Column(String, nullable=True)
verification_code_polling_started_at = Column(DateTime, nullable=True)
class StepModel(Base):
@@ -350,6 +354,10 @@ class WorkflowRunModel(Base):
debug_session_id: Column = Column(String, nullable=True)
ai_fallback = Column(Boolean, nullable=True)
code_gen = Column(Boolean, nullable=True)
# 2FA verification code waiting state fields
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
verification_code_identifier = Column(String, nullable=True)
verification_code_polling_started_at = Column(DateTime, nullable=True)
queued_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)

View File

@@ -211,6 +211,9 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
browser_session_id=task_obj.browser_session_id,
browser_address=task_obj.browser_address,
download_timeout=task_obj.download_timeout,
waiting_for_verification_code=task_obj.waiting_for_verification_code or False,
verification_code_identifier=task_obj.verification_code_identifier,
verification_code_polling_started_at=task_obj.verification_code_polling_started_at,
)
return task
@@ -424,6 +427,9 @@ def convert_to_workflow_run(
run_with=workflow_run_model.run_with,
code_gen=workflow_run_model.code_gen,
ai_fallback=workflow_run_model.ai_fallback,
waiting_for_verification_code=workflow_run_model.waiting_for_verification_code or False,
verification_code_identifier=workflow_run_model.verification_code_identifier,
verification_code_polling_started_at=workflow_run_model.verification_code_polling_started_at,
)

View File

@@ -0,0 +1,21 @@
"""Abstract base for notification registries."""
import asyncio
from abc import ABC, abstractmethod
class BaseNotificationRegistry(ABC):
"""Abstract pub/sub registry scoped by organization.
Implementations must fan-out: a single publish call delivers the
message to every active subscriber for that organization.
"""
@abstractmethod
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]: ...
@abstractmethod
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None: ...
@abstractmethod
def publish(self, organization_id: str, message: dict) -> None: ...

View File

@@ -0,0 +1,14 @@
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
class NotificationRegistryFactory:
__registry: BaseNotificationRegistry = LocalNotificationRegistry()
@staticmethod
def set_registry(registry: BaseNotificationRegistry) -> None:
NotificationRegistryFactory.__registry = registry
@staticmethod
def get_registry() -> BaseNotificationRegistry:
return NotificationRegistryFactory.__registry

View File

@@ -0,0 +1,45 @@
"""In-process notification registry using asyncio queues (single-pod only)."""
import asyncio
from collections import defaultdict
import structlog
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
LOG = structlog.get_logger()
class LocalNotificationRegistry(BaseNotificationRegistry):
"""In-process fan-out pub/sub using asyncio queues. Single-pod only."""
def __init__(self) -> None:
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
queue: asyncio.Queue[dict] = asyncio.Queue()
self._subscribers[organization_id].append(queue)
LOG.info("Notification subscriber added", organization_id=organization_id)
return queue
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
queues = self._subscribers.get(organization_id)
if queues:
try:
queues.remove(queue)
except ValueError:
pass
if not queues:
del self._subscribers[organization_id]
LOG.info("Notification subscriber removed", organization_id=organization_id)
def publish(self, organization_id: str, message: dict) -> None:
queues = self._subscribers.get(organization_id, [])
for queue in queues:
try:
queue.put_nowait(message)
except asyncio.QueueFull:
LOG.warning(
"Notification queue full, dropping message",
organization_id=organization_id,
)

View File

@@ -0,0 +1,55 @@
"""Redis-backed notification registry for multi-pod deployments.
Thin adapter around :class:`RedisPubSub` — all Redis pub/sub logic
lives in the generic layer; this class maps the ``organization_id``
domain concept onto generic string keys.
"""
import asyncio
from redis.asyncio import Redis
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
class RedisNotificationRegistry(BaseNotificationRegistry):
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per org."""
def __init__(self, redis_client: Redis) -> None:
self._pubsub = RedisPubSub(redis_client, channel_prefix="skyvern:notifications:")
# ------------------------------------------------------------------
# Property accessors (used by existing tests)
# ------------------------------------------------------------------
@property
def _listener_tasks(self) -> dict[str, asyncio.Task[None]]:
return self._pubsub._listener_tasks
@property
def _subscribers(self) -> dict[str, list[asyncio.Queue[dict]]]:
return self._pubsub._subscribers
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
return self._pubsub.subscribe(organization_id)
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
self._pubsub.unsubscribe(organization_id, queue)
def publish(self, organization_id: str, message: dict) -> None:
self._pubsub.publish(organization_id, message)
async def close(self) -> None:
await self._pubsub.close()
# ------------------------------------------------------------------
# Internal helper (exposed for tests)
# ------------------------------------------------------------------
def _dispatch_local(self, organization_id: str, message: dict) -> None:
self._pubsub._dispatch_local(organization_id, message)

View File

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from redis.asyncio import Redis
class RedisClientFactory:
"""Singleton factory for a shared async Redis client.
Follows the same static set/get pattern as ``CacheFactory``.
Defaults to ``None`` (no Redis in local/OSS mode).
"""
__client: Redis | None = None
@staticmethod
def set_client(client: Redis) -> None:
RedisClientFactory.__client = client
@staticmethod
def get_client() -> Redis | None:
return RedisClientFactory.__client

View File

@@ -0,0 +1,130 @@
"""Generic Redis pub/sub layer.
Extracted from ``RedisNotificationRegistry`` so that any feature
(notifications, events, cache invalidation, etc.) can reuse the same
pattern with its own channel prefix.
"""
import asyncio
import json
from collections import defaultdict
import structlog
from redis.asyncio import Redis
LOG = structlog.get_logger()
class RedisPubSub:
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per key."""
def __init__(self, redis_client: Redis, channel_prefix: str) -> None:
self._redis = redis_client
self._channel_prefix = channel_prefix
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
# One listener task per key channel
self._listener_tasks: dict[str, asyncio.Task[None]] = {}
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def subscribe(self, key: str) -> asyncio.Queue[dict]:
queue: asyncio.Queue[dict] = asyncio.Queue()
self._subscribers[key].append(queue)
# Spin up a Redis listener if this is the first local subscriber
if key not in self._listener_tasks:
task = asyncio.get_running_loop().create_task(self._listen(key))
self._listener_tasks[key] = task
LOG.info("PubSub subscriber added", key=key, channel_prefix=self._channel_prefix)
return queue
def unsubscribe(self, key: str, queue: asyncio.Queue[dict]) -> None:
queues = self._subscribers.get(key)
if queues:
try:
queues.remove(queue)
except ValueError:
pass
if not queues:
del self._subscribers[key]
self._cancel_listener(key)
LOG.info("PubSub subscriber removed", key=key, channel_prefix=self._channel_prefix)
def publish(self, key: str, message: dict) -> None:
"""Fire-and-forget Redis PUBLISH."""
try:
loop = asyncio.get_running_loop()
loop.create_task(self._publish_to_redis(key, message))
except RuntimeError:
LOG.warning(
"No running event loop; cannot publish via Redis",
key=key,
channel_prefix=self._channel_prefix,
)
async def close(self) -> None:
"""Cancel all listener tasks and clear state. Call on shutdown."""
for key in list(self._listener_tasks):
self._cancel_listener(key)
self._subscribers.clear()
LOG.info("RedisPubSub closed", channel_prefix=self._channel_prefix)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _publish_to_redis(self, key: str, message: dict) -> None:
channel = f"{self._channel_prefix}{key}"
try:
await self._redis.publish(channel, json.dumps(message))
except Exception:
LOG.exception("Failed to publish to Redis", key=key, channel_prefix=self._channel_prefix)
async def _listen(self, key: str) -> None:
"""Subscribe to a Redis channel and fan out messages locally."""
channel = f"{self._channel_prefix}{key}"
pubsub = self._redis.pubsub()
try:
await pubsub.subscribe(channel)
LOG.info("Redis listener started", channel=channel)
async for raw_message in pubsub.listen():
if raw_message["type"] != "message":
continue
try:
data = json.loads(raw_message["data"])
except (json.JSONDecodeError, TypeError):
LOG.warning("Invalid JSON on Redis channel", channel=channel)
continue
self._dispatch_local(key, data)
except asyncio.CancelledError:
LOG.info("Redis listener cancelled", channel=channel)
raise
except Exception:
LOG.exception("Redis listener error", channel=channel)
finally:
try:
await pubsub.unsubscribe(channel)
await pubsub.close()
except Exception:
LOG.warning("Error closing Redis pubsub", channel=channel)
def _dispatch_local(self, key: str, message: dict) -> None:
"""Fan out a message to all local asyncio queues for this key."""
queues = self._subscribers.get(key, [])
for queue in queues:
try:
queue.put_nowait(message)
except asyncio.QueueFull:
LOG.warning(
"Queue full, dropping message",
key=key,
channel_prefix=self._channel_prefix,
)
def _cancel_listener(self, key: str) -> None:
task = self._listener_tasks.pop(key, None)
if task and not task.done():
task.cancel()

View File

@@ -11,5 +11,6 @@ from skyvern.forge.sdk.routes import sdk # noqa: F401
from skyvern.forge.sdk.routes import webhooks # noqa: F401
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
from skyvern.forge.sdk.routes.streaming import notifications # noqa: F401
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401

View File

@@ -131,10 +131,15 @@ async def send_totp_code(
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
if not task:
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
workflow_id_for_storage: str | None = None
if data.workflow_id:
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
if data.workflow_id.startswith("wpid_"):
workflow = await app.DATABASE.get_workflow_by_permanent_id(data.workflow_id, curr_org.organization_id)
else:
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
if not workflow:
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
workflow_id_for_storage = workflow.workflow_id
if data.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
if not workflow_run:
@@ -162,7 +167,7 @@ async def send_totp_code(
content=data.content,
code=otp_value.value,
task_id=data.task_id,
workflow_id=data.workflow_id,
workflow_id=workflow_id_for_storage,
workflow_run_id=data.workflow_run_id,
source=data.source,
expired_at=data.expired_at,

View File

@@ -0,0 +1,101 @@
"""WebSocket endpoint for streaming global 2FA verification code notifications."""
import asyncio
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.routes.routers import base_router
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
LOG = structlog.get_logger()
HEARTBEAT_INTERVAL = 60
@base_router.websocket("/stream/notifications")
async def notification_stream(
websocket: WebSocket,
apikey: str | None = None,
token: str | None = None,
) -> None:
auth = local_auth if settings.ENV == "local" else real_auth
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.info("Notifications: Authentication failed")
return
LOG.info("Notifications: Started streaming", organization_id=organization_id)
registry = NotificationRegistryFactory.get_registry()
queue = registry.subscribe(organization_id)
try:
# Send initial state: all currently active verification requests
active_requests = await app.DATABASE.get_active_verification_requests(organization_id)
for req in active_requests:
await websocket.send_json(
{
"type": "verification_code_required",
"task_id": req.get("task_id"),
"workflow_run_id": req.get("workflow_run_id"),
"identifier": req.get("verification_code_identifier"),
"polling_started_at": req.get("verification_code_polling_started_at"),
}
)
# Watch for client disconnect while streaming events
disconnect_event = asyncio.Event()
async def _watch_disconnect() -> None:
try:
while True:
await websocket.receive()
except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError):
disconnect_event.set()
watcher = asyncio.create_task(_watch_disconnect())
try:
while not disconnect_event.is_set():
queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL))
disconnect_wait = asyncio.ensure_future(disconnect_event.wait())
done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED)
for p in pending:
p.cancel()
if disconnect_event.is_set():
return
try:
message = queue_task.result()
await websocket.send_json(message)
except TimeoutError:
try:
await websocket.send_json({"type": "heartbeat"})
except Exception:
LOG.info(
"Notifications: Client unreachable during heartbeat. Closing.",
organization_id=organization_id,
)
return
except asyncio.CancelledError:
return
finally:
watcher.cancel()
except WebSocketDisconnect:
LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id)
except ConnectionClosedOK:
LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id)
except ConnectionClosedError:
LOG.warning(
"Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id
)
except Exception:
LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True)
finally:
registry.unsubscribe(organization_id, queue)
LOG.info("Notifications: Connection closed", organization_id=organization_id)

View File

@@ -288,6 +288,10 @@ class Task(TaskBase):
queued_at: datetime | None = None
started_at: datetime | None = None
finished_at: datetime | None = None
# 2FA verification code waiting state fields
waiting_for_verification_code: bool = False
verification_code_identifier: str | None = None
verification_code_polling_started_at: datetime | None = None
@property
def llm_key(self) -> str | None:
@@ -365,6 +369,9 @@ class Task(TaskBase):
max_screenshot_scrolls=self.max_screenshot_scrolls,
step_count=step_count,
browser_session_id=self.browser_session_id,
waiting_for_verification_code=self.waiting_for_verification_code,
verification_code_identifier=self.verification_code_identifier,
verification_code_polling_started_at=self.verification_code_polling_started_at,
)
@@ -392,6 +399,10 @@ class TaskResponse(BaseModel):
max_screenshot_scrolls: int | None = None
step_count: int | None = None
browser_session_id: str | None = None
# 2FA verification code waiting state fields
waiting_for_verification_code: bool = False
verification_code_identifier: str | None = None
verification_code_polling_started_at: datetime | None = None
class TaskOutput(BaseModel):

View File

@@ -172,6 +172,10 @@ class WorkflowRun(BaseModel):
sequential_key: str | None = None
ai_fallback: bool | None = None
code_gen: bool | None = None
# 2FA verification code waiting state fields
waiting_for_verification_code: bool = False
verification_code_identifier: str | None = None
verification_code_polling_started_at: datetime | None = None
queued_at: datetime | None = None
started_at: datetime | None = None
@@ -226,6 +230,10 @@ class WorkflowRunResponseBase(BaseModel):
browser_address: str | None = None
script_run: ScriptRunResponse | None = None
errors: list[dict[str, Any]] | None = None
# 2FA verification code waiting state fields
waiting_for_verification_code: bool = False
verification_code_identifier: str | None = None
verification_code_polling_started_at: datetime | None = None
class WorkflowRunWithWorkflowResponse(WorkflowRunResponseBase):

View File

@@ -3019,6 +3019,10 @@ class WorkflowService:
browser_address=workflow_run.browser_address,
script_run=workflow_run.script_run,
errors=errors,
# 2FA verification code waiting state fields
waiting_for_verification_code=workflow_run.waiting_for_verification_code,
verification_code_identifier=workflow_run.verification_code_identifier,
verification_code_polling_started_at=workflow_run.verification_code_polling_started_at,
)
async def clean_up_workflow(