[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)

This commit is contained in:
Aaron Perez
2026-02-18 17:21:58 -05:00
committed by GitHub
parent b48bf707c3
commit e3b6d22fb6
28 changed files with 1989 additions and 41 deletions

View File

@@ -0,0 +1,106 @@
"""Tests for NotificationRegistry pub/sub and get_active_verification_requests (SKY-6)."""
import pytest
from skyvern.forge.sdk.db.agent_db import AgentDB
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
# === Task 1: NotificationRegistry subscribe / publish / unsubscribe ===
@pytest.mark.asyncio
async def test_subscribe_and_publish():
"""Published messages should be received by subscribers."""
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
registry.publish("org_1", {"type": "verification_code_required", "task_id": "tsk_1"})
msg = queue.get_nowait()
assert msg["type"] == "verification_code_required"
assert msg["task_id"] == "tsk_1"
@pytest.mark.asyncio
async def test_multiple_subscribers():
"""All subscribers for an org should receive the same message."""
registry = LocalNotificationRegistry()
q1 = registry.subscribe("org_1")
q2 = registry.subscribe("org_1")
registry.publish("org_1", {"type": "verification_code_required"})
assert not q1.empty()
assert not q2.empty()
assert q1.get_nowait() == q2.get_nowait()
@pytest.mark.asyncio
async def test_publish_wrong_org_does_not_leak():
"""Messages for org_A should not appear in org_B's queue."""
registry = LocalNotificationRegistry()
q_a = registry.subscribe("org_a")
q_b = registry.subscribe("org_b")
registry.publish("org_a", {"type": "test"})
assert not q_a.empty()
assert q_b.empty()
@pytest.mark.asyncio
async def test_unsubscribe():
"""After unsubscribe, the queue should no longer receive messages."""
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
registry.unsubscribe("org_1", queue)
registry.publish("org_1", {"type": "test"})
assert queue.empty()
@pytest.mark.asyncio
async def test_unsubscribe_idempotent():
"""Unsubscribing a queue that's already removed should not raise."""
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
registry.unsubscribe("org_1", queue)
registry.unsubscribe("org_1", queue) # should not raise
# === Task: BaseNotificationRegistry ABC ===
def test_base_notification_registry_cannot_be_instantiated():
"""ABC should not be directly instantiable."""
with pytest.raises(TypeError):
BaseNotificationRegistry()
# === Task: NotificationRegistryFactory ===
@pytest.mark.asyncio
async def test_factory_returns_local_by_default():
"""Factory should return a LocalNotificationRegistry by default."""
registry = NotificationRegistryFactory.get_registry()
assert isinstance(registry, LocalNotificationRegistry)
@pytest.mark.asyncio
async def test_factory_set_and_get():
"""Factory should allow swapping the registry implementation."""
original = NotificationRegistryFactory.get_registry()
try:
custom = LocalNotificationRegistry()
NotificationRegistryFactory.set_registry(custom)
assert NotificationRegistryFactory.get_registry() is custom
finally:
NotificationRegistryFactory.set_registry(original)
# === Task 2: get_active_verification_requests DB method ===
def test_get_active_verification_requests_method_exists():
"""AgentDB should have get_active_verification_requests method."""
assert hasattr(AgentDB, "get_active_verification_requests")

View File

