238 lines
7.1 KiB
Python
238 lines
7.1 KiB
Python
"""Tests for RedisNotificationRegistry (SKY-6).
|
|
|
|
All tests use a mock Redis client — no real Redis instance required.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_mock_redis() -> MagicMock:
|
|
"""Return a mock redis.asyncio.Redis client."""
|
|
redis = MagicMock()
|
|
redis.publish = AsyncMock()
|
|
redis.pubsub = MagicMock()
|
|
return redis
|
|
|
|
|
|
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
|
|
"""Return a mock PubSub that yields *messages* from ``listen()``.
|
|
|
|
Each entry in *messages* should look like:
|
|
{"type": "message", "data": '{"key": "val"}'}
|
|
|
|
If *block* is True the async generator will hang forever after
|
|
exhausting *messages*, which keeps the listener task alive so that
|
|
cancellation semantics can be tested.
|
|
"""
|
|
pubsub = MagicMock()
|
|
pubsub.subscribe = AsyncMock()
|
|
pubsub.unsubscribe = AsyncMock()
|
|
pubsub.close = AsyncMock()
|
|
|
|
async def _listen():
|
|
for msg in messages or []:
|
|
yield msg
|
|
if block:
|
|
# Keep the listener alive until cancelled
|
|
await asyncio.Event().wait()
|
|
|
|
pubsub.listen = _listen
|
|
return pubsub
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: subscribe
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_creates_queue_and_starts_listener():
|
|
"""subscribe() should return a queue and start a background listener task."""
|
|
redis = _make_mock_redis()
|
|
pubsub = _make_mock_pubsub()
|
|
redis.pubsub.return_value = pubsub
|
|
|
|
registry = RedisNotificationRegistry(redis)
|
|
queue = registry.subscribe("org_1")
|
|
|
|
assert isinstance(queue, asyncio.Queue)
|
|
assert "org_1" in registry._listener_tasks
|
|
task = registry._listener_tasks["org_1"]
|
|
assert isinstance(task, asyncio.Task)
|
|
|
|
# Cleanup
|
|
await registry.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_reuses_listener_for_same_org():
|
|
"""A second subscribe for the same org should NOT create a new listener task."""
|
|
redis = _make_mock_redis()
|
|
pubsub = _make_mock_pubsub()
|
|
redis.pubsub.return_value = pubsub
|
|
|
|
registry = RedisNotificationRegistry(redis)
|
|
registry.subscribe("org_1")
|
|
first_task = registry._listener_tasks["org_1"]
|
|
|
|
registry.subscribe("org_1")
|
|
assert registry._listener_tasks["org_1"] is first_task
|
|
|
|
await registry.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: unsubscribe
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
|
|
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
|
|
redis = _make_mock_redis()
|
|
pubsub = _make_mock_pubsub(block=True)
|
|
redis.pubsub.return_value = pubsub
|
|
|
|
registry = RedisNotificationRegistry(redis)
|
|
queue = registry.subscribe("org_1")
|
|
|
|
# Let the listener task start running
|
|
await asyncio.sleep(0)
|
|
|
|
task = registry._listener_tasks["org_1"]
|
|
|
|
registry.unsubscribe("org_1", queue)
|
|
assert "org_1" not in registry._listener_tasks
|
|
|
|
# Wait for the task to fully complete after cancellation
|
|
await asyncio.gather(task, return_exceptions=True)
|
|
assert task.cancelled()
|
|
|
|
await registry.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
|
|
"""If other subscribers remain, the listener task should stay alive."""
|
|
redis = _make_mock_redis()
|
|
pubsub = _make_mock_pubsub()
|
|
redis.pubsub.return_value = pubsub
|
|
|
|
registry = RedisNotificationRegistry(redis)
|
|
q1 = registry.subscribe("org_1")
|
|
registry.subscribe("org_1") # second subscriber
|
|
|
|
registry.unsubscribe("org_1", q1)
|
|
assert "org_1" in registry._listener_tasks
|
|
|
|
await registry.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: publish
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_calls_redis_publish():
|
|
"""publish() should fire-and-forget a Redis PUBLISH."""
|
|
redis = _make_mock_redis()
|
|
registry = RedisNotificationRegistry(redis)
|
|
|
|
registry.publish("org_1", {"type": "verification_code_required"})
|
|
|
|
# Allow the fire-and-forget task to execute
|
|
await asyncio.sleep(0)
|
|
|
|
redis.publish.assert_awaited_once_with(
|
|
"skyvern:notifications:org_1",
|
|
json.dumps({"type": "verification_code_required"}),
|
|
)
|
|
|
|
await registry.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: _dispatch_local
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dispatch_local_fans_out_to_all_queues():
|
|
"""_dispatch_local should put the message into every local queue for the org."""
|
|
redis = _make_mock_redis()
|
|
pubsub = _make_mock_pubsub()
|
|
redis.pubsub.return_value = pubsub
|
|
|
|
registry = RedisNotificationRegistry(redis)
|
|
q1 = registry.subscribe("org_1")
|
|
q2 = registry.subscribe("org_1")
|
|
|
|
msg = {"type": "test", "value": 42}
|
|
registry._dispatch_local("org_1", msg)
|
|
|
|
assert q1.get_nowait() == msg
|
|
assert q2.get_nowait() == msg
|
|
|
|
await registry.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dispatch_local_does_not_leak_across_orgs():
|
|
"""Messages dispatched for org_a should not appear in org_b queues."""
|
|
redis = _make_mock_redis()
|
|
pubsub = _make_mock_pubsub()
|
|
redis.pubsub.return_value = pubsub
|
|
|
|
registry = RedisNotificationRegistry(redis)
|
|
q_a = registry.subscribe("org_a")
|
|
q_b = registry.subscribe("org_b")
|
|
|
|
registry._dispatch_local("org_a", {"type": "test"})
|
|
assert not q_a.empty()
|
|
assert q_b.empty()
|
|
|
|
await registry.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: close
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_cancels_all_listeners_and_clears_state():
|
|
"""close() should cancel every listener task and empty subscriber maps."""
|
|
redis = _make_mock_redis()
|
|
pubsub = _make_mock_pubsub(block=True)
|
|
redis.pubsub.return_value = pubsub
|
|
|
|
registry = RedisNotificationRegistry(redis)
|
|
registry.subscribe("org_1")
|
|
registry.subscribe("org_2")
|
|
|
|
# Let the listener tasks start running
|
|
await asyncio.sleep(0)
|
|
|
|
task_1 = registry._listener_tasks["org_1"]
|
|
task_2 = registry._listener_tasks["org_2"]
|
|
|
|
await registry.close()
|
|
|
|
# Wait for the tasks to fully complete after cancellation
|
|
await asyncio.gather(task_1, task_2, return_exceptions=True)
|
|
assert task_1.cancelled()
|
|
assert task_2.cancelled()
|
|
assert len(registry._listener_tasks) == 0
|
|
assert len(registry._subscribers) == 0
|