[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)

This commit is contained in:
Aaron Perez
2026-02-18 17:21:58 -05:00
committed by GitHub
parent b48bf707c3
commit e3b6d22fb6
28 changed files with 1989 additions and 41 deletions

View File

@@ -11,5 +11,6 @@ from skyvern.forge.sdk.routes import sdk # noqa: F401
from skyvern.forge.sdk.routes import webhooks # noqa: F401
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
from skyvern.forge.sdk.routes.streaming import notifications # noqa: F401
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401

View File

@@ -131,10 +131,15 @@ async def send_totp_code(
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
if not task:
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
workflow_id_for_storage: str | None = None
if data.workflow_id:
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
if data.workflow_id.startswith("wpid_"):
workflow = await app.DATABASE.get_workflow_by_permanent_id(data.workflow_id, curr_org.organization_id)
else:
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
if not workflow:
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
workflow_id_for_storage = workflow.workflow_id
if data.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
if not workflow_run:
@@ -162,7 +167,7 @@ async def send_totp_code(
content=data.content,
code=otp_value.value,
task_id=data.task_id,
workflow_id=data.workflow_id,
workflow_id=workflow_id_for_storage,
workflow_run_id=data.workflow_run_id,
source=data.source,
expired_at=data.expired_at,

View File

@@ -0,0 +1,101 @@
"""WebSocket endpoint for streaming global 2FA verification code notifications."""
import asyncio
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.routes.routers import base_router
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
LOG = structlog.get_logger()
HEARTBEAT_INTERVAL = 60
@base_router.websocket("/stream/notifications")
async def notification_stream(
websocket: WebSocket,
apikey: str | None = None,
token: str | None = None,
) -> None:
auth = local_auth if settings.ENV == "local" else real_auth
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.info("Notifications: Authentication failed")
return
LOG.info("Notifications: Started streaming", organization_id=organization_id)
registry = NotificationRegistryFactory.get_registry()
queue = registry.subscribe(organization_id)
try:
# Send initial state: all currently active verification requests
active_requests = await app.DATABASE.get_active_verification_requests(organization_id)
for req in active_requests:
await websocket.send_json(
{
"type": "verification_code_required",
"task_id": req.get("task_id"),
"workflow_run_id": req.get("workflow_run_id"),
"identifier": req.get("verification_code_identifier"),
"polling_started_at": req.get("verification_code_polling_started_at"),
}
)
# Watch for client disconnect while streaming events
disconnect_event = asyncio.Event()
async def _watch_disconnect() -> None:
try:
while True:
await websocket.receive()
except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError):
disconnect_event.set()
watcher = asyncio.create_task(_watch_disconnect())
try:
while not disconnect_event.is_set():
queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL))
disconnect_wait = asyncio.ensure_future(disconnect_event.wait())
done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED)
for p in pending:
p.cancel()
if disconnect_event.is_set():
return
try:
message = queue_task.result()
await websocket.send_json(message)
except TimeoutError:
try:
await websocket.send_json({"type": "heartbeat"})
except Exception:
LOG.info(
"Notifications: Client unreachable during heartbeat. Closing.",
organization_id=organization_id,
)
return
except asyncio.CancelledError:
return
finally:
watcher.cancel()
except WebSocketDisconnect:
LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id)
except ConnectionClosedOK:
LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id)
except ConnectionClosedError:
LOG.warning(
"Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id
)
except Exception:
LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True)
finally:
registry.unsubscribe(organization_id, queue)
LOG.info("Notifications: Connection closed", organization_id=organization_id)