@@ -0,0 +1,544 @@
"""Tests for manual 2FA input without pre-configured TOTP credentials (SKY-6)."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.constants import SPECIAL_FIELD_VERIFICATION_CODE
from skyvern.forge.agent import ForgeAgent
from skyvern.forge.sdk.db.agent_db import AgentDB
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
from skyvern.forge.sdk.routes.credentials import send_totp_code
from skyvern.forge.sdk.schemas.totp_codes import TOTPCodeCreate
from skyvern.schemas.runs import RunEngine
from skyvern.services.otp_service import OTPValue, _get_otp_value_by_run, poll_otp_value
@pytest.mark.asyncio
async def test_get_otp_codes_by_run_exists():
"""get_otp_codes_by_run should exist on AgentDB."""
assert hasattr(AgentDB, "get_otp_codes_by_run"), "AgentDB missing get_otp_codes_by_run method"
@pytest.mark.asyncio
async def test_get_otp_codes_by_run_returns_empty_without_identifiers():
"""get_otp_codes_by_run should return [] when neither task_id nor workflow_run_id is given."""
db = AgentDB.__new__(AgentDB)
result = await db.get_otp_codes_by_run(
organization_id="org_1",
)
assert result == []
# === Task 2: _get_otp_value_by_run OTP service function ===
@pytest.mark.asyncio
async def test_get_otp_value_by_run_returns_code():
"""_get_otp_value_by_run should find OTP codes by task_id."""
mock_code = MagicMock()
mock_code.code = "123456"
mock_code.otp_type = "totp"
mock_db = AsyncMock()
mock_db.get_otp_codes_by_run.return_value = [mock_code]
mock_app = MagicMock()
mock_app.DATABASE = mock_db
with patch("skyvern.services.otp_service.app", new=mock_app):
result = await _get_otp_value_by_run(
organization_id="org_1",
task_id="tsk_1",
)
assert result is not None
assert result.value == "123456"
@pytest.mark.asyncio
async def test_get_otp_value_by_run_returns_none_when_no_codes():
"""_get_otp_value_by_run should return None when no codes found."""
mock_db = AsyncMock()
mock_db.get_otp_codes_by_run.return_value = []
mock_app = MagicMock()
mock_app.DATABASE = mock_db
with patch("skyvern.services.otp_service.app", new=mock_app):
result = await _get_otp_value_by_run(
organization_id="org_1",
task_id="tsk_1",
)
assert result is None
# === Task 3: poll_otp_value without identifier ===
@pytest.mark.asyncio
async def test_poll_otp_value_without_identifier_uses_run_lookup():
"""poll_otp_value should use _get_otp_value_by_run when no identifier/URL provided."""
mock_code = MagicMock()
mock_code.code = "123456"
mock_code.otp_type = "totp"
mock_db = AsyncMock()
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
mock_db.get_otp_codes_by_run.return_value = [mock_code]
mock_db.update_task_2fa_state = AsyncMock()
mock_app = MagicMock()
mock_app.DATABASE = mock_db
with (
patch("skyvern.services.otp_service.app", new=mock_app),
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
):
result = await poll_otp_value(
organization_id="org_1",
task_id="tsk_1",
)
assert result is not None
assert result.value == "123456"
# === Task 6: Integration test — handle_potential_OTP_actions without TOTP config ===
@pytest.mark.asyncio
async def test_handle_potential_OTP_actions_without_totp_config():
"""When LLM detects 2FA but no TOTP config exists, should still enter verification flow."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.organization_id = "org_1"
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_1"
task.workflow_run_id = None
step = MagicMock()
scraped_page = MagicMock()
browser_state = MagicMock()
json_response = {
"should_enter_verification_code": True,
"place_to_enter_verification_code": "input#otp-code",
"actions": [],
}
with patch.object(agent, "handle_potential_verification_code", new_callable=AsyncMock) as mock_handler:
mock_handler.return_value = {"actions": []}
with patch("skyvern.forge.agent.parse_actions", return_value=[]):
result_json, result_actions = await agent.handle_potential_OTP_actions(
task, step, scraped_page, browser_state, json_response
)
mock_handler.assert_called_once()
@pytest.mark.asyncio
async def test_handle_potential_OTP_actions_skips_magic_link_without_totp_config():
"""Magic links should still require TOTP config."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.organization_id = "org_1"
task.totp_verification_url = None
task.totp_identifier = None
step = MagicMock()
scraped_page = MagicMock()
browser_state = MagicMock()
json_response = {
"should_verify_by_magic_link": True,
}
with patch.object(agent, "handle_potential_magic_link", new_callable=AsyncMock) as mock_handler:
result_json, result_actions = await agent.handle_potential_OTP_actions(
task, step, scraped_page, browser_state, json_response
)
mock_handler.assert_not_called()
assert result_actions == []
# === Task 7: verification_code_check always True in LLM prompt ===
@pytest.mark.asyncio
async def test_verification_code_check_always_true_without_totp_config():
"""_build_extract_action_prompt should receive verification_code_check=True even when task has no TOTP config."""
agent = ForgeAgent.__new__(ForgeAgent)
agent.async_operation_pool = MagicMock()
task = MagicMock()
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_1"
task.workflow_run_id = None
task.organization_id = "org_1"
task.url = "https://example.com"
step = MagicMock()
step.step_id = "step_1"
step.order = 0
step.retry_index = 0
scraped_page = MagicMock()
scraped_page.elements = []
browser_state = MagicMock()
with (
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
patch.object(
agent,
"_build_extract_action_prompt",
new_callable=AsyncMock,
return_value=("prompt", False, "extract_action"),
) as mock_build,
):
mock_ctx.current.return_value = None
await agent.build_and_record_step_prompt(
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
)
mock_build.assert_called_once()
_, kwargs = mock_build.call_args
assert kwargs["verification_code_check"] is True
@pytest.mark.asyncio
async def test_verification_code_check_always_true_with_totp_config():
"""_build_extract_action_prompt should receive verification_code_check=True when task HAS TOTP config (unchanged)."""
agent = ForgeAgent.__new__(ForgeAgent)
agent.async_operation_pool = MagicMock()
task = MagicMock()
task.totp_verification_url = "https://otp.example.com"
task.totp_identifier = "user@example.com"
task.task_id = "tsk_2"
task.workflow_run_id = None
task.organization_id = "org_1"
task.url = "https://example.com"
step = MagicMock()
step.step_id = "step_1"
step.order = 0
step.retry_index = 0
scraped_page = MagicMock()
scraped_page.elements = []
browser_state = MagicMock()
with (
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
patch.object(
agent,
"_build_extract_action_prompt",
new_callable=AsyncMock,
return_value=("prompt", False, "extract_action"),
) as mock_build,
):
mock_ctx.current.return_value = None
await agent.build_and_record_step_prompt(
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
)
mock_build.assert_called_once()
_, kwargs = mock_build.call_args
assert kwargs["verification_code_check"] is True
# === Fix: poll_otp_value should pass workflow_id, not workflow_permanent_id ===
@pytest.mark.asyncio
async def test_poll_otp_value_passes_workflow_id_not_permanent_id():
"""poll_otp_value should pass workflow_id (w_* format) to _get_otp_value_from_db, not workflow_permanent_id."""
mock_db = AsyncMock()
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
mock_db.update_workflow_run = AsyncMock()
mock_app = MagicMock()
mock_app.DATABASE = mock_db
with (
patch("skyvern.services.otp_service.app", new=mock_app),
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
patch(
"skyvern.services.otp_service._get_otp_value_from_db",
new_callable=AsyncMock,
return_value=OTPValue(value="654321", type="totp"),
) as mock_get_from_db,
):
result = await poll_otp_value(
organization_id="org_1",
workflow_id="w_123",
workflow_run_id="wr_789",
workflow_permanent_id="wpid_456",
totp_identifier="user@example.com",
)
assert result is not None
assert result.value == "654321"
mock_get_from_db.assert_called_once_with(
"org_1",
"user@example.com",
task_id=None,
workflow_id="w_123",
workflow_run_id="wr_789",
)
# === Fix: send_totp_code should resolve wpid_* to w_* before storage ===
@pytest.mark.asyncio
async def test_send_totp_code_resolves_wpid_to_workflow_id():
"""send_totp_code should resolve wpid_* to w_* before storing in DB."""
mock_workflow = MagicMock()
mock_workflow.workflow_id = "w_abc123"
mock_totp_code = MagicMock()
mock_db = AsyncMock()
mock_db.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
mock_app = MagicMock()
mock_app.DATABASE = mock_db
data = TOTPCodeCreate(
totp_identifier="user@example.com",
content="123456",
workflow_id="wpid_xyz789",
)
curr_org = MagicMock()
curr_org.organization_id = "org_1"
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
await send_totp_code(data=data, curr_org=curr_org)
mock_db.create_otp_code.assert_called_once()
call_kwargs = mock_db.create_otp_code.call_args[1]
assert call_kwargs["workflow_id"] == "w_abc123", f"Expected w_abc123 but got {call_kwargs['workflow_id']}"
@pytest.mark.asyncio
async def test_send_totp_code_w_format_passes_through():
"""send_totp_code should resolve and store w_* format workflow_id correctly."""
mock_workflow = MagicMock()
mock_workflow.workflow_id = "w_abc123"
mock_totp_code = MagicMock()
mock_db = AsyncMock()
mock_db.get_workflow = AsyncMock(return_value=mock_workflow)
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
mock_app = MagicMock()
mock_app.DATABASE = mock_db
data = TOTPCodeCreate(
totp_identifier="user@example.com",
content="123456",
workflow_id="w_abc123",
)
curr_org = MagicMock()
curr_org.organization_id = "org_1"
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
await send_totp_code(data=data, curr_org=curr_org)
call_kwargs = mock_db.create_otp_code.call_args[1]
assert call_kwargs["workflow_id"] == "w_abc123"
@pytest.mark.asyncio
async def test_send_totp_code_none_workflow_id():
"""send_totp_code should pass None workflow_id when not provided."""
mock_totp_code = MagicMock()
mock_db = AsyncMock()
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
mock_app = MagicMock()
mock_app.DATABASE = mock_db
data = TOTPCodeCreate(
totp_identifier="user@example.com",
content="123456",
)
curr_org = MagicMock()
curr_org.organization_id = "org_1"
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
await send_totp_code(data=data, curr_org=curr_org)
call_kwargs = mock_db.create_otp_code.call_args[1]
assert call_kwargs["workflow_id"] is None
# === Fix: _build_navigation_payload should inject code without TOTP config ===
def test_build_navigation_payload_injects_code_without_totp_config():
"""_build_navigation_payload should inject SPECIAL_FIELD_VERIFICATION_CODE even when
task has no totp_verification_url or totp_identifier (manual 2FA flow)."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_manual_2fa"
task.workflow_run_id = "wr_123"
task.navigation_payload = {"username": "user@example.com"}
mock_context = MagicMock()
mock_context.totp_codes = {"tsk_manual_2fa": "123456"}
mock_context.has_magic_link_page.return_value = False
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
mock_skyvern_ctx.ensure_context.return_value = mock_context
result = agent._build_navigation_payload(task)
assert isinstance(result, dict)
assert SPECIAL_FIELD_VERIFICATION_CODE in result
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "123456"
# Original payload preserved
assert result["username"] == "user@example.com"
def test_build_navigation_payload_injects_code_when_payload_is_none():
"""_build_navigation_payload should create a dict with the code when payload is None."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_manual_2fa"
task.workflow_run_id = "wr_123"
task.navigation_payload = None
mock_context = MagicMock()
mock_context.totp_codes = {"tsk_manual_2fa": "999999"}
mock_context.has_magic_link_page.return_value = False
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
mock_skyvern_ctx.ensure_context.return_value = mock_context
result = agent._build_navigation_payload(task)
assert isinstance(result, dict)
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "999999"
def test_build_navigation_payload_no_code_no_injection():
"""_build_navigation_payload should NOT inject anything when no code in context."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_no_code"
task.workflow_run_id = "wr_456"
task.navigation_payload = {"field": "value"}
mock_context = MagicMock()
mock_context.totp_codes = {} # No code in context
mock_context.has_magic_link_page.return_value = False
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
mock_skyvern_ctx.ensure_context.return_value = mock_context
result = agent._build_navigation_payload(task)
assert isinstance(result, dict)
assert SPECIAL_FIELD_VERIFICATION_CODE not in result
assert result["field"] == "value"
# === Task: poll_otp_value publishes 2FA events to notification registry ===
@pytest.mark.asyncio
async def test_poll_otp_value_publishes_required_event_for_task():
"""poll_otp_value should publish verification_code_required when task waiting state is set."""
mock_code = MagicMock()
mock_code.code = "123456"
mock_code.otp_type = "totp"
mock_db = AsyncMock()
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
mock_db.get_otp_codes_by_run.return_value = [mock_code]
mock_db.update_task_2fa_state = AsyncMock()
mock_app = MagicMock()
mock_app.DATABASE = mock_db
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
with (
patch("skyvern.services.otp_service.app", new=mock_app),
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
patch(
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
new=registry,
),
):
await poll_otp_value(organization_id="org_1", task_id="tsk_1")
# Should have received required + resolved messages
messages = []
while not queue.empty():
messages.append(queue.get_nowait())
types = [m["type"] for m in messages]
assert "verification_code_required" in types
assert "verification_code_resolved" in types
required = next(m for m in messages if m["type"] == "verification_code_required")
assert required["task_id"] == "tsk_1"
resolved = next(m for m in messages if m["type"] == "verification_code_resolved")
assert resolved["task_id"] == "tsk_1"
@pytest.mark.asyncio
async def test_poll_otp_value_publishes_required_event_for_workflow_run():
"""poll_otp_value should publish verification_code_required when workflow run waiting state is set."""
mock_code = MagicMock()
mock_code.code = "654321"
mock_code.otp_type = "totp"
mock_db = AsyncMock()
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
mock_db.update_workflow_run = AsyncMock()
mock_db.get_otp_codes_by_run.return_value = [mock_code]
mock_app = MagicMock()
mock_app.DATABASE = mock_db
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
with (
patch("skyvern.services.otp_service.app", new=mock_app),
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
patch(
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
new=registry,
),
):
await poll_otp_value(organization_id="org_1", workflow_run_id="wr_1")
messages = []
while not queue.empty():
messages.append(queue.get_nowait())
types = [m["type"] for m in messages]
assert "verification_code_required" in types
assert "verification_code_resolved" in types
required = next(m for m in messages if m["type"] == "verification_code_required")
assert required["workflow_run_id"] == "wr_1"

