[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)
This commit is contained in:
106
tests/unit_tests/test_notification_registry.py
Normal file
106
tests/unit_tests/test_notification_registry.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Tests for NotificationRegistry pub/sub and get_active_verification_requests (SKY-6)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
||||
|
||||
# === Task 1: NotificationRegistry subscribe / publish / unsubscribe ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_and_publish():
|
||||
"""Published messages should be received by subscribers."""
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
registry.publish("org_1", {"type": "verification_code_required", "task_id": "tsk_1"})
|
||||
msg = queue.get_nowait()
|
||||
assert msg["type"] == "verification_code_required"
|
||||
assert msg["task_id"] == "tsk_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_subscribers():
|
||||
"""All subscribers for an org should receive the same message."""
|
||||
registry = LocalNotificationRegistry()
|
||||
q1 = registry.subscribe("org_1")
|
||||
q2 = registry.subscribe("org_1")
|
||||
|
||||
registry.publish("org_1", {"type": "verification_code_required"})
|
||||
assert not q1.empty()
|
||||
assert not q2.empty()
|
||||
assert q1.get_nowait() == q2.get_nowait()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_wrong_org_does_not_leak():
|
||||
"""Messages for org_A should not appear in org_B's queue."""
|
||||
registry = LocalNotificationRegistry()
|
||||
q_a = registry.subscribe("org_a")
|
||||
q_b = registry.subscribe("org_b")
|
||||
|
||||
registry.publish("org_a", {"type": "test"})
|
||||
assert not q_a.empty()
|
||||
assert q_b.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe():
|
||||
"""After unsubscribe, the queue should no longer receive messages."""
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
registry.unsubscribe("org_1", queue)
|
||||
registry.publish("org_1", {"type": "test"})
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_idempotent():
|
||||
"""Unsubscribing a queue that's already removed should not raise."""
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
registry.unsubscribe("org_1", queue)
|
||||
registry.unsubscribe("org_1", queue) # should not raise
|
||||
|
||||
|
||||
# === Task: BaseNotificationRegistry ABC ===
|
||||
|
||||
|
||||
def test_base_notification_registry_cannot_be_instantiated():
|
||||
"""ABC should not be directly instantiable."""
|
||||
with pytest.raises(TypeError):
|
||||
BaseNotificationRegistry()
|
||||
|
||||
|
||||
# === Task: NotificationRegistryFactory ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_returns_local_by_default():
|
||||
"""Factory should return a LocalNotificationRegistry by default."""
|
||||
registry = NotificationRegistryFactory.get_registry()
|
||||
assert isinstance(registry, LocalNotificationRegistry)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_set_and_get():
|
||||
"""Factory should allow swapping the registry implementation."""
|
||||
original = NotificationRegistryFactory.get_registry()
|
||||
try:
|
||||
custom = LocalNotificationRegistry()
|
||||
NotificationRegistryFactory.set_registry(custom)
|
||||
assert NotificationRegistryFactory.get_registry() is custom
|
||||
finally:
|
||||
NotificationRegistryFactory.set_registry(original)
|
||||
|
||||
|
||||
# === Task 2: get_active_verification_requests DB method ===
|
||||
|
||||
|
||||
def test_get_active_verification_requests_method_exists():
|
||||
"""AgentDB should have get_active_verification_requests method."""
|
||||
assert hasattr(AgentDB, "get_active_verification_requests")
|
||||
544
tests/unit_tests/test_otp_no_config.py
Normal file
544
tests/unit_tests/test_otp_no_config.py
Normal file
@@ -0,0 +1,544 @@
|
||||
"""Tests for manual 2FA input without pre-configured TOTP credentials (SKY-6)."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.constants import SPECIAL_FIELD_VERIFICATION_CODE
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
||||
from skyvern.forge.sdk.routes.credentials import send_totp_code
|
||||
from skyvern.forge.sdk.schemas.totp_codes import TOTPCodeCreate
|
||||
from skyvern.schemas.runs import RunEngine
|
||||
from skyvern.services.otp_service import OTPValue, _get_otp_value_by_run, poll_otp_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_otp_codes_by_run_exists():
|
||||
"""get_otp_codes_by_run should exist on AgentDB."""
|
||||
assert hasattr(AgentDB, "get_otp_codes_by_run"), "AgentDB missing get_otp_codes_by_run method"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_otp_codes_by_run_returns_empty_without_identifiers():
|
||||
"""get_otp_codes_by_run should return [] when neither task_id nor workflow_run_id is given."""
|
||||
db = AgentDB.__new__(AgentDB)
|
||||
result = await db.get_otp_codes_by_run(
|
||||
organization_id="org_1",
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
# === Task 2: _get_otp_value_by_run OTP service function ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_otp_value_by_run_returns_code():
|
||||
"""_get_otp_value_by_run should find OTP codes by task_id."""
|
||||
mock_code = MagicMock()
|
||||
mock_code.code = "123456"
|
||||
mock_code.otp_type = "totp"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
with patch("skyvern.services.otp_service.app", new=mock_app):
|
||||
result = await _get_otp_value_by_run(
|
||||
organization_id="org_1",
|
||||
task_id="tsk_1",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.value == "123456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_otp_value_by_run_returns_none_when_no_codes():
|
||||
"""_get_otp_value_by_run should return None when no codes found."""
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_otp_codes_by_run.return_value = []
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
with patch("skyvern.services.otp_service.app", new=mock_app):
|
||||
result = await _get_otp_value_by_run(
|
||||
organization_id="org_1",
|
||||
task_id="tsk_1",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
# === Task 3: poll_otp_value without identifier ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_otp_value_without_identifier_uses_run_lookup():
|
||||
"""poll_otp_value should use _get_otp_value_by_run when no identifier/URL provided."""
|
||||
mock_code = MagicMock()
|
||||
mock_code.code = "123456"
|
||||
mock_code.otp_type = "totp"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||
mock_db.update_task_2fa_state = AsyncMock()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
with (
|
||||
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await poll_otp_value(
|
||||
organization_id="org_1",
|
||||
task_id="tsk_1",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.value == "123456"
|
||||
|
||||
|
||||
# === Task 6: Integration test — handle_potential_OTP_actions without TOTP config ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_potential_OTP_actions_without_totp_config():
|
||||
"""When LLM detects 2FA but no TOTP config exists, should still enter verification flow."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.organization_id = "org_1"
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_1"
|
||||
task.workflow_run_id = None
|
||||
|
||||
step = MagicMock()
|
||||
scraped_page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
|
||||
json_response = {
|
||||
"should_enter_verification_code": True,
|
||||
"place_to_enter_verification_code": "input#otp-code",
|
||||
"actions": [],
|
||||
}
|
||||
|
||||
with patch.object(agent, "handle_potential_verification_code", new_callable=AsyncMock) as mock_handler:
|
||||
mock_handler.return_value = {"actions": []}
|
||||
with patch("skyvern.forge.agent.parse_actions", return_value=[]):
|
||||
result_json, result_actions = await agent.handle_potential_OTP_actions(
|
||||
task, step, scraped_page, browser_state, json_response
|
||||
)
|
||||
mock_handler.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_potential_OTP_actions_skips_magic_link_without_totp_config():
|
||||
"""Magic links should still require TOTP config."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.organization_id = "org_1"
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
|
||||
step = MagicMock()
|
||||
scraped_page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
|
||||
json_response = {
|
||||
"should_verify_by_magic_link": True,
|
||||
}
|
||||
|
||||
with patch.object(agent, "handle_potential_magic_link", new_callable=AsyncMock) as mock_handler:
|
||||
result_json, result_actions = await agent.handle_potential_OTP_actions(
|
||||
task, step, scraped_page, browser_state, json_response
|
||||
)
|
||||
mock_handler.assert_not_called()
|
||||
assert result_actions == []
|
||||
|
||||
|
||||
# === Task 7: verification_code_check always True in LLM prompt ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verification_code_check_always_true_without_totp_config():
|
||||
"""_build_extract_action_prompt should receive verification_code_check=True even when task has no TOTP config."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
agent.async_operation_pool = MagicMock()
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_1"
|
||||
task.workflow_run_id = None
|
||||
task.organization_id = "org_1"
|
||||
task.url = "https://example.com"
|
||||
|
||||
step = MagicMock()
|
||||
step.step_id = "step_1"
|
||||
step.order = 0
|
||||
step.retry_index = 0
|
||||
|
||||
scraped_page = MagicMock()
|
||||
scraped_page.elements = []
|
||||
|
||||
browser_state = MagicMock()
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
|
||||
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
|
||||
patch.object(
|
||||
agent,
|
||||
"_build_extract_action_prompt",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("prompt", False, "extract_action"),
|
||||
) as mock_build,
|
||||
):
|
||||
mock_ctx.current.return_value = None
|
||||
await agent.build_and_record_step_prompt(
|
||||
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
|
||||
)
|
||||
mock_build.assert_called_once()
|
||||
_, kwargs = mock_build.call_args
|
||||
assert kwargs["verification_code_check"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verification_code_check_always_true_with_totp_config():
|
||||
"""_build_extract_action_prompt should receive verification_code_check=True when task HAS TOTP config (unchanged)."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
agent.async_operation_pool = MagicMock()
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = "https://otp.example.com"
|
||||
task.totp_identifier = "user@example.com"
|
||||
task.task_id = "tsk_2"
|
||||
task.workflow_run_id = None
|
||||
task.organization_id = "org_1"
|
||||
task.url = "https://example.com"
|
||||
|
||||
step = MagicMock()
|
||||
step.step_id = "step_1"
|
||||
step.order = 0
|
||||
step.retry_index = 0
|
||||
|
||||
scraped_page = MagicMock()
|
||||
scraped_page.elements = []
|
||||
|
||||
browser_state = MagicMock()
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
|
||||
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
|
||||
patch.object(
|
||||
agent,
|
||||
"_build_extract_action_prompt",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("prompt", False, "extract_action"),
|
||||
) as mock_build,
|
||||
):
|
||||
mock_ctx.current.return_value = None
|
||||
await agent.build_and_record_step_prompt(
|
||||
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
|
||||
)
|
||||
mock_build.assert_called_once()
|
||||
_, kwargs = mock_build.call_args
|
||||
assert kwargs["verification_code_check"] is True
|
||||
|
||||
|
||||
# === Fix: poll_otp_value should pass workflow_id, not workflow_permanent_id ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_otp_value_passes_workflow_id_not_permanent_id():
|
||||
"""poll_otp_value should pass workflow_id (w_* format) to _get_otp_value_from_db, not workflow_permanent_id."""
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
with (
|
||||
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||
patch(
|
||||
"skyvern.services.otp_service._get_otp_value_from_db",
|
||||
new_callable=AsyncMock,
|
||||
return_value=OTPValue(value="654321", type="totp"),
|
||||
) as mock_get_from_db,
|
||||
):
|
||||
result = await poll_otp_value(
|
||||
organization_id="org_1",
|
||||
workflow_id="w_123",
|
||||
workflow_run_id="wr_789",
|
||||
workflow_permanent_id="wpid_456",
|
||||
totp_identifier="user@example.com",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.value == "654321"
|
||||
mock_get_from_db.assert_called_once_with(
|
||||
"org_1",
|
||||
"user@example.com",
|
||||
task_id=None,
|
||||
workflow_id="w_123",
|
||||
workflow_run_id="wr_789",
|
||||
)
|
||||
|
||||
|
||||
# === Fix: send_totp_code should resolve wpid_* to w_* before storage ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_totp_code_resolves_wpid_to_workflow_id():
|
||||
"""send_totp_code should resolve wpid_* to w_* before storing in DB."""
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.workflow_id = "w_abc123"
|
||||
|
||||
mock_totp_code = MagicMock()
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
data = TOTPCodeCreate(
|
||||
totp_identifier="user@example.com",
|
||||
content="123456",
|
||||
workflow_id="wpid_xyz789",
|
||||
)
|
||||
curr_org = MagicMock()
|
||||
curr_org.organization_id = "org_1"
|
||||
|
||||
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
||||
await send_totp_code(data=data, curr_org=curr_org)
|
||||
|
||||
mock_db.create_otp_code.assert_called_once()
|
||||
call_kwargs = mock_db.create_otp_code.call_args[1]
|
||||
assert call_kwargs["workflow_id"] == "w_abc123", f"Expected w_abc123 but got {call_kwargs['workflow_id']}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_totp_code_w_format_passes_through():
|
||||
"""send_totp_code should resolve and store w_* format workflow_id correctly."""
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.workflow_id = "w_abc123"
|
||||
|
||||
mock_totp_code = MagicMock()
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_workflow = AsyncMock(return_value=mock_workflow)
|
||||
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
data = TOTPCodeCreate(
|
||||
totp_identifier="user@example.com",
|
||||
content="123456",
|
||||
workflow_id="w_abc123",
|
||||
)
|
||||
curr_org = MagicMock()
|
||||
curr_org.organization_id = "org_1"
|
||||
|
||||
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
||||
await send_totp_code(data=data, curr_org=curr_org)
|
||||
|
||||
call_kwargs = mock_db.create_otp_code.call_args[1]
|
||||
assert call_kwargs["workflow_id"] == "w_abc123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_totp_code_none_workflow_id():
|
||||
"""send_totp_code should pass None workflow_id when not provided."""
|
||||
mock_totp_code = MagicMock()
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
data = TOTPCodeCreate(
|
||||
totp_identifier="user@example.com",
|
||||
content="123456",
|
||||
)
|
||||
curr_org = MagicMock()
|
||||
curr_org.organization_id = "org_1"
|
||||
|
||||
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
||||
await send_totp_code(data=data, curr_org=curr_org)
|
||||
|
||||
call_kwargs = mock_db.create_otp_code.call_args[1]
|
||||
assert call_kwargs["workflow_id"] is None
|
||||
|
||||
|
||||
# === Fix: _build_navigation_payload should inject code without TOTP config ===
|
||||
|
||||
|
||||
def test_build_navigation_payload_injects_code_without_totp_config():
|
||||
"""_build_navigation_payload should inject SPECIAL_FIELD_VERIFICATION_CODE even when
|
||||
task has no totp_verification_url or totp_identifier (manual 2FA flow)."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_manual_2fa"
|
||||
task.workflow_run_id = "wr_123"
|
||||
task.navigation_payload = {"username": "user@example.com"}
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.totp_codes = {"tsk_manual_2fa": "123456"}
|
||||
mock_context.has_magic_link_page.return_value = False
|
||||
|
||||
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
||||
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
||||
result = agent._build_navigation_payload(task)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert SPECIAL_FIELD_VERIFICATION_CODE in result
|
||||
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "123456"
|
||||
# Original payload preserved
|
||||
assert result["username"] == "user@example.com"
|
||||
|
||||
|
||||
def test_build_navigation_payload_injects_code_when_payload_is_none():
|
||||
"""_build_navigation_payload should create a dict with the code when payload is None."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_manual_2fa"
|
||||
task.workflow_run_id = "wr_123"
|
||||
task.navigation_payload = None
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.totp_codes = {"tsk_manual_2fa": "999999"}
|
||||
mock_context.has_magic_link_page.return_value = False
|
||||
|
||||
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
||||
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
||||
result = agent._build_navigation_payload(task)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "999999"
|
||||
|
||||
|
||||
def test_build_navigation_payload_no_code_no_injection():
|
||||
"""_build_navigation_payload should NOT inject anything when no code in context."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_no_code"
|
||||
task.workflow_run_id = "wr_456"
|
||||
task.navigation_payload = {"field": "value"}
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.totp_codes = {} # No code in context
|
||||
mock_context.has_magic_link_page.return_value = False
|
||||
|
||||
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
||||
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
||||
result = agent._build_navigation_payload(task)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert SPECIAL_FIELD_VERIFICATION_CODE not in result
|
||||
assert result["field"] == "value"
|
||||
|
||||
|
||||
# === Task: poll_otp_value publishes 2FA events to notification registry ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_otp_value_publishes_required_event_for_task():
|
||||
"""poll_otp_value should publish verification_code_required when task waiting state is set."""
|
||||
mock_code = MagicMock()
|
||||
mock_code.code = "123456"
|
||||
mock_code.otp_type = "totp"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||
mock_db.update_task_2fa_state = AsyncMock()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
with (
|
||||
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||
patch(
|
||||
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
|
||||
new=registry,
|
||||
),
|
||||
):
|
||||
await poll_otp_value(organization_id="org_1", task_id="tsk_1")
|
||||
|
||||
# Should have received required + resolved messages
|
||||
messages = []
|
||||
while not queue.empty():
|
||||
messages.append(queue.get_nowait())
|
||||
|
||||
types = [m["type"] for m in messages]
|
||||
assert "verification_code_required" in types
|
||||
assert "verification_code_resolved" in types
|
||||
|
||||
required = next(m for m in messages if m["type"] == "verification_code_required")
|
||||
assert required["task_id"] == "tsk_1"
|
||||
|
||||
resolved = next(m for m in messages if m["type"] == "verification_code_resolved")
|
||||
assert resolved["task_id"] == "tsk_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_otp_value_publishes_required_event_for_workflow_run():
|
||||
"""poll_otp_value should publish verification_code_required when workflow run waiting state is set."""
|
||||
mock_code = MagicMock()
|
||||
mock_code.code = "654321"
|
||||
mock_code.otp_type = "totp"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
with (
|
||||
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||
patch(
|
||||
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
|
||||
new=registry,
|
||||
),
|
||||
):
|
||||
await poll_otp_value(organization_id="org_1", workflow_run_id="wr_1")
|
||||
|
||||
messages = []
|
||||
while not queue.empty():
|
||||
messages.append(queue.get_nowait())
|
||||
|
||||
types = [m["type"] for m in messages]
|
||||
assert "verification_code_required" in types
|
||||
assert "verification_code_resolved" in types
|
||||
|
||||
required = next(m for m in messages if m["type"] == "verification_code_required")
|
||||
assert required["workflow_run_id"] == "wr_1"
|
||||
22
tests/unit_tests/test_redis_client_factory.py
Normal file
22
tests/unit_tests/test_redis_client_factory.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Tests for RedisClientFactory."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from skyvern.forge.sdk.redis.factory import RedisClientFactory
|
||||
|
||||
|
||||
def test_default_is_none():
|
||||
"""Factory returns None when no client has been set."""
|
||||
# Reset to default state
|
||||
RedisClientFactory.set_client(None) # type: ignore[arg-type]
|
||||
assert RedisClientFactory.get_client() is None
|
||||
|
||||
|
||||
def test_set_and_get():
|
||||
"""Round-trip: set_client then get_client returns the same object."""
|
||||
mock_client = MagicMock()
|
||||
RedisClientFactory.set_client(mock_client)
|
||||
assert RedisClientFactory.get_client() is mock_client
|
||||
|
||||
# Cleanup
|
||||
RedisClientFactory.set_client(None) # type: ignore[arg-type]
|
||||
237
tests/unit_tests/test_redis_notification_registry.py
Normal file
237
tests/unit_tests/test_redis_notification_registry.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Tests for RedisNotificationRegistry (SKY-6).
|
||||
|
||||
All tests use a mock Redis client — no real Redis instance required.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_redis() -> MagicMock:
|
||||
"""Return a mock redis.asyncio.Redis client."""
|
||||
redis = MagicMock()
|
||||
redis.publish = AsyncMock()
|
||||
redis.pubsub = MagicMock()
|
||||
return redis
|
||||
|
||||
|
||||
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
|
||||
"""Return a mock PubSub that yields *messages* from ``listen()``.
|
||||
|
||||
Each entry in *messages* should look like:
|
||||
{"type": "message", "data": '{"key": "val"}'}
|
||||
|
||||
If *block* is True the async generator will hang forever after
|
||||
exhausting *messages*, which keeps the listener task alive so that
|
||||
cancellation semantics can be tested.
|
||||
"""
|
||||
pubsub = MagicMock()
|
||||
pubsub.subscribe = AsyncMock()
|
||||
pubsub.unsubscribe = AsyncMock()
|
||||
pubsub.close = AsyncMock()
|
||||
|
||||
async def _listen():
|
||||
for msg in messages or []:
|
||||
yield msg
|
||||
if block:
|
||||
# Keep the listener alive until cancelled
|
||||
await asyncio.Event().wait()
|
||||
|
||||
pubsub.listen = _listen
|
||||
return pubsub
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: subscribe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_creates_queue_and_starts_listener():
|
||||
"""subscribe() should return a queue and start a background listener task."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
assert isinstance(queue, asyncio.Queue)
|
||||
assert "org_1" in registry._listener_tasks
|
||||
task = registry._listener_tasks["org_1"]
|
||||
assert isinstance(task, asyncio.Task)
|
||||
|
||||
# Cleanup
|
||||
await registry.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_reuses_listener_for_same_org():
|
||||
"""A second subscribe for the same org should NOT create a new listener task."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
registry.subscribe("org_1")
|
||||
first_task = registry._listener_tasks["org_1"]
|
||||
|
||||
registry.subscribe("org_1")
|
||||
assert registry._listener_tasks["org_1"] is first_task
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: unsubscribe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
|
||||
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub(block=True)
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
# Let the listener task start running
|
||||
await asyncio.sleep(0)
|
||||
|
||||
task = registry._listener_tasks["org_1"]
|
||||
|
||||
registry.unsubscribe("org_1", queue)
|
||||
assert "org_1" not in registry._listener_tasks
|
||||
|
||||
# Wait for the task to fully complete after cancellation
|
||||
await asyncio.gather(task, return_exceptions=True)
|
||||
assert task.cancelled()
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
|
||||
"""If other subscribers remain, the listener task should stay alive."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
q1 = registry.subscribe("org_1")
|
||||
registry.subscribe("org_1") # second subscriber
|
||||
|
||||
registry.unsubscribe("org_1", q1)
|
||||
assert "org_1" in registry._listener_tasks
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: publish
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_calls_redis_publish():
|
||||
"""publish() should fire-and-forget a Redis PUBLISH."""
|
||||
redis = _make_mock_redis()
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
|
||||
registry.publish("org_1", {"type": "verification_code_required"})
|
||||
|
||||
# Allow the fire-and-forget task to execute
|
||||
await asyncio.sleep(0)
|
||||
|
||||
redis.publish.assert_awaited_once_with(
|
||||
"skyvern:notifications:org_1",
|
||||
json.dumps({"type": "verification_code_required"}),
|
||||
)
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _dispatch_local
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_local_fans_out_to_all_queues():
|
||||
"""_dispatch_local should put the message into every local queue for the org."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
q1 = registry.subscribe("org_1")
|
||||
q2 = registry.subscribe("org_1")
|
||||
|
||||
msg = {"type": "test", "value": 42}
|
||||
registry._dispatch_local("org_1", msg)
|
||||
|
||||
assert q1.get_nowait() == msg
|
||||
assert q2.get_nowait() == msg
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_local_does_not_leak_across_orgs():
|
||||
"""Messages dispatched for org_a should not appear in org_b queues."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
q_a = registry.subscribe("org_a")
|
||||
q_b = registry.subscribe("org_b")
|
||||
|
||||
registry._dispatch_local("org_a", {"type": "test"})
|
||||
assert not q_a.empty()
|
||||
assert q_b.empty()
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: close
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_cancels_all_listeners_and_clears_state():
|
||||
"""close() should cancel every listener task and empty subscriber maps."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub(block=True)
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
registry.subscribe("org_1")
|
||||
registry.subscribe("org_2")
|
||||
|
||||
# Let the listener tasks start running
|
||||
await asyncio.sleep(0)
|
||||
|
||||
task_1 = registry._listener_tasks["org_1"]
|
||||
task_2 = registry._listener_tasks["org_2"]
|
||||
|
||||
await registry.close()
|
||||
|
||||
# Wait for the tasks to fully complete after cancellation
|
||||
await asyncio.gather(task_1, task_2, return_exceptions=True)
|
||||
assert task_1.cancelled()
|
||||
assert task_2.cancelled()
|
||||
assert len(registry._listener_tasks) == 0
|
||||
assert len(registry._subscribers) == 0
|
||||
262
tests/unit_tests/test_redis_pubsub.py
Normal file
262
tests/unit_tests/test_redis_pubsub.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Tests for RedisPubSub (generic pub/sub layer).
|
||||
|
||||
All tests use a mock Redis client — no real Redis instance required.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_redis() -> MagicMock:
|
||||
"""Return a mock redis.asyncio.Redis client."""
|
||||
redis = MagicMock()
|
||||
redis.publish = AsyncMock()
|
||||
redis.pubsub = MagicMock()
|
||||
return redis
|
||||
|
||||
|
||||
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
|
||||
"""Return a mock PubSub that yields *messages* from ``listen()``.
|
||||
|
||||
Each entry in *messages* should look like:
|
||||
{"type": "message", "data": '{"key": "val"}'}
|
||||
|
||||
If *block* is True the async generator will hang forever after
|
||||
exhausting *messages*, which keeps the listener task alive so that
|
||||
cancellation semantics can be tested.
|
||||
"""
|
||||
pubsub = MagicMock()
|
||||
pubsub.subscribe = AsyncMock()
|
||||
pubsub.unsubscribe = AsyncMock()
|
||||
pubsub.close = AsyncMock()
|
||||
|
||||
async def _listen():
|
||||
for msg in messages or []:
|
||||
yield msg
|
||||
if block:
|
||||
await asyncio.Event().wait()
|
||||
|
||||
pubsub.listen = _listen
|
||||
return pubsub
|
||||
|
||||
|
||||
PREFIX = "skyvern:test:"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: subscribe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_creates_queue_and_starts_listener():
|
||||
"""subscribe() should return a queue and start a background listener task."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
queue = ps.subscribe("key_1")
|
||||
|
||||
assert isinstance(queue, asyncio.Queue)
|
||||
assert "key_1" in ps._listener_tasks
|
||||
task = ps._listener_tasks["key_1"]
|
||||
assert isinstance(task, asyncio.Task)
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_reuses_listener_for_same_key():
|
||||
"""A second subscribe for the same key should NOT create a new listener task."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
ps.subscribe("key_1")
|
||||
first_task = ps._listener_tasks["key_1"]
|
||||
|
||||
ps.subscribe("key_1")
|
||||
assert ps._listener_tasks["key_1"] is first_task
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: unsubscribe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
|
||||
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub(block=True)
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
queue = ps.subscribe("key_1")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
task = ps._listener_tasks["key_1"]
|
||||
|
||||
ps.unsubscribe("key_1", queue)
|
||||
assert "key_1" not in ps._listener_tasks
|
||||
|
||||
await asyncio.gather(task, return_exceptions=True)
|
||||
assert task.cancelled()
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
|
||||
"""If other subscribers remain, the listener task should stay alive."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
q1 = ps.subscribe("key_1")
|
||||
ps.subscribe("key_1")
|
||||
|
||||
ps.unsubscribe("key_1", q1)
|
||||
assert "key_1" in ps._listener_tasks
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: publish
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_calls_redis_publish():
|
||||
"""publish() should fire-and-forget a Redis PUBLISH with prefixed channel."""
|
||||
redis = _make_mock_redis()
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
|
||||
ps.publish("key_1", {"type": "event"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
redis.publish.assert_awaited_once_with(
|
||||
f"{PREFIX}key_1",
|
||||
json.dumps({"type": "event"}),
|
||||
)
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _dispatch_local
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_local_fans_out_to_all_queues():
|
||||
"""_dispatch_local should put the message into every local queue for the key."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
q1 = ps.subscribe("key_1")
|
||||
q2 = ps.subscribe("key_1")
|
||||
|
||||
msg = {"type": "test", "value": 42}
|
||||
ps._dispatch_local("key_1", msg)
|
||||
|
||||
assert q1.get_nowait() == msg
|
||||
assert q2.get_nowait() == msg
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_local_does_not_leak_across_keys():
|
||||
"""Messages dispatched for key_a should not appear in key_b queues."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
q_a = ps.subscribe("key_a")
|
||||
q_b = ps.subscribe("key_b")
|
||||
|
||||
ps._dispatch_local("key_a", {"type": "test"})
|
||||
assert not q_a.empty()
|
||||
assert q_b.empty()
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: close
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_cancels_all_listeners_and_clears_state():
|
||||
"""close() should cancel every listener task and empty subscriber maps."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub(block=True)
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
ps.subscribe("key_1")
|
||||
ps.subscribe("key_2")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
task_1 = ps._listener_tasks["key_1"]
|
||||
task_2 = ps._listener_tasks["key_2"]
|
||||
|
||||
await ps.close()
|
||||
|
||||
await asyncio.gather(task_1, task_2, return_exceptions=True)
|
||||
assert task_1.cancelled()
|
||||
assert task_2.cancelled()
|
||||
assert len(ps._listener_tasks) == 0
|
||||
assert len(ps._subscribers) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: prefix isolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_prefixes_do_not_interfere():
|
||||
"""Two RedisPubSub instances with different prefixes use separate channels."""
|
||||
redis = _make_mock_redis()
|
||||
|
||||
ps_a = RedisPubSub(redis, channel_prefix="prefix_a:")
|
||||
ps_b = RedisPubSub(redis, channel_prefix="prefix_b:")
|
||||
|
||||
ps_a.publish("key_1", {"from": "a"})
|
||||
ps_b.publish("key_1", {"from": "b"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
calls = redis.publish.await_args_list
|
||||
assert len(calls) == 2
|
||||
|
||||
channels = {call.args[0] for call in calls}
|
||||
assert "prefix_a:key_1" in channels
|
||||
assert "prefix_b:key_1" in channels
|
||||
|
||||
await ps_a.close()
|
||||
await ps_b.close()
|
||||
Reference in New Issue
Block a user