[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)
This commit is contained in:
@@ -11,5 +11,6 @@ from skyvern.forge.sdk.routes import sdk # noqa: F401
|
||||
from skyvern.forge.sdk.routes import webhooks # noqa: F401
|
||||
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import notifications # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401
|
||||
|
||||
@@ -131,10 +131,15 @@ async def send_totp_code(
|
||||
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
|
||||
workflow_id_for_storage: str | None = None
|
||||
if data.workflow_id:
|
||||
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
|
||||
if data.workflow_id.startswith("wpid_"):
|
||||
workflow = await app.DATABASE.get_workflow_by_permanent_id(data.workflow_id, curr_org.organization_id)
|
||||
else:
|
||||
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
|
||||
workflow_id_for_storage = workflow.workflow_id
|
||||
if data.workflow_run_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
|
||||
if not workflow_run:
|
||||
@@ -162,7 +167,7 @@ async def send_totp_code(
|
||||
content=data.content,
|
||||
code=otp_value.value,
|
||||
task_id=data.task_id,
|
||||
workflow_id=data.workflow_id,
|
||||
workflow_id=workflow_id_for_storage,
|
||||
workflow_run_id=data.workflow_run_id,
|
||||
source=data.source,
|
||||
expired_at=data.expired_at,
|
||||
|
||||
101
skyvern/forge/sdk/routes/streaming/notifications.py
Normal file
101
skyvern/forge/sdk/routes/streaming/notifications.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""WebSocket endpoint for streaming global 2FA verification code notifications."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||
from skyvern.forge.sdk.routes.routers import base_router
|
||||
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
HEARTBEAT_INTERVAL = 60
|
||||
|
||||
|
||||
@base_router.websocket("/stream/notifications")
|
||||
async def notification_stream(
|
||||
websocket: WebSocket,
|
||||
apikey: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
auth = local_auth if settings.ENV == "local" else real_auth
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
if not organization_id:
|
||||
LOG.info("Notifications: Authentication failed")
|
||||
return
|
||||
|
||||
LOG.info("Notifications: Started streaming", organization_id=organization_id)
|
||||
registry = NotificationRegistryFactory.get_registry()
|
||||
queue = registry.subscribe(organization_id)
|
||||
|
||||
try:
|
||||
# Send initial state: all currently active verification requests
|
||||
active_requests = await app.DATABASE.get_active_verification_requests(organization_id)
|
||||
for req in active_requests:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "verification_code_required",
|
||||
"task_id": req.get("task_id"),
|
||||
"workflow_run_id": req.get("workflow_run_id"),
|
||||
"identifier": req.get("verification_code_identifier"),
|
||||
"polling_started_at": req.get("verification_code_polling_started_at"),
|
||||
}
|
||||
)
|
||||
|
||||
# Watch for client disconnect while streaming events
|
||||
disconnect_event = asyncio.Event()
|
||||
|
||||
async def _watch_disconnect() -> None:
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive()
|
||||
except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError):
|
||||
disconnect_event.set()
|
||||
|
||||
watcher = asyncio.create_task(_watch_disconnect())
|
||||
try:
|
||||
while not disconnect_event.is_set():
|
||||
queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL))
|
||||
disconnect_wait = asyncio.ensure_future(disconnect_event.wait())
|
||||
done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED)
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
|
||||
if disconnect_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
message = queue_task.result()
|
||||
await websocket.send_json(message)
|
||||
except TimeoutError:
|
||||
try:
|
||||
await websocket.send_json({"type": "heartbeat"})
|
||||
except Exception:
|
||||
LOG.info(
|
||||
"Notifications: Client unreachable during heartbeat. Closing.",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
watcher.cancel()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id)
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id)
|
||||
except ConnectionClosedError:
|
||||
LOG.warning(
|
||||
"Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True)
|
||||
finally:
|
||||
registry.unsubscribe(organization_id, queue)
|
||||
LOG.info("Notifications: Connection closed", organization_id=organization_id)
|
||||
Reference in New Issue
Block a user