View File

@@ -0,0 +1,22 @@
"""Tests for RedisClientFactory."""
from unittest.mock import MagicMock
from skyvern.forge.sdk.redis.factory import RedisClientFactory
def test_default_is_none():
"""Factory returns None when no client has been set."""
# Reset to default state
RedisClientFactory.set_client(None) # type: ignore[arg-type]
assert RedisClientFactory.get_client() is None
def test_set_and_get():
"""Round-trip: set_client then get_client returns the same object."""
mock_client = MagicMock()
RedisClientFactory.set_client(mock_client)
assert RedisClientFactory.get_client() is mock_client
# Cleanup
RedisClientFactory.set_client(None) # type: ignore[arg-type]

View File

@@ -0,0 +1,237 @@
"""Tests for RedisNotificationRegistry (SKY-6).
All tests use a mock Redis client — no real Redis instance required.
"""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_mock_redis() -> MagicMock:
"""Return a mock redis.asyncio.Redis client."""
redis = MagicMock()
redis.publish = AsyncMock()
redis.pubsub = MagicMock()
return redis
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
"""Return a mock PubSub that yields *messages* from ``listen()``.
Each entry in *messages* should look like:
{"type": "message", "data": '{"key": "val"}'}
If *block* is True the async generator will hang forever after
exhausting *messages*, which keeps the listener task alive so that
cancellation semantics can be tested.
"""
pubsub = MagicMock()
pubsub.subscribe = AsyncMock()
pubsub.unsubscribe = AsyncMock()
pubsub.close = AsyncMock()
async def _listen():
for msg in messages or []:
yield msg
if block:
# Keep the listener alive until cancelled
await asyncio.Event().wait()
pubsub.listen = _listen
return pubsub
# ---------------------------------------------------------------------------
# Tests: subscribe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_creates_queue_and_starts_listener():
"""subscribe() should return a queue and start a background listener task."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
queue = registry.subscribe("org_1")
assert isinstance(queue, asyncio.Queue)
assert "org_1" in registry._listener_tasks
task = registry._listener_tasks["org_1"]
assert isinstance(task, asyncio.Task)
# Cleanup
await registry.close()
@pytest.mark.asyncio
async def test_subscribe_reuses_listener_for_same_org():
"""A second subscribe for the same org should NOT create a new listener task."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
registry.subscribe("org_1")
first_task = registry._listener_tasks["org_1"]
registry.subscribe("org_1")
assert registry._listener_tasks["org_1"] is first_task
await registry.close()
# ---------------------------------------------------------------------------
# Tests: unsubscribe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub(block=True)
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
queue = registry.subscribe("org_1")
# Let the listener task start running
await asyncio.sleep(0)
task = registry._listener_tasks["org_1"]
registry.unsubscribe("org_1", queue)
assert "org_1" not in registry._listener_tasks
# Wait for the task to fully complete after cancellation
await asyncio.gather(task, return_exceptions=True)
assert task.cancelled()
await registry.close()
@pytest.mark.asyncio
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
"""If other subscribers remain, the listener task should stay alive."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
q1 = registry.subscribe("org_1")
registry.subscribe("org_1") # second subscriber
registry.unsubscribe("org_1", q1)
assert "org_1" in registry._listener_tasks
await registry.close()
# ---------------------------------------------------------------------------
# Tests: publish
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_publish_calls_redis_publish():
"""publish() should fire-and-forget a Redis PUBLISH."""
redis = _make_mock_redis()
registry = RedisNotificationRegistry(redis)
registry.publish("org_1", {"type": "verification_code_required"})
# Allow the fire-and-forget task to execute
await asyncio.sleep(0)
redis.publish.assert_awaited_once_with(
"skyvern:notifications:org_1",
json.dumps({"type": "verification_code_required"}),
)
await registry.close()
# ---------------------------------------------------------------------------
# Tests: _dispatch_local
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_dispatch_local_fans_out_to_all_queues():
"""_dispatch_local should put the message into every local queue for the org."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
q1 = registry.subscribe("org_1")
q2 = registry.subscribe("org_1")
msg = {"type": "test", "value": 42}
registry._dispatch_local("org_1", msg)
assert q1.get_nowait() == msg
assert q2.get_nowait() == msg
await registry.close()
@pytest.mark.asyncio
async def test_dispatch_local_does_not_leak_across_orgs():
"""Messages dispatched for org_a should not appear in org_b queues."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
q_a = registry.subscribe("org_a")
q_b = registry.subscribe("org_b")
registry._dispatch_local("org_a", {"type": "test"})
assert not q_a.empty()
assert q_b.empty()
await registry.close()
# ---------------------------------------------------------------------------
# Tests: close
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_close_cancels_all_listeners_and_clears_state():
"""close() should cancel every listener task and empty subscriber maps."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub(block=True)
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
registry.subscribe("org_1")
registry.subscribe("org_2")
# Let the listener tasks start running
await asyncio.sleep(0)
task_1 = registry._listener_tasks["org_1"]
task_2 = registry._listener_tasks["org_2"]
await registry.close()
# Wait for the tasks to fully complete after cancellation
await asyncio.gather(task_1, task_2, return_exceptions=True)
assert task_1.cancelled()
assert task_2.cancelled()
assert len(registry._listener_tasks) == 0
assert len(registry._subscribers) == 0

