[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)
This commit is contained in:
130
skyvern/forge/sdk/redis/pubsub.py
Normal file
130
skyvern/forge/sdk/redis/pubsub.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Generic Redis pub/sub layer.
|
||||
|
||||
Extracted from ``RedisNotificationRegistry`` so that any feature
|
||||
(notifications, events, cache invalidation, etc.) can reuse the same
|
||||
pattern with its own channel prefix.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import structlog
|
||||
from redis.asyncio import Redis
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class RedisPubSub:
|
||||
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per key."""
|
||||
|
||||
def __init__(self, redis_client: Redis, channel_prefix: str) -> None:
|
||||
self._redis = redis_client
|
||||
self._channel_prefix = channel_prefix
|
||||
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
|
||||
# One listener task per key channel
|
||||
self._listener_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, key: str) -> asyncio.Queue[dict]:
|
||||
queue: asyncio.Queue[dict] = asyncio.Queue()
|
||||
self._subscribers[key].append(queue)
|
||||
|
||||
# Spin up a Redis listener if this is the first local subscriber
|
||||
if key not in self._listener_tasks:
|
||||
task = asyncio.get_running_loop().create_task(self._listen(key))
|
||||
self._listener_tasks[key] = task
|
||||
|
||||
LOG.info("PubSub subscriber added", key=key, channel_prefix=self._channel_prefix)
|
||||
return queue
|
||||
|
||||
def unsubscribe(self, key: str, queue: asyncio.Queue[dict]) -> None:
|
||||
queues = self._subscribers.get(key)
|
||||
if queues:
|
||||
try:
|
||||
queues.remove(queue)
|
||||
except ValueError:
|
||||
pass
|
||||
if not queues:
|
||||
del self._subscribers[key]
|
||||
self._cancel_listener(key)
|
||||
LOG.info("PubSub subscriber removed", key=key, channel_prefix=self._channel_prefix)
|
||||
|
||||
def publish(self, key: str, message: dict) -> None:
|
||||
"""Fire-and-forget Redis PUBLISH."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._publish_to_redis(key, message))
|
||||
except RuntimeError:
|
||||
LOG.warning(
|
||||
"No running event loop; cannot publish via Redis",
|
||||
key=key,
|
||||
channel_prefix=self._channel_prefix,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Cancel all listener tasks and clear state. Call on shutdown."""
|
||||
for key in list(self._listener_tasks):
|
||||
self._cancel_listener(key)
|
||||
self._subscribers.clear()
|
||||
LOG.info("RedisPubSub closed", channel_prefix=self._channel_prefix)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _publish_to_redis(self, key: str, message: dict) -> None:
|
||||
channel = f"{self._channel_prefix}{key}"
|
||||
try:
|
||||
await self._redis.publish(channel, json.dumps(message))
|
||||
except Exception:
|
||||
LOG.exception("Failed to publish to Redis", key=key, channel_prefix=self._channel_prefix)
|
||||
|
||||
async def _listen(self, key: str) -> None:
|
||||
"""Subscribe to a Redis channel and fan out messages locally."""
|
||||
channel = f"{self._channel_prefix}{key}"
|
||||
pubsub = self._redis.pubsub()
|
||||
try:
|
||||
await pubsub.subscribe(channel)
|
||||
LOG.info("Redis listener started", channel=channel)
|
||||
async for raw_message in pubsub.listen():
|
||||
if raw_message["type"] != "message":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(raw_message["data"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
LOG.warning("Invalid JSON on Redis channel", channel=channel)
|
||||
continue
|
||||
self._dispatch_local(key, data)
|
||||
except asyncio.CancelledError:
|
||||
LOG.info("Redis listener cancelled", channel=channel)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("Redis listener error", channel=channel)
|
||||
finally:
|
||||
try:
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.close()
|
||||
except Exception:
|
||||
LOG.warning("Error closing Redis pubsub", channel=channel)
|
||||
|
||||
def _dispatch_local(self, key: str, message: dict) -> None:
|
||||
"""Fan out a message to all local asyncio queues for this key."""
|
||||
queues = self._subscribers.get(key, [])
|
||||
for queue in queues:
|
||||
try:
|
||||
queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
LOG.warning(
|
||||
"Queue full, dropping message",
|
||||
key=key,
|
||||
channel_prefix=self._channel_prefix,
|
||||
)
|
||||
|
||||
def _cancel_listener(self, key: str) -> None:
|
||||
task = self._listener_tasks.pop(key, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
Reference in New Issue
Block a user