Files
Dorod-Sky/skyvern/forge/sdk/routes/streaming/notifications.py

102 lines
4.1 KiB
Python

"""WebSocket endpoint for streaming global 2FA verification code notifications."""
import asyncio
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.routes.routers import base_router
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
LOG = structlog.get_logger()
HEARTBEAT_INTERVAL = 60
@base_router.websocket("/stream/notifications")
async def notification_stream(
websocket: WebSocket,
apikey: str | None = None,
token: str | None = None,
) -> None:
auth = local_auth if settings.ENV == "local" else real_auth
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.info("Notifications: Authentication failed")
return
LOG.info("Notifications: Started streaming", organization_id=organization_id)
registry = NotificationRegistryFactory.get_registry()
queue = registry.subscribe(organization_id)
try:
# Send initial state: all currently active verification requests
active_requests = await app.DATABASE.get_active_verification_requests(organization_id)
for req in active_requests:
await websocket.send_json(
{
"type": "verification_code_required",
"task_id": req.get("task_id"),
"workflow_run_id": req.get("workflow_run_id"),
"identifier": req.get("verification_code_identifier"),
"polling_started_at": req.get("verification_code_polling_started_at"),
}
)
# Watch for client disconnect while streaming events
disconnect_event = asyncio.Event()
async def _watch_disconnect() -> None:
try:
while True:
await websocket.receive()
except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError):
disconnect_event.set()
watcher = asyncio.create_task(_watch_disconnect())
try:
while not disconnect_event.is_set():
queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL))
disconnect_wait = asyncio.ensure_future(disconnect_event.wait())
done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED)
for p in pending:
p.cancel()
if disconnect_event.is_set():
return
try:
message = queue_task.result()
await websocket.send_json(message)
except TimeoutError:
try:
await websocket.send_json({"type": "heartbeat"})
except Exception:
LOG.info(
"Notifications: Client unreachable during heartbeat. Closing.",
organization_id=organization_id,
)
return
except asyncio.CancelledError:
return
finally:
watcher.cancel()
except WebSocketDisconnect:
LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id)
except ConnectionClosedOK:
LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id)
except ConnectionClosedError:
LOG.warning(
"Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id
)
except Exception:
LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True)
finally:
registry.unsubscribe(organization_id, queue)
LOG.info("Notifications: Connection closed", organization_id=organization_id)