View File

@@ -0,0 +1,262 @@
"""Tests for RedisPubSub (generic pub/sub layer).
All tests use a mock Redis client — no real Redis instance required.
"""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_mock_redis() -> MagicMock:
"""Return a mock redis.asyncio.Redis client."""
redis = MagicMock()
redis.publish = AsyncMock()
redis.pubsub = MagicMock()
return redis
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
"""Return a mock PubSub that yields *messages* from ``listen()``.
Each entry in *messages* should look like:
{"type": "message", "data": '{"key": "val"}'}
If *block* is True the async generator will hang forever after
exhausting *messages*, which keeps the listener task alive so that
cancellation semantics can be tested.
"""
pubsub = MagicMock()
pubsub.subscribe = AsyncMock()
pubsub.unsubscribe = AsyncMock()
pubsub.close = AsyncMock()
async def _listen():
for msg in messages or []:
yield msg
if block:
await asyncio.Event().wait()
pubsub.listen = _listen
return pubsub
PREFIX = "skyvern:test:"
# ---------------------------------------------------------------------------
# Tests: subscribe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_creates_queue_and_starts_listener():
"""subscribe() should return a queue and start a background listener task."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
queue = ps.subscribe("key_1")
assert isinstance(queue, asyncio.Queue)
assert "key_1" in ps._listener_tasks
task = ps._listener_tasks["key_1"]
assert isinstance(task, asyncio.Task)
await ps.close()
@pytest.mark.asyncio
async def test_subscribe_reuses_listener_for_same_key():
"""A second subscribe for the same key should NOT create a new listener task."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
ps.subscribe("key_1")
first_task = ps._listener_tasks["key_1"]
ps.subscribe("key_1")
assert ps._listener_tasks["key_1"] is first_task
await ps.close()
# ---------------------------------------------------------------------------
# Tests: unsubscribe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub(block=True)
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
queue = ps.subscribe("key_1")
await asyncio.sleep(0)
task = ps._listener_tasks["key_1"]
ps.unsubscribe("key_1", queue)
assert "key_1" not in ps._listener_tasks
await asyncio.gather(task, return_exceptions=True)
assert task.cancelled()
await ps.close()
@pytest.mark.asyncio
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
"""If other subscribers remain, the listener task should stay alive."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
q1 = ps.subscribe("key_1")
ps.subscribe("key_1")
ps.unsubscribe("key_1", q1)
assert "key_1" in ps._listener_tasks
await ps.close()
# ---------------------------------------------------------------------------
# Tests: publish
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_publish_calls_redis_publish():
"""publish() should fire-and-forget a Redis PUBLISH with prefixed channel."""
redis = _make_mock_redis()
ps = RedisPubSub(redis, channel_prefix=PREFIX)
ps.publish("key_1", {"type": "event"})
await asyncio.sleep(0)
redis.publish.assert_awaited_once_with(
f"{PREFIX}key_1",
json.dumps({"type": "event"}),
)
await ps.close()
# ---------------------------------------------------------------------------
# Tests: _dispatch_local
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_dispatch_local_fans_out_to_all_queues():
"""_dispatch_local should put the message into every local queue for the key."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
q1 = ps.subscribe("key_1")
q2 = ps.subscribe("key_1")
msg = {"type": "test", "value": 42}
ps._dispatch_local("key_1", msg)
assert q1.get_nowait() == msg
assert q2.get_nowait() == msg
await ps.close()
@pytest.mark.asyncio
async def test_dispatch_local_does_not_leak_across_keys():
"""Messages dispatched for key_a should not appear in key_b queues."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
q_a = ps.subscribe("key_a")
q_b = ps.subscribe("key_b")
ps._dispatch_local("key_a", {"type": "test"})
assert not q_a.empty()
assert q_b.empty()
await ps.close()
# ---------------------------------------------------------------------------
# Tests: close
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_close_cancels_all_listeners_and_clears_state():
"""close() should cancel every listener task and empty subscriber maps."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub(block=True)
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
ps.subscribe("key_1")
ps.subscribe("key_2")
await asyncio.sleep(0)
task_1 = ps._listener_tasks["key_1"]
task_2 = ps._listener_tasks["key_2"]
await ps.close()
await asyncio.gather(task_1, task_2, return_exceptions=True)
assert task_1.cancelled()
assert task_2.cancelled()
assert len(ps._listener_tasks) == 0
assert len(ps._subscribers) == 0
# ---------------------------------------------------------------------------
# Tests: prefix isolation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_different_prefixes_do_not_interfere():
"""Two RedisPubSub instances with different prefixes use separate channels."""
redis = _make_mock_redis()
ps_a = RedisPubSub(redis, channel_prefix="prefix_a:")
ps_b = RedisPubSub(redis, channel_prefix="prefix_b:")
ps_a.publish("key_1", {"from": "a"})
ps_b.publish("key_1", {"from": "b"})
await asyncio.sleep(0)
calls = redis.publish.await_args_list
assert len(calls) == 2
channels = {call.args[0] for call in calls}
assert "prefix_a:key_1" in channels
assert "prefix_b:key_1" in channels
await ps_a.close()
await ps_b.close()