102 lines
4.1 KiB
Python
102 lines
4.1 KiB
Python
"""WebSocket endpoint for streaming global 2FA verification code notifications."""
|
|
|
|
import asyncio
|
|
|
|
import structlog
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
|
|
|
from skyvern.config import settings
|
|
from skyvern.forge import app
|
|
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
|
from skyvern.forge.sdk.routes.routers import base_router
|
|
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
|
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
|
|
|
LOG = structlog.get_logger()
|
|
HEARTBEAT_INTERVAL = 60
|
|
|
|
|
|
@base_router.websocket("/stream/notifications")
|
|
async def notification_stream(
|
|
websocket: WebSocket,
|
|
apikey: str | None = None,
|
|
token: str | None = None,
|
|
) -> None:
|
|
auth = local_auth if settings.ENV == "local" else real_auth
|
|
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
|
if not organization_id:
|
|
LOG.info("Notifications: Authentication failed")
|
|
return
|
|
|
|
LOG.info("Notifications: Started streaming", organization_id=organization_id)
|
|
registry = NotificationRegistryFactory.get_registry()
|
|
queue = registry.subscribe(organization_id)
|
|
|
|
try:
|
|
# Send initial state: all currently active verification requests
|
|
active_requests = await app.DATABASE.get_active_verification_requests(organization_id)
|
|
for req in active_requests:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "verification_code_required",
|
|
"task_id": req.get("task_id"),
|
|
"workflow_run_id": req.get("workflow_run_id"),
|
|
"identifier": req.get("verification_code_identifier"),
|
|
"polling_started_at": req.get("verification_code_polling_started_at"),
|
|
}
|
|
)
|
|
|
|
# Watch for client disconnect while streaming events
|
|
disconnect_event = asyncio.Event()
|
|
|
|
async def _watch_disconnect() -> None:
|
|
try:
|
|
while True:
|
|
await websocket.receive()
|
|
except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError):
|
|
disconnect_event.set()
|
|
|
|
watcher = asyncio.create_task(_watch_disconnect())
|
|
try:
|
|
while not disconnect_event.is_set():
|
|
queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL))
|
|
disconnect_wait = asyncio.ensure_future(disconnect_event.wait())
|
|
done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED)
|
|
for p in pending:
|
|
p.cancel()
|
|
|
|
if disconnect_event.is_set():
|
|
return
|
|
|
|
try:
|
|
message = queue_task.result()
|
|
await websocket.send_json(message)
|
|
except TimeoutError:
|
|
try:
|
|
await websocket.send_json({"type": "heartbeat"})
|
|
except Exception:
|
|
LOG.info(
|
|
"Notifications: Client unreachable during heartbeat. Closing.",
|
|
organization_id=organization_id,
|
|
)
|
|
return
|
|
except asyncio.CancelledError:
|
|
return
|
|
finally:
|
|
watcher.cancel()
|
|
|
|
except WebSocketDisconnect:
|
|
LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id)
|
|
except ConnectionClosedOK:
|
|
LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id)
|
|
except ConnectionClosedError:
|
|
LOG.warning(
|
|
"Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id
|
|
)
|
|
except Exception:
|
|
LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True)
|
|
finally:
|
|
registry.unsubscribe(organization_id, queue)
|
|
LOG.info("Notifications: Connection closed", organization_id=organization_id)
|