131 lines
4.9 KiB
Python
131 lines
4.9 KiB
Python
"""Generic Redis pub/sub layer.
|
|
|
|
Extracted from ``RedisNotificationRegistry`` so that any feature
|
|
(notifications, events, cache invalidation, etc.) can reuse the same
|
|
pattern with its own channel prefix.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from collections import defaultdict
|
|
|
|
import structlog
|
|
from redis.asyncio import Redis
|
|
|
|
LOG = structlog.get_logger()
|
|
|
|
|
|
class RedisPubSub:
|
|
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per key."""
|
|
|
|
def __init__(self, redis_client: Redis, channel_prefix: str) -> None:
|
|
self._redis = redis_client
|
|
self._channel_prefix = channel_prefix
|
|
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
|
|
# One listener task per key channel
|
|
self._listener_tasks: dict[str, asyncio.Task[None]] = {}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public interface
|
|
# ------------------------------------------------------------------
|
|
|
|
def subscribe(self, key: str) -> asyncio.Queue[dict]:
|
|
queue: asyncio.Queue[dict] = asyncio.Queue()
|
|
self._subscribers[key].append(queue)
|
|
|
|
# Spin up a Redis listener if this is the first local subscriber
|
|
if key not in self._listener_tasks:
|
|
task = asyncio.get_running_loop().create_task(self._listen(key))
|
|
self._listener_tasks[key] = task
|
|
|
|
LOG.info("PubSub subscriber added", key=key, channel_prefix=self._channel_prefix)
|
|
return queue
|
|
|
|
def unsubscribe(self, key: str, queue: asyncio.Queue[dict]) -> None:
|
|
queues = self._subscribers.get(key)
|
|
if queues:
|
|
try:
|
|
queues.remove(queue)
|
|
except ValueError:
|
|
pass
|
|
if not queues:
|
|
del self._subscribers[key]
|
|
self._cancel_listener(key)
|
|
LOG.info("PubSub subscriber removed", key=key, channel_prefix=self._channel_prefix)
|
|
|
|
def publish(self, key: str, message: dict) -> None:
|
|
"""Fire-and-forget Redis PUBLISH."""
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
loop.create_task(self._publish_to_redis(key, message))
|
|
except RuntimeError:
|
|
LOG.warning(
|
|
"No running event loop; cannot publish via Redis",
|
|
key=key,
|
|
channel_prefix=self._channel_prefix,
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
"""Cancel all listener tasks and clear state. Call on shutdown."""
|
|
for key in list(self._listener_tasks):
|
|
self._cancel_listener(key)
|
|
self._subscribers.clear()
|
|
LOG.info("RedisPubSub closed", channel_prefix=self._channel_prefix)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _publish_to_redis(self, key: str, message: dict) -> None:
|
|
channel = f"{self._channel_prefix}{key}"
|
|
try:
|
|
await self._redis.publish(channel, json.dumps(message))
|
|
except Exception:
|
|
LOG.exception("Failed to publish to Redis", key=key, channel_prefix=self._channel_prefix)
|
|
|
|
async def _listen(self, key: str) -> None:
|
|
"""Subscribe to a Redis channel and fan out messages locally."""
|
|
channel = f"{self._channel_prefix}{key}"
|
|
pubsub = self._redis.pubsub()
|
|
try:
|
|
await pubsub.subscribe(channel)
|
|
LOG.info("Redis listener started", channel=channel)
|
|
async for raw_message in pubsub.listen():
|
|
if raw_message["type"] != "message":
|
|
continue
|
|
try:
|
|
data = json.loads(raw_message["data"])
|
|
except (json.JSONDecodeError, TypeError):
|
|
LOG.warning("Invalid JSON on Redis channel", channel=channel)
|
|
continue
|
|
self._dispatch_local(key, data)
|
|
except asyncio.CancelledError:
|
|
LOG.info("Redis listener cancelled", channel=channel)
|
|
raise
|
|
except Exception:
|
|
LOG.exception("Redis listener error", channel=channel)
|
|
finally:
|
|
try:
|
|
await pubsub.unsubscribe(channel)
|
|
await pubsub.close()
|
|
except Exception:
|
|
LOG.warning("Error closing Redis pubsub", channel=channel)
|
|
|
|
def _dispatch_local(self, key: str, message: dict) -> None:
|
|
"""Fan out a message to all local asyncio queues for this key."""
|
|
queues = self._subscribers.get(key, [])
|
|
for queue in queues:
|
|
try:
|
|
queue.put_nowait(message)
|
|
except asyncio.QueueFull:
|
|
LOG.warning(
|
|
"Queue full, dropping message",
|
|
key=key,
|
|
channel_prefix=self._channel_prefix,
|
|
)
|
|
|
|
def _cancel_listener(self, key: str) -> None:
|
|
task = self._listener_tasks.pop(key, None)
|
|
if task and not task.done():
|
|
task.cancel()
|