From e3b6d22fb65bb08839f66bdff254c7f8ae9de6bd Mon Sep 17 00:00:00 2001 From: Aaron Perez Date: Wed, 18 Feb 2026 17:21:58 -0500 Subject: [PATCH] [SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786) --- .env.example | 9 + ...02_13_0000-add_2fa_waiting_state_fields.py | 39 ++ skyvern/config.py | 7 + skyvern/forge/agent.py | 21 +- skyvern/forge/api_app.py | 14 + skyvern/forge/forge_app.py | 12 + skyvern/forge/sdk/db/agent_db.py | 139 +++++ skyvern/forge/sdk/db/models.py | 8 + skyvern/forge/sdk/db/utils.py | 6 + skyvern/forge/sdk/notification/base.py | 21 + skyvern/forge/sdk/notification/factory.py | 14 + skyvern/forge/sdk/notification/local.py | 45 ++ skyvern/forge/sdk/notification/redis.py | 55 ++ skyvern/forge/sdk/redis/__init__.py | 0 skyvern/forge/sdk/redis/factory.py | 21 + skyvern/forge/sdk/redis/pubsub.py | 130 +++++ skyvern/forge/sdk/routes/__init__.py | 1 + skyvern/forge/sdk/routes/credentials.py | 9 +- .../sdk/routes/streaming/notifications.py | 101 ++++ skyvern/forge/sdk/schemas/tasks.py | 11 + skyvern/forge/sdk/workflow/models/workflow.py | 8 + skyvern/forge/sdk/workflow/service.py | 4 + skyvern/services/otp_service.py | 184 +++++- .../unit_tests/test_notification_registry.py | 106 ++++ tests/unit_tests/test_otp_no_config.py | 544 ++++++++++++++++++ tests/unit_tests/test_redis_client_factory.py | 22 + .../test_redis_notification_registry.py | 237 ++++++++ tests/unit_tests/test_redis_pubsub.py | 262 +++++++++ 28 files changed, 1989 insertions(+), 41 deletions(-) create mode 100644 alembic/versions/2026_02_13_0000-add_2fa_waiting_state_fields.py create mode 100644 skyvern/forge/sdk/notification/base.py create mode 100644 skyvern/forge/sdk/notification/factory.py create mode 100644 skyvern/forge/sdk/notification/local.py create mode 100644 skyvern/forge/sdk/notification/redis.py create mode 100644 skyvern/forge/sdk/redis/__init__.py create mode 100644 skyvern/forge/sdk/redis/factory.py create mode 100644 skyvern/forge/sdk/redis/pubsub.py create mode 100644 skyvern/forge/sdk/routes/streaming/notifications.py create mode 100644 tests/unit_tests/test_notification_registry.py create mode 100644 tests/unit_tests/test_otp_no_config.py create mode 100644 tests/unit_tests/test_redis_client_factory.py create mode 100644 tests/unit_tests/test_redis_notification_registry.py create mode 100644 tests/unit_tests/test_redis_pubsub.py diff --git a/.env.example b/.env.example index 11bf6c25..030c4027 100644 --- a/.env.example +++ b/.env.example @@ -141,3 +141,12 @@ SKYVERN_AUTH_BITWARDEN_CLIENT_SECRET=your-client-secret-here # Timeout in seconds for Bitwarden operations # BITWARDEN_TIMEOUT_SECONDS=60 + +# Shared Redis URL used by any service that needs Redis (pub/sub, cache, etc.) +# REDIS_URL=redis://localhost:6379/0 + +# Notification registry type: "local" (default, in-process) or "redis" (multi-pod) +# NOTIFICATION_REGISTRY_TYPE=local + +# Optional: override Redis URL specifically for notifications (falls back to REDIS_URL) +# NOTIFICATION_REDIS_URL= \ No newline at end of file diff --git a/alembic/versions/2026_02_13_0000-add_2fa_waiting_state_fields.py b/alembic/versions/2026_02_13_0000-add_2fa_waiting_state_fields.py new file mode 100644 index 00000000..d070ac4c --- /dev/null +++ b/alembic/versions/2026_02_13_0000-add_2fa_waiting_state_fields.py @@ -0,0 +1,39 @@ +"""add 2fa waiting state fields to workflow_runs and tasks + +Revision ID: a1b2c3d4e5f6 +Revises: 43217e31df12 +Create Date: 2026-02-13 00:00:00.000000+00:00 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, None] = "43217e31df12" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute( + "ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS waiting_for_verification_code BOOLEAN NOT NULL DEFAULT false" + ) + op.execute("ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS verification_code_identifier VARCHAR") + op.execute("ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS verification_code_polling_started_at TIMESTAMP") + op.execute( + "ALTER TABLE tasks ADD COLUMN IF NOT EXISTS waiting_for_verification_code BOOLEAN NOT NULL DEFAULT false" + ) + op.execute("ALTER TABLE tasks ADD COLUMN IF NOT EXISTS verification_code_identifier VARCHAR") + op.execute("ALTER TABLE tasks ADD COLUMN IF NOT EXISTS verification_code_polling_started_at TIMESTAMP") + + +def downgrade() -> None: + op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS verification_code_polling_started_at") + op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS verification_code_identifier") + op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS waiting_for_verification_code") + op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS verification_code_polling_started_at") + op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS verification_code_identifier") + op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS waiting_for_verification_code") diff --git a/skyvern/config.py b/skyvern/config.py index d2ca9cf4..85e6c6e5 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -98,6 +98,13 @@ class Settings(BaseSettings): # Supported storage types: local, s3cloud, azureblob SKYVERN_STORAGE_TYPE: str = "local" + # Shared Redis URL (used by any service that needs Redis) + REDIS_URL: str = "redis://localhost:6379/0" + + # Notification registry settings ("local" or "redis") + NOTIFICATION_REGISTRY_TYPE: str = "local" + NOTIFICATION_REDIS_URL: str | None = None # Deprecated: falls back to REDIS_URL + # S3/AWS settings AWS_REGION: str = "us-east-1" MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index e90b56b6..eca8c7c4 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -2544,7 +2544,7 @@ class ForgeAgent: step, browser_state, scraped_page, - verification_code_check=bool(task.totp_verification_url or task.totp_identifier), + verification_code_check=True, expire_verification_code=True, ) @@ -3169,7 +3169,7 @@ class ForgeAgent: current_context = skyvern_context.ensure_context() verification_code = current_context.totp_codes.get(task.task_id) - if (task.totp_verification_url or task.totp_identifier) and verification_code: + if verification_code: if ( isinstance(final_navigation_payload, dict) and SPECIAL_FIELD_VERIFICATION_CODE not in final_navigation_payload @@ -4444,13 +4444,11 @@ class ForgeAgent: if not task.organization_id: return json_response, [] - if not task.totp_verification_url and not task.totp_identifier: - return json_response, [] - should_verify_by_magic_link = json_response.get("should_verify_by_magic_link") place_to_enter_verification_code = json_response.get("place_to_enter_verification_code") should_enter_verification_code = json_response.get("should_enter_verification_code") + # If no OTP verification needed, return early to avoid unnecessary processing if ( not should_verify_by_magic_link and not place_to_enter_verification_code @@ -4466,8 +4464,10 @@ class ForgeAgent: return json_response, actions if should_verify_by_magic_link: - actions = await self.handle_potential_magic_link(task, step, scraped_page, browser_state, json_response) - return json_response, actions + # Magic links still require TOTP config (need a source to poll the link from) + if task.totp_verification_url or task.totp_identifier: + actions = await self.handle_potential_magic_link(task, step, scraped_page, browser_state, json_response) + return json_response, actions return json_response, [] @@ -4524,12 +4524,7 @@ class ForgeAgent: ) -> dict[str, Any]: place_to_enter_verification_code = json_response.get("place_to_enter_verification_code") should_enter_verification_code = json_response.get("should_enter_verification_code") - if ( - place_to_enter_verification_code - and should_enter_verification_code - and (task.totp_verification_url or task.totp_identifier) - and task.organization_id - ): + if place_to_enter_verification_code and should_enter_verification_code and task.organization_id: LOG.info("Need verification code") workflow_id = workflow_permanent_id = None if task.workflow_run_id: diff --git a/skyvern/forge/api_app.py b/skyvern/forge/api_app.py index 339bca5e..46c4ad43 100644 --- a/skyvern/forge/api_app.py +++ b/skyvern/forge/api_app.py @@ -80,6 +80,20 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, Any]: # Stop cleanup scheduler await stop_cleanup_scheduler() + # Close notification registry (e.g. cancel Redis listener tasks) + from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory + + registry = NotificationRegistryFactory.get_registry() + if hasattr(registry, "close"): + await registry.close() + + # Close shared Redis client (after registry so listener tasks drain first) + from skyvern.forge.sdk.redis.factory import RedisClientFactory + + redis_client = RedisClientFactory.get_client() + if redis_client is not None: + await redis_client.close() + if forge_app.api_app_shutdown_event: LOG.info("Calling api app shutdown event") try: diff --git a/skyvern/forge/forge_app.py b/skyvern/forge/forge_app.py index 4a954828..dcfe9c99 100644 --- a/skyvern/forge/forge_app.py +++ b/skyvern/forge/forge_app.py @@ -110,6 +110,18 @@ def create_forge_app() -> ForgeApp: StorageFactory.set_storage(AzureStorage()) app.STORAGE = StorageFactory.get_storage() app.CACHE = CacheFactory.get_cache() + + if settings.NOTIFICATION_REGISTRY_TYPE == "redis": + from redis.asyncio import from_url as redis_from_url + + from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory + from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry + from skyvern.forge.sdk.redis.factory import RedisClientFactory + + redis_url = settings.NOTIFICATION_REDIS_URL or settings.REDIS_URL + redis_client = redis_from_url(redis_url, decode_responses=True) + RedisClientFactory.set_client(redis_client) + NotificationRegistryFactory.set_registry(RedisNotificationRegistry(redis_client)) app.ARTIFACT_MANAGER = ArtifactManager() app.BROWSER_MANAGER = RealBrowserManager() app.EXPERIMENTATION_PROVIDER = NoOpExperimentationProvider() diff --git a/skyvern/forge/sdk/db/agent_db.py b/skyvern/forge/sdk/db/agent_db.py index 30e343de..de8c9fea 100644 --- a/skyvern/forge/sdk/db/agent_db.py +++ b/skyvern/forge/sdk/db/agent_db.py @@ -849,6 +849,102 @@ class AgentDB(BaseAlchemyDB): LOG.error("UnexpectedError", exc_info=True) raise + async def update_task_2fa_state( + self, + task_id: str, + organization_id: str, + waiting_for_verification_code: bool, + verification_code_identifier: str | None = None, + verification_code_polling_started_at: datetime | None = None, + ) -> Task: + """Update task 2FA verification code waiting state.""" + try: + async with self.Session() as session: + if task := ( + await session.scalars( + select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id) + ) + ).first(): + task.waiting_for_verification_code = waiting_for_verification_code + if verification_code_identifier is not None: + task.verification_code_identifier = verification_code_identifier + if verification_code_polling_started_at is not None: + task.verification_code_polling_started_at = verification_code_polling_started_at + if not waiting_for_verification_code: + # Clear identifiers when no longer waiting + task.verification_code_identifier = None + task.verification_code_polling_started_at = None + await session.commit() + updated_task = await self.get_task(task_id, organization_id=organization_id) + if not updated_task: + raise NotFoundError("Task not found") + return updated_task + else: + raise NotFoundError("Task not found") + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + + @read_retry() + async def get_active_verification_requests(self, organization_id: str) -> list[dict]: + """Return active 2FA verification requests for an organization. + + Queries both tasks and workflow runs where waiting_for_verification_code=True. + Used to provide initial state when a WebSocket notification client connects. + """ + results: list[dict] = [] + async with self.Session() as session: + # Tasks waiting for verification (exclude finalized tasks) + finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()] + task_rows = ( + await session.scalars( + select(TaskModel) + .filter_by(organization_id=organization_id) + .filter_by(waiting_for_verification_code=True) + .filter_by(workflow_run_id=None) + .filter(TaskModel.status.not_in(finalized_task_statuses)) + .filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1)) + ) + ).all() + for t in task_rows: + results.append( + { + "task_id": t.task_id, + "workflow_run_id": None, + "verification_code_identifier": t.verification_code_identifier, + "verification_code_polling_started_at": ( + t.verification_code_polling_started_at.isoformat() + if t.verification_code_polling_started_at + else None + ), + } + ) + # Workflow runs waiting for verification (exclude finalized runs) + finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()] + wr_rows = ( + await session.scalars( + select(WorkflowRunModel) + .filter_by(organization_id=organization_id) + .filter_by(waiting_for_verification_code=True) + .filter(WorkflowRunModel.status.not_in(finalized_wr_statuses)) + .filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1)) + ) + ).all() + for wr in wr_rows: + results.append( + { + "task_id": None, + "workflow_run_id": wr.workflow_run_id, + "verification_code_identifier": wr.verification_code_identifier, + "verification_code_polling_started_at": ( + wr.verification_code_polling_started_at.isoformat() + if wr.verification_code_polling_started_at + else None + ), + } + ) + return results + async def bulk_update_tasks( self, task_ids: list[str], @@ -2794,6 +2890,9 @@ class AgentDB(BaseAlchemyDB): ai_fallback: bool | None = None, depends_on_workflow_run_id: str | None = None, browser_session_id: str | None = None, + waiting_for_verification_code: bool | None = None, + verification_code_identifier: str | None = None, + verification_code_polling_started_at: datetime | None = None, ) -> WorkflowRun: async with self.Session() as session: workflow_run = ( @@ -2826,6 +2925,17 @@ class AgentDB(BaseAlchemyDB): workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id if browser_session_id: workflow_run.browser_session_id = browser_session_id + # 2FA verification code waiting state updates + if waiting_for_verification_code is not None: + workflow_run.waiting_for_verification_code = waiting_for_verification_code + if verification_code_identifier is not None: + workflow_run.verification_code_identifier = verification_code_identifier + if verification_code_polling_started_at is not None: + workflow_run.verification_code_polling_started_at = verification_code_polling_started_at + if waiting_for_verification_code is not None and not waiting_for_verification_code: + # Clear related fields when waiting is set to False + workflow_run.verification_code_identifier = None + workflow_run.verification_code_polling_started_at = None await session.commit() await save_workflow_run_logs(workflow_run_id) await session.refresh(workflow_run) @@ -3995,6 +4105,35 @@ class AgentDB(BaseAlchemyDB): totp_code = (await session.scalars(query)).all() return [TOTPCode.model_validate(totp_code) for totp_code in totp_code] + async def get_otp_codes_by_run( + self, + organization_id: str, + task_id: str | None = None, + workflow_run_id: str | None = None, + valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES, + limit: int = 1, + ) -> list[TOTPCode]: + """Get OTP codes matching a specific task or workflow run (no totp_identifier required). + + Used when the agent detects a 2FA page but no TOTP credentials are pre-configured. + The user submits codes manually via the UI, and this method finds them by run context. + """ + if not workflow_run_id and not task_id: + return [] + async with self.Session() as session: + query = ( + select(TOTPCodeModel) + .filter_by(organization_id=organization_id) + .filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes)) + ) + if workflow_run_id: + query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id) + elif task_id: + query = query.filter(TOTPCodeModel.task_id == task_id) + query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit) + results = (await session.scalars(query)).all() + return [TOTPCode.model_validate(r) for r in results] + async def get_recent_otp_codes( self, organization_id: str, diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index cbf56c7a..b898ceaa 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -116,6 +116,10 @@ class TaskModel(Base): model = Column(JSON, nullable=True) browser_address = Column(String, nullable=True) download_timeout = Column(Numeric, nullable=True) + # 2FA verification code waiting state fields + waiting_for_verification_code = Column(Boolean, nullable=False, default=False) + verification_code_identifier = Column(String, nullable=True) + verification_code_polling_started_at = Column(DateTime, nullable=True) class StepModel(Base): @@ -350,6 +354,10 @@ class WorkflowRunModel(Base): debug_session_id: Column = Column(String, nullable=True) ai_fallback = Column(Boolean, nullable=True) code_gen = Column(Boolean, nullable=True) + # 2FA verification code waiting state fields + waiting_for_verification_code = Column(Boolean, nullable=False, default=False) + verification_code_identifier = Column(String, nullable=True) + verification_code_polling_started_at = Column(DateTime, nullable=True) queued_at = Column(DateTime, nullable=True) started_at = Column(DateTime, nullable=True) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 82c0d352..2f0fdcd8 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -211,6 +211,9 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p browser_session_id=task_obj.browser_session_id, browser_address=task_obj.browser_address, download_timeout=task_obj.download_timeout, + waiting_for_verification_code=task_obj.waiting_for_verification_code or False, + verification_code_identifier=task_obj.verification_code_identifier, + verification_code_polling_started_at=task_obj.verification_code_polling_started_at, ) return task @@ -424,6 +427,9 @@ def convert_to_workflow_run( run_with=workflow_run_model.run_with, code_gen=workflow_run_model.code_gen, ai_fallback=workflow_run_model.ai_fallback, + waiting_for_verification_code=workflow_run_model.waiting_for_verification_code or False, + verification_code_identifier=workflow_run_model.verification_code_identifier, + verification_code_polling_started_at=workflow_run_model.verification_code_polling_started_at, ) diff --git a/skyvern/forge/sdk/notification/base.py b/skyvern/forge/sdk/notification/base.py new file mode 100644 index 00000000..66ef8480 --- /dev/null +++ b/skyvern/forge/sdk/notification/base.py @@ -0,0 +1,21 @@ +"""Abstract base for notification registries.""" + +import asyncio +from abc import ABC, abstractmethod + + +class BaseNotificationRegistry(ABC): + """Abstract pub/sub registry scoped by organization. + + Implementations must fan-out: a single publish call delivers the + message to every active subscriber for that organization. + """ + + @abstractmethod + def subscribe(self, organization_id: str) -> asyncio.Queue[dict]: ... + + @abstractmethod + def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None: ... + + @abstractmethod + def publish(self, organization_id: str, message: dict) -> None: ... diff --git a/skyvern/forge/sdk/notification/factory.py b/skyvern/forge/sdk/notification/factory.py new file mode 100644 index 00000000..85e7bfcf --- /dev/null +++ b/skyvern/forge/sdk/notification/factory.py @@ -0,0 +1,14 @@ +from skyvern.forge.sdk.notification.base import BaseNotificationRegistry +from skyvern.forge.sdk.notification.local import LocalNotificationRegistry + + +class NotificationRegistryFactory: + __registry: BaseNotificationRegistry = LocalNotificationRegistry() + + @staticmethod + def set_registry(registry: BaseNotificationRegistry) -> None: + NotificationRegistryFactory.__registry = registry + + @staticmethod + def get_registry() -> BaseNotificationRegistry: + return NotificationRegistryFactory.__registry diff --git a/skyvern/forge/sdk/notification/local.py b/skyvern/forge/sdk/notification/local.py new file mode 100644 index 00000000..9ed1d7f1 --- /dev/null +++ b/skyvern/forge/sdk/notification/local.py @@ -0,0 +1,45 @@ +"""In-process notification registry using asyncio queues (single-pod only).""" + +import asyncio +from collections import defaultdict + +import structlog + +from skyvern.forge.sdk.notification.base import BaseNotificationRegistry + +LOG = structlog.get_logger() + + +class LocalNotificationRegistry(BaseNotificationRegistry): + """In-process fan-out pub/sub using asyncio queues. Single-pod only.""" + + def __init__(self) -> None: + self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list) + + def subscribe(self, organization_id: str) -> asyncio.Queue[dict]: + queue: asyncio.Queue[dict] = asyncio.Queue() + self._subscribers[organization_id].append(queue) + LOG.info("Notification subscriber added", organization_id=organization_id) + return queue + + def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None: + queues = self._subscribers.get(organization_id) + if queues: + try: + queues.remove(queue) + except ValueError: + pass + if not queues: + del self._subscribers[organization_id] + LOG.info("Notification subscriber removed", organization_id=organization_id) + + def publish(self, organization_id: str, message: dict) -> None: + queues = self._subscribers.get(organization_id, []) + for queue in queues: + try: + queue.put_nowait(message) + except asyncio.QueueFull: + LOG.warning( + "Notification queue full, dropping message", + organization_id=organization_id, + ) diff --git a/skyvern/forge/sdk/notification/redis.py b/skyvern/forge/sdk/notification/redis.py new file mode 100644 index 00000000..fe1e9b84 --- /dev/null +++ b/skyvern/forge/sdk/notification/redis.py @@ -0,0 +1,55 @@ +"""Redis-backed notification registry for multi-pod deployments. + +Thin adapter around :class:`RedisPubSub` — all Redis pub/sub logic +lives in the generic layer; this class maps the ``organization_id`` +domain concept onto generic string keys. +""" + +import asyncio + +from redis.asyncio import Redis + +from skyvern.forge.sdk.notification.base import BaseNotificationRegistry +from skyvern.forge.sdk.redis.pubsub import RedisPubSub + + +class RedisNotificationRegistry(BaseNotificationRegistry): + """Fan-out pub/sub backed by Redis. One Redis PubSub channel per org.""" + + def __init__(self, redis_client: Redis) -> None: + self._pubsub = RedisPubSub(redis_client, channel_prefix="skyvern:notifications:") + + # ------------------------------------------------------------------ + # Property accessors (used by existing tests) + # ------------------------------------------------------------------ + + @property + def _listener_tasks(self) -> dict[str, asyncio.Task[None]]: + return self._pubsub._listener_tasks + + @property + def _subscribers(self) -> dict[str, list[asyncio.Queue[dict]]]: + return self._pubsub._subscribers + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def subscribe(self, organization_id: str) -> asyncio.Queue[dict]: + return self._pubsub.subscribe(organization_id) + + def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None: + self._pubsub.unsubscribe(organization_id, queue) + + def publish(self, organization_id: str, message: dict) -> None: + self._pubsub.publish(organization_id, message) + + async def close(self) -> None: + await self._pubsub.close() + + # ------------------------------------------------------------------ + # Internal helper (exposed for tests) + # ------------------------------------------------------------------ + + def _dispatch_local(self, organization_id: str, message: dict) -> None: + self._pubsub._dispatch_local(organization_id, message) diff --git a/skyvern/forge/sdk/redis/__init__.py b/skyvern/forge/sdk/redis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skyvern/forge/sdk/redis/factory.py b/skyvern/forge/sdk/redis/factory.py new file mode 100644 index 00000000..31a69612 --- /dev/null +++ b/skyvern/forge/sdk/redis/factory.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from redis.asyncio import Redis + + +class RedisClientFactory: + """Singleton factory for a shared async Redis client. + + Follows the same static set/get pattern as ``CacheFactory``. + Defaults to ``None`` (no Redis in local/OSS mode). + """ + + __client: Redis | None = None + + @staticmethod + def set_client(client: Redis) -> None: + RedisClientFactory.__client = client + + @staticmethod + def get_client() -> Redis | None: + return RedisClientFactory.__client diff --git a/skyvern/forge/sdk/redis/pubsub.py b/skyvern/forge/sdk/redis/pubsub.py new file mode 100644 index 00000000..fed4f79f --- /dev/null +++ b/skyvern/forge/sdk/redis/pubsub.py @@ -0,0 +1,130 @@ +"""Generic Redis pub/sub layer. + +Extracted from ``RedisNotificationRegistry`` so that any feature +(notifications, events, cache invalidation, etc.) can reuse the same +pattern with its own channel prefix. +""" + +import asyncio +import json +from collections import defaultdict + +import structlog +from redis.asyncio import Redis + +LOG = structlog.get_logger() + + +class RedisPubSub: + """Fan-out pub/sub backed by Redis. One Redis PubSub channel per key.""" + + def __init__(self, redis_client: Redis, channel_prefix: str) -> None: + self._redis = redis_client + self._channel_prefix = channel_prefix + self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list) + # One listener task per key channel + self._listener_tasks: dict[str, asyncio.Task[None]] = {} + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def subscribe(self, key: str) -> asyncio.Queue[dict]: + queue: asyncio.Queue[dict] = asyncio.Queue() + self._subscribers[key].append(queue) + + # Spin up a Redis listener if this is the first local subscriber + if key not in self._listener_tasks: + task = asyncio.get_running_loop().create_task(self._listen(key)) + self._listener_tasks[key] = task + + LOG.info("PubSub subscriber added", key=key, channel_prefix=self._channel_prefix) + return queue + + def unsubscribe(self, key: str, queue: asyncio.Queue[dict]) -> None: + queues = self._subscribers.get(key) + if queues: + try: + queues.remove(queue) + except ValueError: + pass + if not queues: + del self._subscribers[key] + self._cancel_listener(key) + LOG.info("PubSub subscriber removed", key=key, channel_prefix=self._channel_prefix) + + def publish(self, key: str, message: dict) -> None: + """Fire-and-forget Redis PUBLISH.""" + try: + loop = asyncio.get_running_loop() + loop.create_task(self._publish_to_redis(key, message)) + except RuntimeError: + LOG.warning( + "No running event loop; cannot publish via Redis", + key=key, + channel_prefix=self._channel_prefix, + ) + + async def close(self) -> None: + """Cancel all listener tasks and clear state. Call on shutdown.""" + for key in list(self._listener_tasks): + self._cancel_listener(key) + self._subscribers.clear() + LOG.info("RedisPubSub closed", channel_prefix=self._channel_prefix) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + async def _publish_to_redis(self, key: str, message: dict) -> None: + channel = f"{self._channel_prefix}{key}" + try: + await self._redis.publish(channel, json.dumps(message)) + except Exception: + LOG.exception("Failed to publish to Redis", key=key, channel_prefix=self._channel_prefix) + + async def _listen(self, key: str) -> None: + """Subscribe to a Redis channel and fan out messages locally.""" + channel = f"{self._channel_prefix}{key}" + pubsub = self._redis.pubsub() + try: + await pubsub.subscribe(channel) + LOG.info("Redis listener started", channel=channel) + async for raw_message in pubsub.listen(): + if raw_message["type"] != "message": + continue + try: + data = json.loads(raw_message["data"]) + except (json.JSONDecodeError, TypeError): + LOG.warning("Invalid JSON on Redis channel", channel=channel) + continue + self._dispatch_local(key, data) + except asyncio.CancelledError: + LOG.info("Redis listener cancelled", channel=channel) + raise + except Exception: + LOG.exception("Redis listener error", channel=channel) + finally: + try: + await pubsub.unsubscribe(channel) + await pubsub.close() + except Exception: + LOG.warning("Error closing Redis pubsub", channel=channel) + + def _dispatch_local(self, key: str, message: dict) -> None: + """Fan out a message to all local asyncio queues for this key.""" + queues = self._subscribers.get(key, []) + for queue in queues: + try: + queue.put_nowait(message) + except asyncio.QueueFull: + LOG.warning( + "Queue full, dropping message", + key=key, + channel_prefix=self._channel_prefix, + ) + + def _cancel_listener(self, key: str) -> None: + task = self._listener_tasks.pop(key, None) + if task and not task.done(): + task.cancel() diff --git a/skyvern/forge/sdk/routes/__init__.py b/skyvern/forge/sdk/routes/__init__.py index a24204fc..42e038d0 100644 --- a/skyvern/forge/sdk/routes/__init__.py +++ b/skyvern/forge/sdk/routes/__init__.py @@ -11,5 +11,6 @@ from skyvern.forge.sdk.routes import sdk # noqa: F401 from skyvern.forge.sdk.routes import webhooks # noqa: F401 from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401 from skyvern.forge.sdk.routes.streaming import messages # noqa: F401 +from skyvern.forge.sdk.routes.streaming import notifications # noqa: F401 from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401 from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401 diff --git a/skyvern/forge/sdk/routes/credentials.py b/skyvern/forge/sdk/routes/credentials.py index 0a3e686f..ab533954 100644 --- a/skyvern/forge/sdk/routes/credentials.py +++ b/skyvern/forge/sdk/routes/credentials.py @@ -131,10 +131,15 @@ async def send_totp_code( task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id) if not task: raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}") + workflow_id_for_storage: str | None = None if data.workflow_id: - workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id) + if data.workflow_id.startswith("wpid_"): + workflow = await app.DATABASE.get_workflow_by_permanent_id(data.workflow_id, curr_org.organization_id) + else: + workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id) if not workflow: raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}") + workflow_id_for_storage = workflow.workflow_id if data.workflow_run_id: workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id) if not workflow_run: @@ -162,7 +167,7 @@ async def send_totp_code( content=data.content, code=otp_value.value, task_id=data.task_id, - workflow_id=data.workflow_id, + workflow_id=workflow_id_for_storage, workflow_run_id=data.workflow_run_id, source=data.source, expired_at=data.expired_at, diff --git a/skyvern/forge/sdk/routes/streaming/notifications.py b/skyvern/forge/sdk/routes/streaming/notifications.py new file mode 100644 index 00000000..bf713a73 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/notifications.py @@ -0,0 +1,101 @@ +"""WebSocket endpoint for streaming global 2FA verification code notifications.""" + +import asyncio + +import structlog +from fastapi import WebSocket, WebSocketDisconnect +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK + +from skyvern.config import settings +from skyvern.forge import app +from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory +from skyvern.forge.sdk.routes.routers import base_router +from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth +from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth + +LOG = structlog.get_logger() +HEARTBEAT_INTERVAL = 60 + + +@base_router.websocket("/stream/notifications") +async def notification_stream( + websocket: WebSocket, + apikey: str | None = None, + token: str | None = None, +) -> None: + auth = local_auth if settings.ENV == "local" else real_auth + organization_id = await auth(apikey=apikey, token=token, websocket=websocket) + if not organization_id: + LOG.info("Notifications: Authentication failed") + return + + LOG.info("Notifications: Started streaming", organization_id=organization_id) + registry = NotificationRegistryFactory.get_registry() + queue = registry.subscribe(organization_id) + + try: + # Send initial state: all currently active verification requests + active_requests = await app.DATABASE.get_active_verification_requests(organization_id) + for req in active_requests: + await websocket.send_json( + { + "type": "verification_code_required", + "task_id": req.get("task_id"), + "workflow_run_id": req.get("workflow_run_id"), + "identifier": req.get("verification_code_identifier"), + "polling_started_at": req.get("verification_code_polling_started_at"), + } + ) + + # Watch for client disconnect while streaming events + disconnect_event = asyncio.Event() + + async def _watch_disconnect() -> None: + try: + while True: + await websocket.receive() + except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError): + disconnect_event.set() + + watcher = asyncio.create_task(_watch_disconnect()) + try: + while not disconnect_event.is_set(): + queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL)) + disconnect_wait = asyncio.ensure_future(disconnect_event.wait()) + done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED) + for p in pending: + p.cancel() + + if disconnect_event.is_set(): + return + + try: + message = queue_task.result() + await websocket.send_json(message) + except TimeoutError: + try: + await websocket.send_json({"type": "heartbeat"}) + except Exception: + LOG.info( + "Notifications: Client unreachable during heartbeat. Closing.", + organization_id=organization_id, + ) + return + except asyncio.CancelledError: + return + finally: + watcher.cancel() + + except WebSocketDisconnect: + LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id) + except ConnectionClosedOK: + LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id) + except ConnectionClosedError: + LOG.warning( + "Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id + ) + except Exception: + LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True) + finally: + registry.unsubscribe(organization_id, queue) + LOG.info("Notifications: Connection closed", organization_id=organization_id) diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 0f60a7fa..42ec74e1 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -288,6 +288,10 @@ class Task(TaskBase): queued_at: datetime | None = None started_at: datetime | None = None finished_at: datetime | None = None + # 2FA verification code waiting state fields + waiting_for_verification_code: bool = False + verification_code_identifier: str | None = None + verification_code_polling_started_at: datetime | None = None @property def llm_key(self) -> str | None: @@ -365,6 +369,9 @@ class Task(TaskBase): max_screenshot_scrolls=self.max_screenshot_scrolls, step_count=step_count, browser_session_id=self.browser_session_id, + waiting_for_verification_code=self.waiting_for_verification_code, + verification_code_identifier=self.verification_code_identifier, + verification_code_polling_started_at=self.verification_code_polling_started_at, ) @@ -392,6 +399,10 @@ class TaskResponse(BaseModel): max_screenshot_scrolls: int | None = None step_count: int | None = None browser_session_id: str | None = None + # 2FA verification code waiting state fields + waiting_for_verification_code: bool = False + verification_code_identifier: str | None = None + verification_code_polling_started_at: datetime | None = None class TaskOutput(BaseModel): diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index e9d3eeb8..d2667698 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -172,6 +172,10 @@ class WorkflowRun(BaseModel): sequential_key: str | None = None ai_fallback: bool | None = None code_gen: bool | None = None + # 2FA verification code waiting state fields + waiting_for_verification_code: bool = False + verification_code_identifier: str | None = None + verification_code_polling_started_at: datetime | None = None queued_at: datetime | None = None started_at: datetime | None = None @@ -226,6 +230,10 @@ class WorkflowRunResponseBase(BaseModel): browser_address: str | None = None script_run: ScriptRunResponse | None = None errors: list[dict[str, Any]] | None = None + # 2FA verification code waiting state fields + waiting_for_verification_code: bool = False + verification_code_identifier: str | None = None + verification_code_polling_started_at: datetime | None = None class WorkflowRunWithWorkflowResponse(WorkflowRunResponseBase): diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 69c17119..ee937984 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -3019,6 +3019,10 @@ class WorkflowService: browser_address=workflow_run.browser_address, script_run=workflow_run.script_run, errors=errors, + # 2FA verification code waiting state fields + waiting_for_verification_code=workflow_run.waiting_for_verification_code, + verification_code_identifier=workflow_run.verification_code_identifier, + verification_code_polling_started_at=workflow_run.verification_code_polling_started_at, ) async def clean_up_workflow( diff --git a/skyvern/services/otp_service.py b/skyvern/services/otp_service.py index fcd28399..0d5c33fa 100644 --- a/skyvern/services/otp_service.py +++ b/skyvern/services/otp_service.py @@ -11,6 +11,7 @@ from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType +from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory from skyvern.forge.sdk.schemas.totp_codes import OTPType LOG = structlog.get_logger() @@ -80,38 +81,147 @@ async def poll_otp_value( totp_verification_url=totp_verification_url, totp_identifier=totp_identifier, ) - while True: - await asyncio.sleep(10) - # check timeout - if datetime.utcnow() > timeout_datetime: - LOG.warning("Polling otp value timed out") - raise NoTOTPVerificationCodeFound( - task_id=task_id, + + # Set the waiting state in the database when polling starts + identifier_for_ui = totp_identifier + if workflow_run_id: + try: + await app.DATABASE.update_workflow_run( workflow_run_id=workflow_run_id, - workflow_id=workflow_permanent_id, - totp_verification_url=totp_verification_url, - totp_identifier=totp_identifier, + waiting_for_verification_code=True, + verification_code_identifier=identifier_for_ui, + verification_code_polling_started_at=start_datetime, ) - otp_value: OTPValue | None = None - if totp_verification_url: - otp_value = await _get_otp_value_from_url( - organization_id, - totp_verification_url, - org_token.token, - task_id=task_id, + LOG.info( + "Set 2FA waiting state for workflow run", workflow_run_id=workflow_run_id, + verification_code_identifier=identifier_for_ui, ) - elif totp_identifier: - otp_value = await _get_otp_value_from_db( - organization_id, - totp_identifier, + try: + NotificationRegistryFactory.get_registry().publish( + organization_id, + { + "type": "verification_code_required", + "workflow_run_id": workflow_run_id, + "task_id": task_id, + "identifier": identifier_for_ui, + "polling_started_at": start_datetime.isoformat(), + }, + ) + except Exception: + LOG.warning("Failed to publish 2FA required notification for workflow run", exc_info=True) + except Exception: + LOG.warning("Failed to set 2FA waiting state for workflow run", exc_info=True) + elif task_id: + try: + await app.DATABASE.update_task_2fa_state( task_id=task_id, - workflow_id=workflow_permanent_id, - workflow_run_id=workflow_run_id, + organization_id=organization_id, + waiting_for_verification_code=True, + verification_code_identifier=identifier_for_ui, + verification_code_polling_started_at=start_datetime, ) - if otp_value: - LOG.info("Got otp value", otp_value=otp_value) - return otp_value + LOG.info( + "Set 2FA waiting state for task", + task_id=task_id, + verification_code_identifier=identifier_for_ui, + ) + try: + NotificationRegistryFactory.get_registry().publish( + organization_id, + { + "type": "verification_code_required", + "task_id": task_id, + "identifier": identifier_for_ui, + "polling_started_at": start_datetime.isoformat(), + }, + ) + except Exception: + LOG.warning("Failed to publish 2FA required notification for task", exc_info=True) + except Exception: + LOG.warning("Failed to set 2FA waiting state for task", exc_info=True) + + try: + while True: + await asyncio.sleep(10) + # check timeout + if datetime.utcnow() > timeout_datetime: + LOG.warning("Polling otp value timed out") + raise NoTOTPVerificationCodeFound( + task_id=task_id, + workflow_run_id=workflow_run_id, + workflow_id=workflow_permanent_id, + totp_verification_url=totp_verification_url, + totp_identifier=totp_identifier, + ) + otp_value: OTPValue | None = None + if totp_verification_url: + otp_value = await _get_otp_value_from_url( + organization_id, + totp_verification_url, + org_token.token, + task_id=task_id, + workflow_run_id=workflow_run_id, + ) + elif totp_identifier: + otp_value = await _get_otp_value_from_db( + organization_id, + totp_identifier, + task_id=task_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + ) + if not otp_value: + otp_value = await _get_otp_value_by_run( + organization_id, + task_id=task_id, + workflow_run_id=workflow_run_id, + ) + else: + # No pre-configured TOTP — poll for manually submitted codes by run context + otp_value = await _get_otp_value_by_run( + organization_id, + task_id=task_id, + workflow_run_id=workflow_run_id, + ) + if otp_value: + LOG.info("Got otp value", otp_value=otp_value) + return otp_value + finally: + # Clear the waiting state when polling completes (success, timeout, or error) + if workflow_run_id: + try: + await app.DATABASE.update_workflow_run( + workflow_run_id=workflow_run_id, + waiting_for_verification_code=False, + ) + LOG.info("Cleared 2FA waiting state for workflow run", workflow_run_id=workflow_run_id) + try: + NotificationRegistryFactory.get_registry().publish( + organization_id, + {"type": "verification_code_resolved", "workflow_run_id": workflow_run_id, "task_id": task_id}, + ) + except Exception: + LOG.warning("Failed to publish 2FA resolved notification for workflow run", exc_info=True) + except Exception: + LOG.warning("Failed to clear 2FA waiting state for workflow run", exc_info=True) + elif task_id: + try: + await app.DATABASE.update_task_2fa_state( + task_id=task_id, + organization_id=organization_id, + waiting_for_verification_code=False, + ) + LOG.info("Cleared 2FA waiting state for task", task_id=task_id) + try: + NotificationRegistryFactory.get_registry().publish( + organization_id, + {"type": "verification_code_resolved", "task_id": task_id}, + ) + except Exception: + LOG.warning("Failed to publish 2FA resolved notification for task", exc_info=True) + except Exception: + LOG.warning("Failed to clear 2FA waiting state for task", exc_info=True) async def _get_otp_value_from_url( @@ -175,6 +285,28 @@ async def _get_otp_value_from_url( return otp_value +async def _get_otp_value_by_run( + organization_id: str, + task_id: str | None = None, + workflow_run_id: str | None = None, +) -> OTPValue | None: + """Look up OTP codes by task_id/workflow_run_id when no totp_identifier is configured. + + Used for the manual 2FA input flow where users submit codes through the UI + without pre-configured TOTP credentials. + """ + codes = await app.DATABASE.get_otp_codes_by_run( + organization_id=organization_id, + task_id=task_id, + workflow_run_id=workflow_run_id, + limit=1, + ) + if codes: + code = codes[0] + return OTPValue(value=code.code, type=code.otp_type) + return None + + async def _get_otp_value_from_db( organization_id: str, totp_identifier: str, diff --git a/tests/unit_tests/test_notification_registry.py b/tests/unit_tests/test_notification_registry.py new file mode 100644 index 00000000..1654740a --- /dev/null +++ b/tests/unit_tests/test_notification_registry.py @@ -0,0 +1,106 @@ +"""Tests for NotificationRegistry pub/sub and get_active_verification_requests (SKY-6).""" + +import pytest + +from skyvern.forge.sdk.db.agent_db import AgentDB +from skyvern.forge.sdk.notification.base import BaseNotificationRegistry +from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory +from skyvern.forge.sdk.notification.local import LocalNotificationRegistry + +# === Task 1: NotificationRegistry subscribe / publish / unsubscribe === + + +@pytest.mark.asyncio +async def test_subscribe_and_publish(): + """Published messages should be received by subscribers.""" + registry = LocalNotificationRegistry() + queue = registry.subscribe("org_1") + + registry.publish("org_1", {"type": "verification_code_required", "task_id": "tsk_1"}) + msg = queue.get_nowait() + assert msg["type"] == "verification_code_required" + assert msg["task_id"] == "tsk_1" + + +@pytest.mark.asyncio +async def test_multiple_subscribers(): + """All subscribers for an org should receive the same message.""" + registry = LocalNotificationRegistry() + q1 = registry.subscribe("org_1") + q2 = registry.subscribe("org_1") + + registry.publish("org_1", {"type": "verification_code_required"}) + assert not q1.empty() + assert not q2.empty() + assert q1.get_nowait() == q2.get_nowait() + + +@pytest.mark.asyncio +async def test_publish_wrong_org_does_not_leak(): + """Messages for org_A should not appear in org_B's queue.""" + registry = LocalNotificationRegistry() + q_a = registry.subscribe("org_a") + q_b = registry.subscribe("org_b") + + registry.publish("org_a", {"type": "test"}) + assert not q_a.empty() + assert q_b.empty() + + +@pytest.mark.asyncio +async def test_unsubscribe(): + """After unsubscribe, the queue should no longer receive messages.""" + registry = LocalNotificationRegistry() + queue = registry.subscribe("org_1") + + registry.unsubscribe("org_1", queue) + registry.publish("org_1", {"type": "test"}) + assert queue.empty() + + +@pytest.mark.asyncio +async def test_unsubscribe_idempotent(): + """Unsubscribing a queue that's already removed should not raise.""" + registry = LocalNotificationRegistry() + queue = registry.subscribe("org_1") + registry.unsubscribe("org_1", queue) + registry.unsubscribe("org_1", queue) # should not raise + + +# === Task: BaseNotificationRegistry ABC === + + +def test_base_notification_registry_cannot_be_instantiated(): + """ABC should not be directly instantiable.""" + with pytest.raises(TypeError): + BaseNotificationRegistry() + + +# === Task: NotificationRegistryFactory === + + +@pytest.mark.asyncio +async def test_factory_returns_local_by_default(): + """Factory should return a LocalNotificationRegistry by default.""" + registry = NotificationRegistryFactory.get_registry() + assert isinstance(registry, LocalNotificationRegistry) + + +@pytest.mark.asyncio +async def test_factory_set_and_get(): + """Factory should allow swapping the registry implementation.""" + original = NotificationRegistryFactory.get_registry() + try: + custom = LocalNotificationRegistry() + NotificationRegistryFactory.set_registry(custom) + assert NotificationRegistryFactory.get_registry() is custom + finally: + NotificationRegistryFactory.set_registry(original) + + +# === Task 2: get_active_verification_requests DB method === + + +def test_get_active_verification_requests_method_exists(): + """AgentDB should have get_active_verification_requests method.""" + assert hasattr(AgentDB, "get_active_verification_requests") diff --git a/tests/unit_tests/test_otp_no_config.py b/tests/unit_tests/test_otp_no_config.py new file mode 100644 index 00000000..33fa57bf --- /dev/null +++ b/tests/unit_tests/test_otp_no_config.py @@ -0,0 +1,544 @@ +"""Tests for manual 2FA input without pre-configured TOTP credentials (SKY-6).""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from skyvern.constants import SPECIAL_FIELD_VERIFICATION_CODE +from skyvern.forge.agent import ForgeAgent +from skyvern.forge.sdk.db.agent_db import AgentDB +from skyvern.forge.sdk.notification.local import LocalNotificationRegistry +from skyvern.forge.sdk.routes.credentials import send_totp_code +from skyvern.forge.sdk.schemas.totp_codes import TOTPCodeCreate +from skyvern.schemas.runs import RunEngine +from skyvern.services.otp_service import OTPValue, _get_otp_value_by_run, poll_otp_value + + +@pytest.mark.asyncio +async def test_get_otp_codes_by_run_exists(): + """get_otp_codes_by_run should exist on AgentDB.""" + assert hasattr(AgentDB, "get_otp_codes_by_run"), "AgentDB missing get_otp_codes_by_run method" + + +@pytest.mark.asyncio +async def test_get_otp_codes_by_run_returns_empty_without_identifiers(): + """get_otp_codes_by_run should return [] when neither task_id nor workflow_run_id is given.""" + db = AgentDB.__new__(AgentDB) + result = await db.get_otp_codes_by_run( + organization_id="org_1", + ) + assert result == [] + + +# === Task 2: _get_otp_value_by_run OTP service function === + + +@pytest.mark.asyncio +async def test_get_otp_value_by_run_returns_code(): + """_get_otp_value_by_run should find OTP codes by task_id.""" + mock_code = MagicMock() + mock_code.code = "123456" + mock_code.otp_type = "totp" + + mock_db = AsyncMock() + mock_db.get_otp_codes_by_run.return_value = [mock_code] + + mock_app = MagicMock() + mock_app.DATABASE = mock_db + + with patch("skyvern.services.otp_service.app", new=mock_app): + result = await _get_otp_value_by_run( + organization_id="org_1", + task_id="tsk_1", + ) + assert result is not None + assert result.value == "123456" + + +@pytest.mark.asyncio +async def test_get_otp_value_by_run_returns_none_when_no_codes(): + """_get_otp_value_by_run should return None when no codes found.""" + mock_db = AsyncMock() + mock_db.get_otp_codes_by_run.return_value = [] + + mock_app = MagicMock() + mock_app.DATABASE = mock_db + + with patch("skyvern.services.otp_service.app", new=mock_app): + result = await _get_otp_value_by_run( + organization_id="org_1", + task_id="tsk_1", + ) + assert result is None + + +# === Task 3: poll_otp_value without identifier === + + +@pytest.mark.asyncio +async def test_poll_otp_value_without_identifier_uses_run_lookup(): + """poll_otp_value should use _get_otp_value_by_run when no identifier/URL provided.""" + mock_code = MagicMock() + mock_code.code = "123456" + mock_code.otp_type = "totp" + + mock_db = AsyncMock() + mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok") + mock_db.get_otp_codes_by_run.return_value = [mock_code] + mock_db.update_task_2fa_state = AsyncMock() + + mock_app = MagicMock() + mock_app.DATABASE = mock_db + + with ( + patch("skyvern.services.otp_service.app", new=mock_app), + patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock), + ): + result = await poll_otp_value( + organization_id="org_1", + task_id="tsk_1", + ) + assert result is not None + assert result.value == "123456" + + +# === Task 6: Integration test — handle_potential_OTP_actions without TOTP config === + + +@pytest.mark.asyncio +async def test_handle_potential_OTP_actions_without_totp_config(): + """When LLM detects 2FA but no TOTP config exists, should still enter verification flow.""" + agent = ForgeAgent.__new__(ForgeAgent) + + task = MagicMock() + task.organization_id = "org_1" + task.totp_verification_url = None + task.totp_identifier = None + task.task_id = "tsk_1" + task.workflow_run_id = None + + step = MagicMock() + scraped_page = MagicMock() + browser_state = MagicMock() + + json_response = { + "should_enter_verification_code": True, + "place_to_enter_verification_code": "input#otp-code", + "actions": [], + } + + with patch.object(agent, "handle_potential_verification_code", new_callable=AsyncMock) as mock_handler: + mock_handler.return_value = {"actions": []} + with patch("skyvern.forge.agent.parse_actions", return_value=[]): + result_json, result_actions = await agent.handle_potential_OTP_actions( + task, step, scraped_page, browser_state, json_response + ) + mock_handler.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_potential_OTP_actions_skips_magic_link_without_totp_config(): + """Magic links should still require TOTP config.""" + agent = ForgeAgent.__new__(ForgeAgent) + + task = MagicMock() + task.organization_id = "org_1" + task.totp_verification_url = None + task.totp_identifier = None + + step = MagicMock() + scraped_page = MagicMock() + browser_state = MagicMock() + + json_response = { + "should_verify_by_magic_link": True, + } + + with patch.object(agent, "handle_potential_magic_link", new_callable=AsyncMock) as mock_handler: + result_json, result_actions = await agent.handle_potential_OTP_actions( + task, step, scraped_page, browser_state, json_response + ) + mock_handler.assert_not_called() + assert result_actions == [] + + +# === Task 7: verification_code_check always True in LLM prompt === + + +@pytest.mark.asyncio +async def test_verification_code_check_always_true_without_totp_config(): + """_build_extract_action_prompt should receive verification_code_check=True even when task has no TOTP config.""" + agent = ForgeAgent.__new__(ForgeAgent) + agent.async_operation_pool = MagicMock() + + task = MagicMock() + task.totp_verification_url = None + task.totp_identifier = None + task.task_id = "tsk_1" + task.workflow_run_id = None + task.organization_id = "org_1" + task.url = "https://example.com" + + step = MagicMock() + step.step_id = "step_1" + step.order = 0 + step.retry_index = 0 + + scraped_page = MagicMock() + scraped_page.elements = [] + + browser_state = MagicMock() + + with ( + patch("skyvern.forge.agent.skyvern_context") as mock_ctx, + patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page), + patch.object( + agent, + "_build_extract_action_prompt", + new_callable=AsyncMock, + return_value=("prompt", False, "extract_action"), + ) as mock_build, + ): + mock_ctx.current.return_value = None + await agent.build_and_record_step_prompt( + task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False + ) + mock_build.assert_called_once() + _, kwargs = mock_build.call_args + assert kwargs["verification_code_check"] is True + + +@pytest.mark.asyncio +async def test_verification_code_check_always_true_with_totp_config(): + """_build_extract_action_prompt should receive verification_code_check=True when task HAS TOTP config (unchanged).""" + agent = ForgeAgent.__new__(ForgeAgent) + agent.async_operation_pool = MagicMock() + + task = MagicMock() + task.totp_verification_url = "https://otp.example.com" + task.totp_identifier = "user@example.com" + task.task_id = "tsk_2" + task.workflow_run_id = None + task.organization_id = "org_1" + task.url = "https://example.com" + + step = MagicMock() + step.step_id = "step_1" + step.order = 0 + step.retry_index = 0 + + scraped_page = MagicMock() + scraped_page.elements = [] + + browser_state = MagicMock() + + with ( + patch("skyvern.forge.agent.skyvern_context") as mock_ctx, + patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page), + patch.object( + agent, + "_build_extract_action_prompt", + new_callable=AsyncMock, + return_value=("prompt", False, "extract_action"), + ) as mock_build, + ): + mock_ctx.current.return_value = None + await agent.build_and_record_step_prompt( + task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False + ) + mock_build.assert_called_once() + _, kwargs = mock_build.call_args + assert kwargs["verification_code_check"] is True + + +# === Fix: poll_otp_value should pass workflow_id, not workflow_permanent_id === + + +@pytest.mark.asyncio +async def test_poll_otp_value_passes_workflow_id_not_permanent_id(): + """poll_otp_value should pass workflow_id (w_* format) to _get_otp_value_from_db, not workflow_permanent_id.""" + mock_db = AsyncMock() + mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok") + mock_db.update_workflow_run = AsyncMock() + + mock_app = MagicMock() + mock_app.DATABASE = mock_db + + with ( + patch("skyvern.services.otp_service.app", new=mock_app), + patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock), + patch( + "skyvern.services.otp_service._get_otp_value_from_db", + new_callable=AsyncMock, + return_value=OTPValue(value="654321", type="totp"), + ) as mock_get_from_db, + ): + result = await poll_otp_value( + organization_id="org_1", + workflow_id="w_123", + workflow_run_id="wr_789", + workflow_permanent_id="wpid_456", + totp_identifier="user@example.com", + ) + assert result is not None + assert result.value == "654321" + mock_get_from_db.assert_called_once_with( + "org_1", + "user@example.com", + task_id=None, + workflow_id="w_123", + workflow_run_id="wr_789", + ) + + +# === Fix: send_totp_code should resolve wpid_* to w_* before storage === + + +@pytest.mark.asyncio +async def test_send_totp_code_resolves_wpid_to_workflow_id(): + """send_totp_code should resolve wpid_* to w_* before storing in DB.""" + mock_workflow = MagicMock() + mock_workflow.workflow_id = "w_abc123" + + mock_totp_code = MagicMock() + + mock_db = AsyncMock() + mock_db.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow) + mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code) + + mock_app = MagicMock() + mock_app.DATABASE = mock_db + + data = TOTPCodeCreate( + totp_identifier="user@example.com", + content="123456", + workflow_id="wpid_xyz789", + ) + curr_org = MagicMock() + curr_org.organization_id = "org_1" + + with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app): + await send_totp_code(data=data, curr_org=curr_org) + + mock_db.create_otp_code.assert_called_once() + call_kwargs = mock_db.create_otp_code.call_args[1] + assert call_kwargs["workflow_id"] == "w_abc123", f"Expected w_abc123 but got {call_kwargs['workflow_id']}" + + +@pytest.mark.asyncio +async def test_send_totp_code_w_format_passes_through(): + """send_totp_code should resolve and store w_* format workflow_id correctly.""" + mock_workflow = MagicMock() + mock_workflow.workflow_id = "w_abc123" + + mock_totp_code = MagicMock() + + mock_db = AsyncMock() + mock_db.get_workflow = AsyncMock(return_value=mock_workflow) + mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code) + + mock_app = MagicMock() + mock_app.DATABASE = mock_db + + data = TOTPCodeCreate( + totp_identifier="user@example.com", + content="123456", + workflow_id="w_abc123", + ) + curr_org = MagicMock() + curr_org.organization_id = "org_1" + + with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app): + await send_totp_code(data=data, curr_org=curr_org) + + call_kwargs = mock_db.create_otp_code.call_args[1] + assert call_kwargs["workflow_id"] == "w_abc123" + + +@pytest.mark.asyncio +async def test_send_totp_code_none_workflow_id(): + """send_totp_code should pass None workflow_id when not provided.""" + mock_totp_code = MagicMock() + + mock_db = AsyncMock() + mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code) + + mock_app = MagicMock() + mock_app.DATABASE = mock_db + + data = TOTPCodeCreate( + totp_identifier="user@example.com", + content="123456", + ) + curr_org = MagicMock() + curr_org.organization_id = "org_1" + + with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app): + await send_totp_code(data=data, curr_org=curr_org) + + call_kwargs = mock_db.create_otp_code.call_args[1] + assert call_kwargs["workflow_id"] is None + + +# === Fix: _build_navigation_payload should inject code without TOTP config === + + +def test_build_navigation_payload_injects_code_without_totp_config(): + """_build_navigation_payload should inject SPECIAL_FIELD_VERIFICATION_CODE even when + task has no totp_verification_url or totp_identifier (manual 2FA flow).""" + agent = ForgeAgent.__new__(ForgeAgent) + + task = MagicMock() + task.totp_verification_url = None + task.totp_identifier = None + task.task_id = "tsk_manual_2fa" + task.workflow_run_id = "wr_123" + task.navigation_payload = {"username": "user@example.com"} + + mock_context = MagicMock() + mock_context.totp_codes = {"tsk_manual_2fa": "123456"} + mock_context.has_magic_link_page.return_value = False + + with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx: + mock_skyvern_ctx.ensure_context.return_value = mock_context + result = agent._build_navigation_payload(task) + + assert isinstance(result, dict) + assert SPECIAL_FIELD_VERIFICATION_CODE in result + assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "123456" + # Original payload preserved + assert result["username"] == "user@example.com" + + +def test_build_navigation_payload_injects_code_when_payload_is_none(): + """_build_navigation_payload should create a dict with the code when payload is None.""" + agent = ForgeAgent.__new__(ForgeAgent) + + task = MagicMock() + task.totp_verification_url = None + task.totp_identifier = None + task.task_id = "tsk_manual_2fa" + task.workflow_run_id = "wr_123" + task.navigation_payload = None + + mock_context = MagicMock() + mock_context.totp_codes = {"tsk_manual_2fa": "999999"} + mock_context.has_magic_link_page.return_value = False + + with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx: + mock_skyvern_ctx.ensure_context.return_value = mock_context + result = agent._build_navigation_payload(task) + + assert isinstance(result, dict) + assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "999999" + + +def test_build_navigation_payload_no_code_no_injection(): + """_build_navigation_payload should NOT inject anything when no code in context.""" + agent = ForgeAgent.__new__(ForgeAgent) + + task = MagicMock() + task.totp_verification_url = None + task.totp_identifier = None + task.task_id = "tsk_no_code" + task.workflow_run_id = "wr_456" + task.navigation_payload = {"field": "value"} + + mock_context = MagicMock() + mock_context.totp_codes = {} # No code in context + mock_context.has_magic_link_page.return_value = False + + with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx: + mock_skyvern_ctx.ensure_context.return_value = mock_context + result = agent._build_navigation_payload(task) + + assert isinstance(result, dict) + assert SPECIAL_FIELD_VERIFICATION_CODE not in result + assert result["field"] == "value" + + +# === Task: poll_otp_value publishes 2FA events to notification registry === + + +@pytest.mark.asyncio +async def test_poll_otp_value_publishes_required_event_for_task(): + """poll_otp_value should publish verification_code_required when task waiting state is set.""" + mock_code = MagicMock() + mock_code.code = "123456" + mock_code.otp_type = "totp" + + mock_db = AsyncMock() + mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok") + mock_db.get_otp_codes_by_run.return_value = [mock_code] + mock_db.update_task_2fa_state = AsyncMock() + + mock_app = MagicMock() + mock_app.DATABASE = mock_db + + registry = LocalNotificationRegistry() + queue = registry.subscribe("org_1") + + with ( + patch("skyvern.services.otp_service.app", new=mock_app), + patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock), + patch( + "skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry", + new=registry, + ), + ): + await poll_otp_value(organization_id="org_1", task_id="tsk_1") + + # Should have received required + resolved messages + messages = [] + while not queue.empty(): + messages.append(queue.get_nowait()) + + types = [m["type"] for m in messages] + assert "verification_code_required" in types + assert "verification_code_resolved" in types + + required = next(m for m in messages if m["type"] == "verification_code_required") + assert required["task_id"] == "tsk_1" + + resolved = next(m for m in messages if m["type"] == "verification_code_resolved") + assert resolved["task_id"] == "tsk_1" + + +@pytest.mark.asyncio +async def test_poll_otp_value_publishes_required_event_for_workflow_run(): + """poll_otp_value should publish verification_code_required when workflow run waiting state is set.""" + mock_code = MagicMock() + mock_code.code = "654321" + mock_code.otp_type = "totp" + + mock_db = AsyncMock() + mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok") + mock_db.update_workflow_run = AsyncMock() + mock_db.get_otp_codes_by_run.return_value = [mock_code] + + mock_app = MagicMock() + mock_app.DATABASE = mock_db + + registry = LocalNotificationRegistry() + queue = registry.subscribe("org_1") + + with ( + patch("skyvern.services.otp_service.app", new=mock_app), + patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock), + patch( + "skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry", + new=registry, + ), + ): + await poll_otp_value(organization_id="org_1", workflow_run_id="wr_1") + + messages = [] + while not queue.empty(): + messages.append(queue.get_nowait()) + + types = [m["type"] for m in messages] + assert "verification_code_required" in types + assert "verification_code_resolved" in types + + required = next(m for m in messages if m["type"] == "verification_code_required") + assert required["workflow_run_id"] == "wr_1" diff --git a/tests/unit_tests/test_redis_client_factory.py b/tests/unit_tests/test_redis_client_factory.py new file mode 100644 index 00000000..36e31b44 --- /dev/null +++ b/tests/unit_tests/test_redis_client_factory.py @@ -0,0 +1,22 @@ +"""Tests for RedisClientFactory.""" + +from unittest.mock import MagicMock + +from skyvern.forge.sdk.redis.factory import RedisClientFactory + + +def test_default_is_none(): + """Factory returns None when no client has been set.""" + # Reset to default state + RedisClientFactory.set_client(None) # type: ignore[arg-type] + assert RedisClientFactory.get_client() is None + + +def test_set_and_get(): + """Round-trip: set_client then get_client returns the same object.""" + mock_client = MagicMock() + RedisClientFactory.set_client(mock_client) + assert RedisClientFactory.get_client() is mock_client + + # Cleanup + RedisClientFactory.set_client(None) # type: ignore[arg-type] diff --git a/tests/unit_tests/test_redis_notification_registry.py b/tests/unit_tests/test_redis_notification_registry.py new file mode 100644 index 00000000..7370476d --- /dev/null +++ b/tests/unit_tests/test_redis_notification_registry.py @@ -0,0 +1,237 @@ +"""Tests for RedisNotificationRegistry (SKY-6). + +All tests use a mock Redis client — no real Redis instance required. +""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_redis() -> MagicMock: + """Return a mock redis.asyncio.Redis client.""" + redis = MagicMock() + redis.publish = AsyncMock() + redis.pubsub = MagicMock() + return redis + + +def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock: + """Return a mock PubSub that yields *messages* from ``listen()``. + + Each entry in *messages* should look like: + {"type": "message", "data": '{"key": "val"}'} + + If *block* is True the async generator will hang forever after + exhausting *messages*, which keeps the listener task alive so that + cancellation semantics can be tested. + """ + pubsub = MagicMock() + pubsub.subscribe = AsyncMock() + pubsub.unsubscribe = AsyncMock() + pubsub.close = AsyncMock() + + async def _listen(): + for msg in messages or []: + yield msg + if block: + # Keep the listener alive until cancelled + await asyncio.Event().wait() + + pubsub.listen = _listen + return pubsub + + +# --------------------------------------------------------------------------- +# Tests: subscribe +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_subscribe_creates_queue_and_starts_listener(): + """subscribe() should return a queue and start a background listener task.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + registry = RedisNotificationRegistry(redis) + queue = registry.subscribe("org_1") + + assert isinstance(queue, asyncio.Queue) + assert "org_1" in registry._listener_tasks + task = registry._listener_tasks["org_1"] + assert isinstance(task, asyncio.Task) + + # Cleanup + await registry.close() + + +@pytest.mark.asyncio +async def test_subscribe_reuses_listener_for_same_org(): + """A second subscribe for the same org should NOT create a new listener task.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + registry = RedisNotificationRegistry(redis) + registry.subscribe("org_1") + first_task = registry._listener_tasks["org_1"] + + registry.subscribe("org_1") + assert registry._listener_tasks["org_1"] is first_task + + await registry.close() + + +# --------------------------------------------------------------------------- +# Tests: unsubscribe +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves(): + """When the last subscriber unsubscribes, the listener task should be cancelled.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub(block=True) + redis.pubsub.return_value = pubsub + + registry = RedisNotificationRegistry(redis) + queue = registry.subscribe("org_1") + + # Let the listener task start running + await asyncio.sleep(0) + + task = registry._listener_tasks["org_1"] + + registry.unsubscribe("org_1", queue) + assert "org_1" not in registry._listener_tasks + + # Wait for the task to fully complete after cancellation + await asyncio.gather(task, return_exceptions=True) + assert task.cancelled() + + await registry.close() + + +@pytest.mark.asyncio +async def test_unsubscribe_keeps_listener_when_subscribers_remain(): + """If other subscribers remain, the listener task should stay alive.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + registry = RedisNotificationRegistry(redis) + q1 = registry.subscribe("org_1") + registry.subscribe("org_1") # second subscriber + + registry.unsubscribe("org_1", q1) + assert "org_1" in registry._listener_tasks + + await registry.close() + + +# --------------------------------------------------------------------------- +# Tests: publish +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_calls_redis_publish(): + """publish() should fire-and-forget a Redis PUBLISH.""" + redis = _make_mock_redis() + registry = RedisNotificationRegistry(redis) + + registry.publish("org_1", {"type": "verification_code_required"}) + + # Allow the fire-and-forget task to execute + await asyncio.sleep(0) + + redis.publish.assert_awaited_once_with( + "skyvern:notifications:org_1", + json.dumps({"type": "verification_code_required"}), + ) + + await registry.close() + + +# --------------------------------------------------------------------------- +# Tests: _dispatch_local +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_dispatch_local_fans_out_to_all_queues(): + """_dispatch_local should put the message into every local queue for the org.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + registry = RedisNotificationRegistry(redis) + q1 = registry.subscribe("org_1") + q2 = registry.subscribe("org_1") + + msg = {"type": "test", "value": 42} + registry._dispatch_local("org_1", msg) + + assert q1.get_nowait() == msg + assert q2.get_nowait() == msg + + await registry.close() + + +@pytest.mark.asyncio +async def test_dispatch_local_does_not_leak_across_orgs(): + """Messages dispatched for org_a should not appear in org_b queues.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + registry = RedisNotificationRegistry(redis) + q_a = registry.subscribe("org_a") + q_b = registry.subscribe("org_b") + + registry._dispatch_local("org_a", {"type": "test"}) + assert not q_a.empty() + assert q_b.empty() + + await registry.close() + + +# --------------------------------------------------------------------------- +# Tests: close +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_close_cancels_all_listeners_and_clears_state(): + """close() should cancel every listener task and empty subscriber maps.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub(block=True) + redis.pubsub.return_value = pubsub + + registry = RedisNotificationRegistry(redis) + registry.subscribe("org_1") + registry.subscribe("org_2") + + # Let the listener tasks start running + await asyncio.sleep(0) + + task_1 = registry._listener_tasks["org_1"] + task_2 = registry._listener_tasks["org_2"] + + await registry.close() + + # Wait for the tasks to fully complete after cancellation + await asyncio.gather(task_1, task_2, return_exceptions=True) + assert task_1.cancelled() + assert task_2.cancelled() + assert len(registry._listener_tasks) == 0 + assert len(registry._subscribers) == 0 diff --git a/tests/unit_tests/test_redis_pubsub.py b/tests/unit_tests/test_redis_pubsub.py new file mode 100644 index 00000000..97dae81e --- /dev/null +++ b/tests/unit_tests/test_redis_pubsub.py @@ -0,0 +1,262 @@ +"""Tests for RedisPubSub (generic pub/sub layer). + +All tests use a mock Redis client — no real Redis instance required. +""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from skyvern.forge.sdk.redis.pubsub import RedisPubSub + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_redis() -> MagicMock: + """Return a mock redis.asyncio.Redis client.""" + redis = MagicMock() + redis.publish = AsyncMock() + redis.pubsub = MagicMock() + return redis + + +def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock: + """Return a mock PubSub that yields *messages* from ``listen()``. + + Each entry in *messages* should look like: + {"type": "message", "data": '{"key": "val"}'} + + If *block* is True the async generator will hang forever after + exhausting *messages*, which keeps the listener task alive so that + cancellation semantics can be tested. + """ + pubsub = MagicMock() + pubsub.subscribe = AsyncMock() + pubsub.unsubscribe = AsyncMock() + pubsub.close = AsyncMock() + + async def _listen(): + for msg in messages or []: + yield msg + if block: + await asyncio.Event().wait() + + pubsub.listen = _listen + return pubsub + + +PREFIX = "skyvern:test:" + + +# --------------------------------------------------------------------------- +# Tests: subscribe +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_subscribe_creates_queue_and_starts_listener(): + """subscribe() should return a queue and start a background listener task.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + ps = RedisPubSub(redis, channel_prefix=PREFIX) + queue = ps.subscribe("key_1") + + assert isinstance(queue, asyncio.Queue) + assert "key_1" in ps._listener_tasks + task = ps._listener_tasks["key_1"] + assert isinstance(task, asyncio.Task) + + await ps.close() + + +@pytest.mark.asyncio +async def test_subscribe_reuses_listener_for_same_key(): + """A second subscribe for the same key should NOT create a new listener task.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + ps = RedisPubSub(redis, channel_prefix=PREFIX) + ps.subscribe("key_1") + first_task = ps._listener_tasks["key_1"] + + ps.subscribe("key_1") + assert ps._listener_tasks["key_1"] is first_task + + await ps.close() + + +# --------------------------------------------------------------------------- +# Tests: unsubscribe +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves(): + """When the last subscriber unsubscribes, the listener task should be cancelled.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub(block=True) + redis.pubsub.return_value = pubsub + + ps = RedisPubSub(redis, channel_prefix=PREFIX) + queue = ps.subscribe("key_1") + + await asyncio.sleep(0) + + task = ps._listener_tasks["key_1"] + + ps.unsubscribe("key_1", queue) + assert "key_1" not in ps._listener_tasks + + await asyncio.gather(task, return_exceptions=True) + assert task.cancelled() + + await ps.close() + + +@pytest.mark.asyncio +async def test_unsubscribe_keeps_listener_when_subscribers_remain(): + """If other subscribers remain, the listener task should stay alive.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + ps = RedisPubSub(redis, channel_prefix=PREFIX) + q1 = ps.subscribe("key_1") + ps.subscribe("key_1") + + ps.unsubscribe("key_1", q1) + assert "key_1" in ps._listener_tasks + + await ps.close() + + +# --------------------------------------------------------------------------- +# Tests: publish +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_calls_redis_publish(): + """publish() should fire-and-forget a Redis PUBLISH with prefixed channel.""" + redis = _make_mock_redis() + ps = RedisPubSub(redis, channel_prefix=PREFIX) + + ps.publish("key_1", {"type": "event"}) + + await asyncio.sleep(0) + + redis.publish.assert_awaited_once_with( + f"{PREFIX}key_1", + json.dumps({"type": "event"}), + ) + + await ps.close() + + +# --------------------------------------------------------------------------- +# Tests: _dispatch_local +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_dispatch_local_fans_out_to_all_queues(): + """_dispatch_local should put the message into every local queue for the key.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + ps = RedisPubSub(redis, channel_prefix=PREFIX) + q1 = ps.subscribe("key_1") + q2 = ps.subscribe("key_1") + + msg = {"type": "test", "value": 42} + ps._dispatch_local("key_1", msg) + + assert q1.get_nowait() == msg + assert q2.get_nowait() == msg + + await ps.close() + + +@pytest.mark.asyncio +async def test_dispatch_local_does_not_leak_across_keys(): + """Messages dispatched for key_a should not appear in key_b queues.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub() + redis.pubsub.return_value = pubsub + + ps = RedisPubSub(redis, channel_prefix=PREFIX) + q_a = ps.subscribe("key_a") + q_b = ps.subscribe("key_b") + + ps._dispatch_local("key_a", {"type": "test"}) + assert not q_a.empty() + assert q_b.empty() + + await ps.close() + + +# --------------------------------------------------------------------------- +# Tests: close +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_close_cancels_all_listeners_and_clears_state(): + """close() should cancel every listener task and empty subscriber maps.""" + redis = _make_mock_redis() + pubsub = _make_mock_pubsub(block=True) + redis.pubsub.return_value = pubsub + + ps = RedisPubSub(redis, channel_prefix=PREFIX) + ps.subscribe("key_1") + ps.subscribe("key_2") + + await asyncio.sleep(0) + + task_1 = ps._listener_tasks["key_1"] + task_2 = ps._listener_tasks["key_2"] + + await ps.close() + + await asyncio.gather(task_1, task_2, return_exceptions=True) + assert task_1.cancelled() + assert task_2.cancelled() + assert len(ps._listener_tasks) == 0 + assert len(ps._subscribers) == 0 + + +# --------------------------------------------------------------------------- +# Tests: prefix isolation +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_different_prefixes_do_not_interfere(): + """Two RedisPubSub instances with different prefixes use separate channels.""" + redis = _make_mock_redis() + + ps_a = RedisPubSub(redis, channel_prefix="prefix_a:") + ps_b = RedisPubSub(redis, channel_prefix="prefix_b:") + + ps_a.publish("key_1", {"from": "a"}) + ps_b.publish("key_1", {"from": "b"}) + + await asyncio.sleep(0) + + calls = redis.publish.await_args_list + assert len(calls) == 2 + + channels = {call.args[0] for call in calls} + assert "prefix_a:key_1" in channels + assert "prefix_b:key_1" in channels + + await ps_a.close() + await ps_b.close()