[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)
This commit is contained in:
@@ -141,3 +141,12 @@ SKYVERN_AUTH_BITWARDEN_CLIENT_SECRET=your-client-secret-here
|
||||
|
||||
# Timeout in seconds for Bitwarden operations
|
||||
# BITWARDEN_TIMEOUT_SECONDS=60
|
||||
|
||||
# Shared Redis URL used by any service that needs Redis (pub/sub, cache, etc.)
|
||||
# REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
# Notification registry type: "local" (default, in-process) or "redis" (multi-pod)
|
||||
# NOTIFICATION_REGISTRY_TYPE=local
|
||||
|
||||
# Optional: override Redis URL specifically for notifications (falls back to REDIS_URL)
|
||||
# NOTIFICATION_REDIS_URL=
|
||||
@@ -0,0 +1,39 @@
|
||||
"""add 2fa waiting state fields to workflow_runs and tasks
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 43217e31df12
|
||||
Create Date: 2026-02-13 00:00:00.000000+00:00
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a1b2c3d4e5f6"
|
||||
down_revision: Union[str, None] = "43217e31df12"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS waiting_for_verification_code BOOLEAN NOT NULL DEFAULT false"
|
||||
)
|
||||
op.execute("ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS verification_code_identifier VARCHAR")
|
||||
op.execute("ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS verification_code_polling_started_at TIMESTAMP")
|
||||
op.execute(
|
||||
"ALTER TABLE tasks ADD COLUMN IF NOT EXISTS waiting_for_verification_code BOOLEAN NOT NULL DEFAULT false"
|
||||
)
|
||||
op.execute("ALTER TABLE tasks ADD COLUMN IF NOT EXISTS verification_code_identifier VARCHAR")
|
||||
op.execute("ALTER TABLE tasks ADD COLUMN IF NOT EXISTS verification_code_polling_started_at TIMESTAMP")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS verification_code_polling_started_at")
|
||||
op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS verification_code_identifier")
|
||||
op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS waiting_for_verification_code")
|
||||
op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS verification_code_polling_started_at")
|
||||
op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS verification_code_identifier")
|
||||
op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS waiting_for_verification_code")
|
||||
@@ -98,6 +98,13 @@ class Settings(BaseSettings):
|
||||
# Supported storage types: local, s3cloud, azureblob
|
||||
SKYVERN_STORAGE_TYPE: str = "local"
|
||||
|
||||
# Shared Redis URL (used by any service that needs Redis)
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
|
||||
# Notification registry settings ("local" or "redis")
|
||||
NOTIFICATION_REGISTRY_TYPE: str = "local"
|
||||
NOTIFICATION_REDIS_URL: str | None = None # Deprecated: falls back to REDIS_URL
|
||||
|
||||
# S3/AWS settings
|
||||
AWS_REGION: str = "us-east-1"
|
||||
MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
@@ -2544,7 +2544,7 @@ class ForgeAgent:
|
||||
step,
|
||||
browser_state,
|
||||
scraped_page,
|
||||
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
|
||||
verification_code_check=True,
|
||||
expire_verification_code=True,
|
||||
)
|
||||
|
||||
@@ -3169,7 +3169,7 @@ class ForgeAgent:
|
||||
|
||||
current_context = skyvern_context.ensure_context()
|
||||
verification_code = current_context.totp_codes.get(task.task_id)
|
||||
if (task.totp_verification_url or task.totp_identifier) and verification_code:
|
||||
if verification_code:
|
||||
if (
|
||||
isinstance(final_navigation_payload, dict)
|
||||
and SPECIAL_FIELD_VERIFICATION_CODE not in final_navigation_payload
|
||||
@@ -4444,13 +4444,11 @@ class ForgeAgent:
|
||||
if not task.organization_id:
|
||||
return json_response, []
|
||||
|
||||
if not task.totp_verification_url and not task.totp_identifier:
|
||||
return json_response, []
|
||||
|
||||
should_verify_by_magic_link = json_response.get("should_verify_by_magic_link")
|
||||
place_to_enter_verification_code = json_response.get("place_to_enter_verification_code")
|
||||
should_enter_verification_code = json_response.get("should_enter_verification_code")
|
||||
|
||||
# If no OTP verification needed, return early to avoid unnecessary processing
|
||||
if (
|
||||
not should_verify_by_magic_link
|
||||
and not place_to_enter_verification_code
|
||||
@@ -4466,8 +4464,10 @@ class ForgeAgent:
|
||||
return json_response, actions
|
||||
|
||||
if should_verify_by_magic_link:
|
||||
actions = await self.handle_potential_magic_link(task, step, scraped_page, browser_state, json_response)
|
||||
return json_response, actions
|
||||
# Magic links still require TOTP config (need a source to poll the link from)
|
||||
if task.totp_verification_url or task.totp_identifier:
|
||||
actions = await self.handle_potential_magic_link(task, step, scraped_page, browser_state, json_response)
|
||||
return json_response, actions
|
||||
|
||||
return json_response, []
|
||||
|
||||
@@ -4524,12 +4524,7 @@ class ForgeAgent:
|
||||
) -> dict[str, Any]:
|
||||
place_to_enter_verification_code = json_response.get("place_to_enter_verification_code")
|
||||
should_enter_verification_code = json_response.get("should_enter_verification_code")
|
||||
if (
|
||||
place_to_enter_verification_code
|
||||
and should_enter_verification_code
|
||||
and (task.totp_verification_url or task.totp_identifier)
|
||||
and task.organization_id
|
||||
):
|
||||
if place_to_enter_verification_code and should_enter_verification_code and task.organization_id:
|
||||
LOG.info("Need verification code")
|
||||
workflow_id = workflow_permanent_id = None
|
||||
if task.workflow_run_id:
|
||||
|
||||
@@ -80,6 +80,20 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, Any]:
|
||||
# Stop cleanup scheduler
|
||||
await stop_cleanup_scheduler()
|
||||
|
||||
# Close notification registry (e.g. cancel Redis listener tasks)
|
||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||
|
||||
registry = NotificationRegistryFactory.get_registry()
|
||||
if hasattr(registry, "close"):
|
||||
await registry.close()
|
||||
|
||||
# Close shared Redis client (after registry so listener tasks drain first)
|
||||
from skyvern.forge.sdk.redis.factory import RedisClientFactory
|
||||
|
||||
redis_client = RedisClientFactory.get_client()
|
||||
if redis_client is not None:
|
||||
await redis_client.close()
|
||||
|
||||
if forge_app.api_app_shutdown_event:
|
||||
LOG.info("Calling api app shutdown event")
|
||||
try:
|
||||
|
||||
@@ -110,6 +110,18 @@ def create_forge_app() -> ForgeApp:
|
||||
StorageFactory.set_storage(AzureStorage())
|
||||
app.STORAGE = StorageFactory.get_storage()
|
||||
app.CACHE = CacheFactory.get_cache()
|
||||
|
||||
if settings.NOTIFICATION_REGISTRY_TYPE == "redis":
|
||||
from redis.asyncio import from_url as redis_from_url
|
||||
|
||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||
from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry
|
||||
from skyvern.forge.sdk.redis.factory import RedisClientFactory
|
||||
|
||||
redis_url = settings.NOTIFICATION_REDIS_URL or settings.REDIS_URL
|
||||
redis_client = redis_from_url(redis_url, decode_responses=True)
|
||||
RedisClientFactory.set_client(redis_client)
|
||||
NotificationRegistryFactory.set_registry(RedisNotificationRegistry(redis_client))
|
||||
app.ARTIFACT_MANAGER = ArtifactManager()
|
||||
app.BROWSER_MANAGER = RealBrowserManager()
|
||||
app.EXPERIMENTATION_PROVIDER = NoOpExperimentationProvider()
|
||||
|
||||
@@ -849,6 +849,102 @@ class AgentDB(BaseAlchemyDB):
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_task_2fa_state(
|
||||
self,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
waiting_for_verification_code: bool,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
) -> Task:
|
||||
"""Update task 2FA verification code waiting state."""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
task.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
task.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
task.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if not waiting_for_verification_code:
|
||||
# Clear identifiers when no longer waiting
|
||||
task.verification_code_identifier = None
|
||||
task.verification_code_polling_started_at = None
|
||||
await session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
@read_retry()
|
||||
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
|
||||
"""Return active 2FA verification requests for an organization.
|
||||
|
||||
Queries both tasks and workflow runs where waiting_for_verification_code=True.
|
||||
Used to provide initial state when a WebSocket notification client connects.
|
||||
"""
|
||||
results: list[dict] = []
|
||||
async with self.Session() as session:
|
||||
# Tasks waiting for verification (exclude finalized tasks)
|
||||
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
|
||||
task_rows = (
|
||||
await session.scalars(
|
||||
select(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter_by(workflow_run_id=None)
|
||||
.filter(TaskModel.status.not_in(finalized_task_statuses))
|
||||
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for t in task_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": t.task_id,
|
||||
"workflow_run_id": None,
|
||||
"verification_code_identifier": t.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
t.verification_code_polling_started_at.isoformat()
|
||||
if t.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
# Workflow runs waiting for verification (exclude finalized runs)
|
||||
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
|
||||
wr_rows = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
|
||||
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for wr in wr_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": None,
|
||||
"workflow_run_id": wr.workflow_run_id,
|
||||
"verification_code_identifier": wr.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
wr.verification_code_polling_started_at.isoformat()
|
||||
if wr.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
async def bulk_update_tasks(
|
||||
self,
|
||||
task_ids: list[str],
|
||||
@@ -2794,6 +2890,9 @@ class AgentDB(BaseAlchemyDB):
|
||||
ai_fallback: bool | None = None,
|
||||
depends_on_workflow_run_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
waiting_for_verification_code: bool | None = None,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
) -> WorkflowRun:
|
||||
async with self.Session() as session:
|
||||
workflow_run = (
|
||||
@@ -2826,6 +2925,17 @@ class AgentDB(BaseAlchemyDB):
|
||||
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
|
||||
if browser_session_id:
|
||||
workflow_run.browser_session_id = browser_session_id
|
||||
# 2FA verification code waiting state updates
|
||||
if waiting_for_verification_code is not None:
|
||||
workflow_run.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
workflow_run.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if waiting_for_verification_code is not None and not waiting_for_verification_code:
|
||||
# Clear related fields when waiting is set to False
|
||||
workflow_run.verification_code_identifier = None
|
||||
workflow_run.verification_code_polling_started_at = None
|
||||
await session.commit()
|
||||
await save_workflow_run_logs(workflow_run_id)
|
||||
await session.refresh(workflow_run)
|
||||
@@ -3995,6 +4105,35 @@ class AgentDB(BaseAlchemyDB):
|
||||
totp_code = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
|
||||
|
||||
async def get_otp_codes_by_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
task_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
||||
limit: int = 1,
|
||||
) -> list[TOTPCode]:
|
||||
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
|
||||
|
||||
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
|
||||
The user submits codes manually via the UI, and this method finds them by run context.
|
||||
"""
|
||||
if not workflow_run_id and not task_id:
|
||||
return []
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(TOTPCodeModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
|
||||
)
|
||||
if workflow_run_id:
|
||||
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||
elif task_id:
|
||||
query = query.filter(TOTPCodeModel.task_id == task_id)
|
||||
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
||||
results = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(r) for r in results]
|
||||
|
||||
async def get_recent_otp_codes(
|
||||
self,
|
||||
organization_id: str,
|
||||
|
||||
@@ -116,6 +116,10 @@ class TaskModel(Base):
|
||||
model = Column(JSON, nullable=True)
|
||||
browser_address = Column(String, nullable=True)
|
||||
download_timeout = Column(Numeric, nullable=True)
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
|
||||
verification_code_identifier = Column(String, nullable=True)
|
||||
verification_code_polling_started_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class StepModel(Base):
|
||||
@@ -350,6 +354,10 @@ class WorkflowRunModel(Base):
|
||||
debug_session_id: Column = Column(String, nullable=True)
|
||||
ai_fallback = Column(Boolean, nullable=True)
|
||||
code_gen = Column(Boolean, nullable=True)
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
|
||||
verification_code_identifier = Column(String, nullable=True)
|
||||
verification_code_polling_started_at = Column(DateTime, nullable=True)
|
||||
|
||||
queued_at = Column(DateTime, nullable=True)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
|
||||
@@ -211,6 +211,9 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
|
||||
browser_session_id=task_obj.browser_session_id,
|
||||
browser_address=task_obj.browser_address,
|
||||
download_timeout=task_obj.download_timeout,
|
||||
waiting_for_verification_code=task_obj.waiting_for_verification_code or False,
|
||||
verification_code_identifier=task_obj.verification_code_identifier,
|
||||
verification_code_polling_started_at=task_obj.verification_code_polling_started_at,
|
||||
)
|
||||
return task
|
||||
|
||||
@@ -424,6 +427,9 @@ def convert_to_workflow_run(
|
||||
run_with=workflow_run_model.run_with,
|
||||
code_gen=workflow_run_model.code_gen,
|
||||
ai_fallback=workflow_run_model.ai_fallback,
|
||||
waiting_for_verification_code=workflow_run_model.waiting_for_verification_code or False,
|
||||
verification_code_identifier=workflow_run_model.verification_code_identifier,
|
||||
verification_code_polling_started_at=workflow_run_model.verification_code_polling_started_at,
|
||||
)
|
||||
|
||||
|
||||
|
||||
21
skyvern/forge/sdk/notification/base.py
Normal file
21
skyvern/forge/sdk/notification/base.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Abstract base for notification registries."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseNotificationRegistry(ABC):
|
||||
"""Abstract pub/sub registry scoped by organization.
|
||||
|
||||
Implementations must fan-out: a single publish call delivers the
|
||||
message to every active subscriber for that organization.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]: ...
|
||||
|
||||
@abstractmethod
|
||||
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, organization_id: str, message: dict) -> None: ...
|
||||
14
skyvern/forge/sdk/notification/factory.py
Normal file
14
skyvern/forge/sdk/notification/factory.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
||||
|
||||
|
||||
class NotificationRegistryFactory:
|
||||
__registry: BaseNotificationRegistry = LocalNotificationRegistry()
|
||||
|
||||
@staticmethod
|
||||
def set_registry(registry: BaseNotificationRegistry) -> None:
|
||||
NotificationRegistryFactory.__registry = registry
|
||||
|
||||
@staticmethod
|
||||
def get_registry() -> BaseNotificationRegistry:
|
||||
return NotificationRegistryFactory.__registry
|
||||
45
skyvern/forge/sdk/notification/local.py
Normal file
45
skyvern/forge/sdk/notification/local.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""In-process notification registry using asyncio queues (single-pod only)."""
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LocalNotificationRegistry(BaseNotificationRegistry):
|
||||
"""In-process fan-out pub/sub using asyncio queues. Single-pod only."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
|
||||
|
||||
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
|
||||
queue: asyncio.Queue[dict] = asyncio.Queue()
|
||||
self._subscribers[organization_id].append(queue)
|
||||
LOG.info("Notification subscriber added", organization_id=organization_id)
|
||||
return queue
|
||||
|
||||
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
|
||||
queues = self._subscribers.get(organization_id)
|
||||
if queues:
|
||||
try:
|
||||
queues.remove(queue)
|
||||
except ValueError:
|
||||
pass
|
||||
if not queues:
|
||||
del self._subscribers[organization_id]
|
||||
LOG.info("Notification subscriber removed", organization_id=organization_id)
|
||||
|
||||
def publish(self, organization_id: str, message: dict) -> None:
|
||||
queues = self._subscribers.get(organization_id, [])
|
||||
for queue in queues:
|
||||
try:
|
||||
queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
LOG.warning(
|
||||
"Notification queue full, dropping message",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
55
skyvern/forge/sdk/notification/redis.py
Normal file
55
skyvern/forge/sdk/notification/redis.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Redis-backed notification registry for multi-pod deployments.
|
||||
|
||||
Thin adapter around :class:`RedisPubSub` — all Redis pub/sub logic
|
||||
lives in the generic layer; this class maps the ``organization_id``
|
||||
domain concept onto generic string keys.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
|
||||
|
||||
|
||||
class RedisNotificationRegistry(BaseNotificationRegistry):
|
||||
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per org."""
|
||||
|
||||
def __init__(self, redis_client: Redis) -> None:
|
||||
self._pubsub = RedisPubSub(redis_client, channel_prefix="skyvern:notifications:")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Property accessors (used by existing tests)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def _listener_tasks(self) -> dict[str, asyncio.Task[None]]:
|
||||
return self._pubsub._listener_tasks
|
||||
|
||||
@property
|
||||
def _subscribers(self) -> dict[str, list[asyncio.Queue[dict]]]:
|
||||
return self._pubsub._subscribers
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
|
||||
return self._pubsub.subscribe(organization_id)
|
||||
|
||||
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
|
||||
self._pubsub.unsubscribe(organization_id, queue)
|
||||
|
||||
def publish(self, organization_id: str, message: dict) -> None:
|
||||
self._pubsub.publish(organization_id, message)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._pubsub.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helper (exposed for tests)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _dispatch_local(self, organization_id: str, message: dict) -> None:
|
||||
self._pubsub._dispatch_local(organization_id, message)
|
||||
0
skyvern/forge/sdk/redis/__init__.py
Normal file
0
skyvern/forge/sdk/redis/__init__.py
Normal file
21
skyvern/forge/sdk/redis/factory.py
Normal file
21
skyvern/forge/sdk/redis/factory.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class RedisClientFactory:
|
||||
"""Singleton factory for a shared async Redis client.
|
||||
|
||||
Follows the same static set/get pattern as ``CacheFactory``.
|
||||
Defaults to ``None`` (no Redis in local/OSS mode).
|
||||
"""
|
||||
|
||||
__client: Redis | None = None
|
||||
|
||||
@staticmethod
|
||||
def set_client(client: Redis) -> None:
|
||||
RedisClientFactory.__client = client
|
||||
|
||||
@staticmethod
|
||||
def get_client() -> Redis | None:
|
||||
return RedisClientFactory.__client
|
||||
130
skyvern/forge/sdk/redis/pubsub.py
Normal file
130
skyvern/forge/sdk/redis/pubsub.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Generic Redis pub/sub layer.
|
||||
|
||||
Extracted from ``RedisNotificationRegistry`` so that any feature
|
||||
(notifications, events, cache invalidation, etc.) can reuse the same
|
||||
pattern with its own channel prefix.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import structlog
|
||||
from redis.asyncio import Redis
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class RedisPubSub:
|
||||
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per key."""
|
||||
|
||||
def __init__(self, redis_client: Redis, channel_prefix: str) -> None:
|
||||
self._redis = redis_client
|
||||
self._channel_prefix = channel_prefix
|
||||
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
|
||||
# One listener task per key channel
|
||||
self._listener_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, key: str) -> asyncio.Queue[dict]:
|
||||
queue: asyncio.Queue[dict] = asyncio.Queue()
|
||||
self._subscribers[key].append(queue)
|
||||
|
||||
# Spin up a Redis listener if this is the first local subscriber
|
||||
if key not in self._listener_tasks:
|
||||
task = asyncio.get_running_loop().create_task(self._listen(key))
|
||||
self._listener_tasks[key] = task
|
||||
|
||||
LOG.info("PubSub subscriber added", key=key, channel_prefix=self._channel_prefix)
|
||||
return queue
|
||||
|
||||
def unsubscribe(self, key: str, queue: asyncio.Queue[dict]) -> None:
|
||||
queues = self._subscribers.get(key)
|
||||
if queues:
|
||||
try:
|
||||
queues.remove(queue)
|
||||
except ValueError:
|
||||
pass
|
||||
if not queues:
|
||||
del self._subscribers[key]
|
||||
self._cancel_listener(key)
|
||||
LOG.info("PubSub subscriber removed", key=key, channel_prefix=self._channel_prefix)
|
||||
|
||||
def publish(self, key: str, message: dict) -> None:
|
||||
"""Fire-and-forget Redis PUBLISH."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._publish_to_redis(key, message))
|
||||
except RuntimeError:
|
||||
LOG.warning(
|
||||
"No running event loop; cannot publish via Redis",
|
||||
key=key,
|
||||
channel_prefix=self._channel_prefix,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Cancel all listener tasks and clear state. Call on shutdown."""
|
||||
for key in list(self._listener_tasks):
|
||||
self._cancel_listener(key)
|
||||
self._subscribers.clear()
|
||||
LOG.info("RedisPubSub closed", channel_prefix=self._channel_prefix)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _publish_to_redis(self, key: str, message: dict) -> None:
|
||||
channel = f"{self._channel_prefix}{key}"
|
||||
try:
|
||||
await self._redis.publish(channel, json.dumps(message))
|
||||
except Exception:
|
||||
LOG.exception("Failed to publish to Redis", key=key, channel_prefix=self._channel_prefix)
|
||||
|
||||
async def _listen(self, key: str) -> None:
|
||||
"""Subscribe to a Redis channel and fan out messages locally."""
|
||||
channel = f"{self._channel_prefix}{key}"
|
||||
pubsub = self._redis.pubsub()
|
||||
try:
|
||||
await pubsub.subscribe(channel)
|
||||
LOG.info("Redis listener started", channel=channel)
|
||||
async for raw_message in pubsub.listen():
|
||||
if raw_message["type"] != "message":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(raw_message["data"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
LOG.warning("Invalid JSON on Redis channel", channel=channel)
|
||||
continue
|
||||
self._dispatch_local(key, data)
|
||||
except asyncio.CancelledError:
|
||||
LOG.info("Redis listener cancelled", channel=channel)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("Redis listener error", channel=channel)
|
||||
finally:
|
||||
try:
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.close()
|
||||
except Exception:
|
||||
LOG.warning("Error closing Redis pubsub", channel=channel)
|
||||
|
||||
def _dispatch_local(self, key: str, message: dict) -> None:
|
||||
"""Fan out a message to all local asyncio queues for this key."""
|
||||
queues = self._subscribers.get(key, [])
|
||||
for queue in queues:
|
||||
try:
|
||||
queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
LOG.warning(
|
||||
"Queue full, dropping message",
|
||||
key=key,
|
||||
channel_prefix=self._channel_prefix,
|
||||
)
|
||||
|
||||
def _cancel_listener(self, key: str) -> None:
|
||||
task = self._listener_tasks.pop(key, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
@@ -11,5 +11,6 @@ from skyvern.forge.sdk.routes import sdk # noqa: F401
|
||||
from skyvern.forge.sdk.routes import webhooks # noqa: F401
|
||||
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import notifications # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401
|
||||
|
||||
@@ -131,10 +131,15 @@ async def send_totp_code(
|
||||
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
|
||||
workflow_id_for_storage: str | None = None
|
||||
if data.workflow_id:
|
||||
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
|
||||
if data.workflow_id.startswith("wpid_"):
|
||||
workflow = await app.DATABASE.get_workflow_by_permanent_id(data.workflow_id, curr_org.organization_id)
|
||||
else:
|
||||
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
|
||||
workflow_id_for_storage = workflow.workflow_id
|
||||
if data.workflow_run_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
|
||||
if not workflow_run:
|
||||
@@ -162,7 +167,7 @@ async def send_totp_code(
|
||||
content=data.content,
|
||||
code=otp_value.value,
|
||||
task_id=data.task_id,
|
||||
workflow_id=data.workflow_id,
|
||||
workflow_id=workflow_id_for_storage,
|
||||
workflow_run_id=data.workflow_run_id,
|
||||
source=data.source,
|
||||
expired_at=data.expired_at,
|
||||
|
||||
101
skyvern/forge/sdk/routes/streaming/notifications.py
Normal file
101
skyvern/forge/sdk/routes/streaming/notifications.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""WebSocket endpoint for streaming global 2FA verification code notifications."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||
from skyvern.forge.sdk.routes.routers import base_router
|
||||
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
HEARTBEAT_INTERVAL = 60
|
||||
|
||||
|
||||
@base_router.websocket("/stream/notifications")
|
||||
async def notification_stream(
|
||||
websocket: WebSocket,
|
||||
apikey: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
auth = local_auth if settings.ENV == "local" else real_auth
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
if not organization_id:
|
||||
LOG.info("Notifications: Authentication failed")
|
||||
return
|
||||
|
||||
LOG.info("Notifications: Started streaming", organization_id=organization_id)
|
||||
registry = NotificationRegistryFactory.get_registry()
|
||||
queue = registry.subscribe(organization_id)
|
||||
|
||||
try:
|
||||
# Send initial state: all currently active verification requests
|
||||
active_requests = await app.DATABASE.get_active_verification_requests(organization_id)
|
||||
for req in active_requests:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "verification_code_required",
|
||||
"task_id": req.get("task_id"),
|
||||
"workflow_run_id": req.get("workflow_run_id"),
|
||||
"identifier": req.get("verification_code_identifier"),
|
||||
"polling_started_at": req.get("verification_code_polling_started_at"),
|
||||
}
|
||||
)
|
||||
|
||||
# Watch for client disconnect while streaming events
|
||||
disconnect_event = asyncio.Event()
|
||||
|
||||
async def _watch_disconnect() -> None:
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive()
|
||||
except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError):
|
||||
disconnect_event.set()
|
||||
|
||||
watcher = asyncio.create_task(_watch_disconnect())
|
||||
try:
|
||||
while not disconnect_event.is_set():
|
||||
queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL))
|
||||
disconnect_wait = asyncio.ensure_future(disconnect_event.wait())
|
||||
done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED)
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
|
||||
if disconnect_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
message = queue_task.result()
|
||||
await websocket.send_json(message)
|
||||
except TimeoutError:
|
||||
try:
|
||||
await websocket.send_json({"type": "heartbeat"})
|
||||
except Exception:
|
||||
LOG.info(
|
||||
"Notifications: Client unreachable during heartbeat. Closing.",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
watcher.cancel()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id)
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id)
|
||||
except ConnectionClosedError:
|
||||
LOG.warning(
|
||||
"Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True)
|
||||
finally:
|
||||
registry.unsubscribe(organization_id, queue)
|
||||
LOG.info("Notifications: Connection closed", organization_id=organization_id)
|
||||
@@ -288,6 +288,10 @@ class Task(TaskBase):
|
||||
queued_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code: bool = False
|
||||
verification_code_identifier: str | None = None
|
||||
verification_code_polling_started_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def llm_key(self) -> str | None:
|
||||
@@ -365,6 +369,9 @@ class Task(TaskBase):
|
||||
max_screenshot_scrolls=self.max_screenshot_scrolls,
|
||||
step_count=step_count,
|
||||
browser_session_id=self.browser_session_id,
|
||||
waiting_for_verification_code=self.waiting_for_verification_code,
|
||||
verification_code_identifier=self.verification_code_identifier,
|
||||
verification_code_polling_started_at=self.verification_code_polling_started_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -392,6 +399,10 @@ class TaskResponse(BaseModel):
|
||||
max_screenshot_scrolls: int | None = None
|
||||
step_count: int | None = None
|
||||
browser_session_id: str | None = None
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code: bool = False
|
||||
verification_code_identifier: str | None = None
|
||||
verification_code_polling_started_at: datetime | None = None
|
||||
|
||||
|
||||
class TaskOutput(BaseModel):
|
||||
|
||||
@@ -172,6 +172,10 @@ class WorkflowRun(BaseModel):
|
||||
sequential_key: str | None = None
|
||||
ai_fallback: bool | None = None
|
||||
code_gen: bool | None = None
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code: bool = False
|
||||
verification_code_identifier: str | None = None
|
||||
verification_code_polling_started_at: datetime | None = None
|
||||
|
||||
queued_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
@@ -226,6 +230,10 @@ class WorkflowRunResponseBase(BaseModel):
|
||||
browser_address: str | None = None
|
||||
script_run: ScriptRunResponse | None = None
|
||||
errors: list[dict[str, Any]] | None = None
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code: bool = False
|
||||
verification_code_identifier: str | None = None
|
||||
verification_code_polling_started_at: datetime | None = None
|
||||
|
||||
|
||||
class WorkflowRunWithWorkflowResponse(WorkflowRunResponseBase):
|
||||
|
||||
@@ -3019,6 +3019,10 @@ class WorkflowService:
|
||||
browser_address=workflow_run.browser_address,
|
||||
script_run=workflow_run.script_run,
|
||||
errors=errors,
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code=workflow_run.waiting_for_verification_code,
|
||||
verification_code_identifier=workflow_run.verification_code_identifier,
|
||||
verification_code_polling_started_at=workflow_run.verification_code_polling_started_at,
|
||||
)
|
||||
|
||||
async def clean_up_workflow(
|
||||
|
||||
@@ -11,6 +11,7 @@ from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -80,38 +81,147 @@ async def poll_otp_value(
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
)
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
# check timeout
|
||||
if datetime.utcnow() > timeout_datetime:
|
||||
LOG.warning("Polling otp value timed out")
|
||||
raise NoTOTPVerificationCodeFound(
|
||||
task_id=task_id,
|
||||
|
||||
# Set the waiting state in the database when polling starts
|
||||
identifier_for_ui = totp_identifier
|
||||
if workflow_run_id:
|
||||
try:
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_permanent_id,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
waiting_for_verification_code=True,
|
||||
verification_code_identifier=identifier_for_ui,
|
||||
verification_code_polling_started_at=start_datetime,
|
||||
)
|
||||
otp_value: OTPValue | None = None
|
||||
if totp_verification_url:
|
||||
otp_value = await _get_otp_value_from_url(
|
||||
organization_id,
|
||||
totp_verification_url,
|
||||
org_token.token,
|
||||
task_id=task_id,
|
||||
LOG.info(
|
||||
"Set 2FA waiting state for workflow run",
|
||||
workflow_run_id=workflow_run_id,
|
||||
verification_code_identifier=identifier_for_ui,
|
||||
)
|
||||
elif totp_identifier:
|
||||
otp_value = await _get_otp_value_from_db(
|
||||
organization_id,
|
||||
totp_identifier,
|
||||
try:
|
||||
NotificationRegistryFactory.get_registry().publish(
|
||||
organization_id,
|
||||
{
|
||||
"type": "verification_code_required",
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"task_id": task_id,
|
||||
"identifier": identifier_for_ui,
|
||||
"polling_started_at": start_datetime.isoformat(),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning("Failed to publish 2FA required notification for workflow run", exc_info=True)
|
||||
except Exception:
|
||||
LOG.warning("Failed to set 2FA waiting state for workflow run", exc_info=True)
|
||||
elif task_id:
|
||||
try:
|
||||
await app.DATABASE.update_task_2fa_state(
|
||||
task_id=task_id,
|
||||
workflow_id=workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
waiting_for_verification_code=True,
|
||||
verification_code_identifier=identifier_for_ui,
|
||||
verification_code_polling_started_at=start_datetime,
|
||||
)
|
||||
if otp_value:
|
||||
LOG.info("Got otp value", otp_value=otp_value)
|
||||
return otp_value
|
||||
LOG.info(
|
||||
"Set 2FA waiting state for task",
|
||||
task_id=task_id,
|
||||
verification_code_identifier=identifier_for_ui,
|
||||
)
|
||||
try:
|
||||
NotificationRegistryFactory.get_registry().publish(
|
||||
organization_id,
|
||||
{
|
||||
"type": "verification_code_required",
|
||||
"task_id": task_id,
|
||||
"identifier": identifier_for_ui,
|
||||
"polling_started_at": start_datetime.isoformat(),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning("Failed to publish 2FA required notification for task", exc_info=True)
|
||||
except Exception:
|
||||
LOG.warning("Failed to set 2FA waiting state for task", exc_info=True)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
# check timeout
|
||||
if datetime.utcnow() > timeout_datetime:
|
||||
LOG.warning("Polling otp value timed out")
|
||||
raise NoTOTPVerificationCodeFound(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_permanent_id,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
)
|
||||
otp_value: OTPValue | None = None
|
||||
if totp_verification_url:
|
||||
otp_value = await _get_otp_value_from_url(
|
||||
organization_id,
|
||||
totp_verification_url,
|
||||
org_token.token,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
elif totp_identifier:
|
||||
otp_value = await _get_otp_value_from_db(
|
||||
organization_id,
|
||||
totp_identifier,
|
||||
task_id=task_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
if not otp_value:
|
||||
otp_value = await _get_otp_value_by_run(
|
||||
organization_id,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
else:
|
||||
# No pre-configured TOTP — poll for manually submitted codes by run context
|
||||
otp_value = await _get_otp_value_by_run(
|
||||
organization_id,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
if otp_value:
|
||||
LOG.info("Got otp value", otp_value=otp_value)
|
||||
return otp_value
|
||||
finally:
|
||||
# Clear the waiting state when polling completes (success, timeout, or error)
|
||||
if workflow_run_id:
|
||||
try:
|
||||
await app.DATABASE.update_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
waiting_for_verification_code=False,
|
||||
)
|
||||
LOG.info("Cleared 2FA waiting state for workflow run", workflow_run_id=workflow_run_id)
|
||||
try:
|
||||
NotificationRegistryFactory.get_registry().publish(
|
||||
organization_id,
|
||||
{"type": "verification_code_resolved", "workflow_run_id": workflow_run_id, "task_id": task_id},
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning("Failed to publish 2FA resolved notification for workflow run", exc_info=True)
|
||||
except Exception:
|
||||
LOG.warning("Failed to clear 2FA waiting state for workflow run", exc_info=True)
|
||||
elif task_id:
|
||||
try:
|
||||
await app.DATABASE.update_task_2fa_state(
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
waiting_for_verification_code=False,
|
||||
)
|
||||
LOG.info("Cleared 2FA waiting state for task", task_id=task_id)
|
||||
try:
|
||||
NotificationRegistryFactory.get_registry().publish(
|
||||
organization_id,
|
||||
{"type": "verification_code_resolved", "task_id": task_id},
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning("Failed to publish 2FA resolved notification for task", exc_info=True)
|
||||
except Exception:
|
||||
LOG.warning("Failed to clear 2FA waiting state for task", exc_info=True)
|
||||
|
||||
|
||||
async def _get_otp_value_from_url(
|
||||
@@ -175,6 +285,28 @@ async def _get_otp_value_from_url(
|
||||
return otp_value
|
||||
|
||||
|
||||
async def _get_otp_value_by_run(
|
||||
organization_id: str,
|
||||
task_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
) -> OTPValue | None:
|
||||
"""Look up OTP codes by task_id/workflow_run_id when no totp_identifier is configured.
|
||||
|
||||
Used for the manual 2FA input flow where users submit codes through the UI
|
||||
without pre-configured TOTP credentials.
|
||||
"""
|
||||
codes = await app.DATABASE.get_otp_codes_by_run(
|
||||
organization_id=organization_id,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
limit=1,
|
||||
)
|
||||
if codes:
|
||||
code = codes[0]
|
||||
return OTPValue(value=code.code, type=code.otp_type)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_otp_value_from_db(
|
||||
organization_id: str,
|
||||
totp_identifier: str,
|
||||
|
||||
106
tests/unit_tests/test_notification_registry.py
Normal file
106
tests/unit_tests/test_notification_registry.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Tests for NotificationRegistry pub/sub and get_active_verification_requests (SKY-6)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
||||
|
||||
# === Task 1: NotificationRegistry subscribe / publish / unsubscribe ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_and_publish():
|
||||
"""Published messages should be received by subscribers."""
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
registry.publish("org_1", {"type": "verification_code_required", "task_id": "tsk_1"})
|
||||
msg = queue.get_nowait()
|
||||
assert msg["type"] == "verification_code_required"
|
||||
assert msg["task_id"] == "tsk_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_subscribers():
|
||||
"""All subscribers for an org should receive the same message."""
|
||||
registry = LocalNotificationRegistry()
|
||||
q1 = registry.subscribe("org_1")
|
||||
q2 = registry.subscribe("org_1")
|
||||
|
||||
registry.publish("org_1", {"type": "verification_code_required"})
|
||||
assert not q1.empty()
|
||||
assert not q2.empty()
|
||||
assert q1.get_nowait() == q2.get_nowait()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_wrong_org_does_not_leak():
|
||||
"""Messages for org_A should not appear in org_B's queue."""
|
||||
registry = LocalNotificationRegistry()
|
||||
q_a = registry.subscribe("org_a")
|
||||
q_b = registry.subscribe("org_b")
|
||||
|
||||
registry.publish("org_a", {"type": "test"})
|
||||
assert not q_a.empty()
|
||||
assert q_b.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe():
|
||||
"""After unsubscribe, the queue should no longer receive messages."""
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
registry.unsubscribe("org_1", queue)
|
||||
registry.publish("org_1", {"type": "test"})
|
||||
assert queue.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_idempotent():
|
||||
"""Unsubscribing a queue that's already removed should not raise."""
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
registry.unsubscribe("org_1", queue)
|
||||
registry.unsubscribe("org_1", queue) # should not raise
|
||||
|
||||
|
||||
# === Task: BaseNotificationRegistry ABC ===
|
||||
|
||||
|
||||
def test_base_notification_registry_cannot_be_instantiated():
|
||||
"""ABC should not be directly instantiable."""
|
||||
with pytest.raises(TypeError):
|
||||
BaseNotificationRegistry()
|
||||
|
||||
|
||||
# === Task: NotificationRegistryFactory ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_returns_local_by_default():
|
||||
"""Factory should return a LocalNotificationRegistry by default."""
|
||||
registry = NotificationRegistryFactory.get_registry()
|
||||
assert isinstance(registry, LocalNotificationRegistry)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_set_and_get():
|
||||
"""Factory should allow swapping the registry implementation."""
|
||||
original = NotificationRegistryFactory.get_registry()
|
||||
try:
|
||||
custom = LocalNotificationRegistry()
|
||||
NotificationRegistryFactory.set_registry(custom)
|
||||
assert NotificationRegistryFactory.get_registry() is custom
|
||||
finally:
|
||||
NotificationRegistryFactory.set_registry(original)
|
||||
|
||||
|
||||
# === Task 2: get_active_verification_requests DB method ===
|
||||
|
||||
|
||||
def test_get_active_verification_requests_method_exists():
|
||||
"""AgentDB should have get_active_verification_requests method."""
|
||||
assert hasattr(AgentDB, "get_active_verification_requests")
|
||||
544
tests/unit_tests/test_otp_no_config.py
Normal file
544
tests/unit_tests/test_otp_no_config.py
Normal file
@@ -0,0 +1,544 @@
|
||||
"""Tests for manual 2FA input without pre-configured TOTP credentials (SKY-6)."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.constants import SPECIAL_FIELD_VERIFICATION_CODE
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
||||
from skyvern.forge.sdk.routes.credentials import send_totp_code
|
||||
from skyvern.forge.sdk.schemas.totp_codes import TOTPCodeCreate
|
||||
from skyvern.schemas.runs import RunEngine
|
||||
from skyvern.services.otp_service import OTPValue, _get_otp_value_by_run, poll_otp_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_otp_codes_by_run_exists():
|
||||
"""get_otp_codes_by_run should exist on AgentDB."""
|
||||
assert hasattr(AgentDB, "get_otp_codes_by_run"), "AgentDB missing get_otp_codes_by_run method"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_otp_codes_by_run_returns_empty_without_identifiers():
|
||||
"""get_otp_codes_by_run should return [] when neither task_id nor workflow_run_id is given."""
|
||||
db = AgentDB.__new__(AgentDB)
|
||||
result = await db.get_otp_codes_by_run(
|
||||
organization_id="org_1",
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
# === Task 2: _get_otp_value_by_run OTP service function ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_otp_value_by_run_returns_code():
|
||||
"""_get_otp_value_by_run should find OTP codes by task_id."""
|
||||
mock_code = MagicMock()
|
||||
mock_code.code = "123456"
|
||||
mock_code.otp_type = "totp"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
with patch("skyvern.services.otp_service.app", new=mock_app):
|
||||
result = await _get_otp_value_by_run(
|
||||
organization_id="org_1",
|
||||
task_id="tsk_1",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.value == "123456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_otp_value_by_run_returns_none_when_no_codes():
|
||||
"""_get_otp_value_by_run should return None when no codes found."""
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_otp_codes_by_run.return_value = []
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
with patch("skyvern.services.otp_service.app", new=mock_app):
|
||||
result = await _get_otp_value_by_run(
|
||||
organization_id="org_1",
|
||||
task_id="tsk_1",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
# === Task 3: poll_otp_value without identifier ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_otp_value_without_identifier_uses_run_lookup():
|
||||
"""poll_otp_value should use _get_otp_value_by_run when no identifier/URL provided."""
|
||||
mock_code = MagicMock()
|
||||
mock_code.code = "123456"
|
||||
mock_code.otp_type = "totp"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||
mock_db.update_task_2fa_state = AsyncMock()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
with (
|
||||
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
result = await poll_otp_value(
|
||||
organization_id="org_1",
|
||||
task_id="tsk_1",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.value == "123456"
|
||||
|
||||
|
||||
# === Task 6: Integration test — handle_potential_OTP_actions without TOTP config ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_potential_OTP_actions_without_totp_config():
|
||||
"""When LLM detects 2FA but no TOTP config exists, should still enter verification flow."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.organization_id = "org_1"
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_1"
|
||||
task.workflow_run_id = None
|
||||
|
||||
step = MagicMock()
|
||||
scraped_page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
|
||||
json_response = {
|
||||
"should_enter_verification_code": True,
|
||||
"place_to_enter_verification_code": "input#otp-code",
|
||||
"actions": [],
|
||||
}
|
||||
|
||||
with patch.object(agent, "handle_potential_verification_code", new_callable=AsyncMock) as mock_handler:
|
||||
mock_handler.return_value = {"actions": []}
|
||||
with patch("skyvern.forge.agent.parse_actions", return_value=[]):
|
||||
result_json, result_actions = await agent.handle_potential_OTP_actions(
|
||||
task, step, scraped_page, browser_state, json_response
|
||||
)
|
||||
mock_handler.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_potential_OTP_actions_skips_magic_link_without_totp_config():
|
||||
"""Magic links should still require TOTP config."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.organization_id = "org_1"
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
|
||||
step = MagicMock()
|
||||
scraped_page = MagicMock()
|
||||
browser_state = MagicMock()
|
||||
|
||||
json_response = {
|
||||
"should_verify_by_magic_link": True,
|
||||
}
|
||||
|
||||
with patch.object(agent, "handle_potential_magic_link", new_callable=AsyncMock) as mock_handler:
|
||||
result_json, result_actions = await agent.handle_potential_OTP_actions(
|
||||
task, step, scraped_page, browser_state, json_response
|
||||
)
|
||||
mock_handler.assert_not_called()
|
||||
assert result_actions == []
|
||||
|
||||
|
||||
# === Task 7: verification_code_check always True in LLM prompt ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verification_code_check_always_true_without_totp_config():
|
||||
"""_build_extract_action_prompt should receive verification_code_check=True even when task has no TOTP config."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
agent.async_operation_pool = MagicMock()
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_1"
|
||||
task.workflow_run_id = None
|
||||
task.organization_id = "org_1"
|
||||
task.url = "https://example.com"
|
||||
|
||||
step = MagicMock()
|
||||
step.step_id = "step_1"
|
||||
step.order = 0
|
||||
step.retry_index = 0
|
||||
|
||||
scraped_page = MagicMock()
|
||||
scraped_page.elements = []
|
||||
|
||||
browser_state = MagicMock()
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
|
||||
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
|
||||
patch.object(
|
||||
agent,
|
||||
"_build_extract_action_prompt",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("prompt", False, "extract_action"),
|
||||
) as mock_build,
|
||||
):
|
||||
mock_ctx.current.return_value = None
|
||||
await agent.build_and_record_step_prompt(
|
||||
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
|
||||
)
|
||||
mock_build.assert_called_once()
|
||||
_, kwargs = mock_build.call_args
|
||||
assert kwargs["verification_code_check"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verification_code_check_always_true_with_totp_config():
|
||||
"""_build_extract_action_prompt should receive verification_code_check=True when task HAS TOTP config (unchanged)."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
agent.async_operation_pool = MagicMock()
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = "https://otp.example.com"
|
||||
task.totp_identifier = "user@example.com"
|
||||
task.task_id = "tsk_2"
|
||||
task.workflow_run_id = None
|
||||
task.organization_id = "org_1"
|
||||
task.url = "https://example.com"
|
||||
|
||||
step = MagicMock()
|
||||
step.step_id = "step_1"
|
||||
step.order = 0
|
||||
step.retry_index = 0
|
||||
|
||||
scraped_page = MagicMock()
|
||||
scraped_page.elements = []
|
||||
|
||||
browser_state = MagicMock()
|
||||
|
||||
with (
|
||||
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
|
||||
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
|
||||
patch.object(
|
||||
agent,
|
||||
"_build_extract_action_prompt",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("prompt", False, "extract_action"),
|
||||
) as mock_build,
|
||||
):
|
||||
mock_ctx.current.return_value = None
|
||||
await agent.build_and_record_step_prompt(
|
||||
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
|
||||
)
|
||||
mock_build.assert_called_once()
|
||||
_, kwargs = mock_build.call_args
|
||||
assert kwargs["verification_code_check"] is True
|
||||
|
||||
|
||||
# === Fix: poll_otp_value should pass workflow_id, not workflow_permanent_id ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_otp_value_passes_workflow_id_not_permanent_id():
|
||||
"""poll_otp_value should pass workflow_id (w_* format) to _get_otp_value_from_db, not workflow_permanent_id."""
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
with (
|
||||
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||
patch(
|
||||
"skyvern.services.otp_service._get_otp_value_from_db",
|
||||
new_callable=AsyncMock,
|
||||
return_value=OTPValue(value="654321", type="totp"),
|
||||
) as mock_get_from_db,
|
||||
):
|
||||
result = await poll_otp_value(
|
||||
organization_id="org_1",
|
||||
workflow_id="w_123",
|
||||
workflow_run_id="wr_789",
|
||||
workflow_permanent_id="wpid_456",
|
||||
totp_identifier="user@example.com",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.value == "654321"
|
||||
mock_get_from_db.assert_called_once_with(
|
||||
"org_1",
|
||||
"user@example.com",
|
||||
task_id=None,
|
||||
workflow_id="w_123",
|
||||
workflow_run_id="wr_789",
|
||||
)
|
||||
|
||||
|
||||
# === Fix: send_totp_code should resolve wpid_* to w_* before storage ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_totp_code_resolves_wpid_to_workflow_id():
|
||||
"""send_totp_code should resolve wpid_* to w_* before storing in DB."""
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.workflow_id = "w_abc123"
|
||||
|
||||
mock_totp_code = MagicMock()
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
data = TOTPCodeCreate(
|
||||
totp_identifier="user@example.com",
|
||||
content="123456",
|
||||
workflow_id="wpid_xyz789",
|
||||
)
|
||||
curr_org = MagicMock()
|
||||
curr_org.organization_id = "org_1"
|
||||
|
||||
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
||||
await send_totp_code(data=data, curr_org=curr_org)
|
||||
|
||||
mock_db.create_otp_code.assert_called_once()
|
||||
call_kwargs = mock_db.create_otp_code.call_args[1]
|
||||
assert call_kwargs["workflow_id"] == "w_abc123", f"Expected w_abc123 but got {call_kwargs['workflow_id']}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_totp_code_w_format_passes_through():
|
||||
"""send_totp_code should resolve and store w_* format workflow_id correctly."""
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.workflow_id = "w_abc123"
|
||||
|
||||
mock_totp_code = MagicMock()
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_workflow = AsyncMock(return_value=mock_workflow)
|
||||
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
data = TOTPCodeCreate(
|
||||
totp_identifier="user@example.com",
|
||||
content="123456",
|
||||
workflow_id="w_abc123",
|
||||
)
|
||||
curr_org = MagicMock()
|
||||
curr_org.organization_id = "org_1"
|
||||
|
||||
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
||||
await send_totp_code(data=data, curr_org=curr_org)
|
||||
|
||||
call_kwargs = mock_db.create_otp_code.call_args[1]
|
||||
assert call_kwargs["workflow_id"] == "w_abc123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_totp_code_none_workflow_id():
|
||||
"""send_totp_code should pass None workflow_id when not provided."""
|
||||
mock_totp_code = MagicMock()
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
data = TOTPCodeCreate(
|
||||
totp_identifier="user@example.com",
|
||||
content="123456",
|
||||
)
|
||||
curr_org = MagicMock()
|
||||
curr_org.organization_id = "org_1"
|
||||
|
||||
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
||||
await send_totp_code(data=data, curr_org=curr_org)
|
||||
|
||||
call_kwargs = mock_db.create_otp_code.call_args[1]
|
||||
assert call_kwargs["workflow_id"] is None
|
||||
|
||||
|
||||
# === Fix: _build_navigation_payload should inject code without TOTP config ===
|
||||
|
||||
|
||||
def test_build_navigation_payload_injects_code_without_totp_config():
|
||||
"""_build_navigation_payload should inject SPECIAL_FIELD_VERIFICATION_CODE even when
|
||||
task has no totp_verification_url or totp_identifier (manual 2FA flow)."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_manual_2fa"
|
||||
task.workflow_run_id = "wr_123"
|
||||
task.navigation_payload = {"username": "user@example.com"}
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.totp_codes = {"tsk_manual_2fa": "123456"}
|
||||
mock_context.has_magic_link_page.return_value = False
|
||||
|
||||
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
||||
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
||||
result = agent._build_navigation_payload(task)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert SPECIAL_FIELD_VERIFICATION_CODE in result
|
||||
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "123456"
|
||||
# Original payload preserved
|
||||
assert result["username"] == "user@example.com"
|
||||
|
||||
|
||||
def test_build_navigation_payload_injects_code_when_payload_is_none():
|
||||
"""_build_navigation_payload should create a dict with the code when payload is None."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_manual_2fa"
|
||||
task.workflow_run_id = "wr_123"
|
||||
task.navigation_payload = None
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.totp_codes = {"tsk_manual_2fa": "999999"}
|
||||
mock_context.has_magic_link_page.return_value = False
|
||||
|
||||
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
||||
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
||||
result = agent._build_navigation_payload(task)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "999999"
|
||||
|
||||
|
||||
def test_build_navigation_payload_no_code_no_injection():
|
||||
"""_build_navigation_payload should NOT inject anything when no code in context."""
|
||||
agent = ForgeAgent.__new__(ForgeAgent)
|
||||
|
||||
task = MagicMock()
|
||||
task.totp_verification_url = None
|
||||
task.totp_identifier = None
|
||||
task.task_id = "tsk_no_code"
|
||||
task.workflow_run_id = "wr_456"
|
||||
task.navigation_payload = {"field": "value"}
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.totp_codes = {} # No code in context
|
||||
mock_context.has_magic_link_page.return_value = False
|
||||
|
||||
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
||||
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
||||
result = agent._build_navigation_payload(task)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert SPECIAL_FIELD_VERIFICATION_CODE not in result
|
||||
assert result["field"] == "value"
|
||||
|
||||
|
||||
# === Task: poll_otp_value publishes 2FA events to notification registry ===
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_otp_value_publishes_required_event_for_task():
|
||||
"""poll_otp_value should publish verification_code_required when task waiting state is set."""
|
||||
mock_code = MagicMock()
|
||||
mock_code.code = "123456"
|
||||
mock_code.otp_type = "totp"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||
mock_db.update_task_2fa_state = AsyncMock()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
with (
|
||||
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||
patch(
|
||||
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
|
||||
new=registry,
|
||||
),
|
||||
):
|
||||
await poll_otp_value(organization_id="org_1", task_id="tsk_1")
|
||||
|
||||
# Should have received required + resolved messages
|
||||
messages = []
|
||||
while not queue.empty():
|
||||
messages.append(queue.get_nowait())
|
||||
|
||||
types = [m["type"] for m in messages]
|
||||
assert "verification_code_required" in types
|
||||
assert "verification_code_resolved" in types
|
||||
|
||||
required = next(m for m in messages if m["type"] == "verification_code_required")
|
||||
assert required["task_id"] == "tsk_1"
|
||||
|
||||
resolved = next(m for m in messages if m["type"] == "verification_code_resolved")
|
||||
assert resolved["task_id"] == "tsk_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_otp_value_publishes_required_event_for_workflow_run():
|
||||
"""poll_otp_value should publish verification_code_required when workflow run waiting state is set."""
|
||||
mock_code = MagicMock()
|
||||
mock_code.code = "654321"
|
||||
mock_code.otp_type = "totp"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.DATABASE = mock_db
|
||||
|
||||
registry = LocalNotificationRegistry()
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
with (
|
||||
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||
patch(
|
||||
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
|
||||
new=registry,
|
||||
),
|
||||
):
|
||||
await poll_otp_value(organization_id="org_1", workflow_run_id="wr_1")
|
||||
|
||||
messages = []
|
||||
while not queue.empty():
|
||||
messages.append(queue.get_nowait())
|
||||
|
||||
types = [m["type"] for m in messages]
|
||||
assert "verification_code_required" in types
|
||||
assert "verification_code_resolved" in types
|
||||
|
||||
required = next(m for m in messages if m["type"] == "verification_code_required")
|
||||
assert required["workflow_run_id"] == "wr_1"
|
||||
22
tests/unit_tests/test_redis_client_factory.py
Normal file
22
tests/unit_tests/test_redis_client_factory.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Tests for RedisClientFactory."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from skyvern.forge.sdk.redis.factory import RedisClientFactory
|
||||
|
||||
|
||||
def test_default_is_none():
|
||||
"""Factory returns None when no client has been set."""
|
||||
# Reset to default state
|
||||
RedisClientFactory.set_client(None) # type: ignore[arg-type]
|
||||
assert RedisClientFactory.get_client() is None
|
||||
|
||||
|
||||
def test_set_and_get():
|
||||
"""Round-trip: set_client then get_client returns the same object."""
|
||||
mock_client = MagicMock()
|
||||
RedisClientFactory.set_client(mock_client)
|
||||
assert RedisClientFactory.get_client() is mock_client
|
||||
|
||||
# Cleanup
|
||||
RedisClientFactory.set_client(None) # type: ignore[arg-type]
|
||||
237
tests/unit_tests/test_redis_notification_registry.py
Normal file
237
tests/unit_tests/test_redis_notification_registry.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Tests for RedisNotificationRegistry (SKY-6).
|
||||
|
||||
All tests use a mock Redis client — no real Redis instance required.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_redis() -> MagicMock:
|
||||
"""Return a mock redis.asyncio.Redis client."""
|
||||
redis = MagicMock()
|
||||
redis.publish = AsyncMock()
|
||||
redis.pubsub = MagicMock()
|
||||
return redis
|
||||
|
||||
|
||||
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
|
||||
"""Return a mock PubSub that yields *messages* from ``listen()``.
|
||||
|
||||
Each entry in *messages* should look like:
|
||||
{"type": "message", "data": '{"key": "val"}'}
|
||||
|
||||
If *block* is True the async generator will hang forever after
|
||||
exhausting *messages*, which keeps the listener task alive so that
|
||||
cancellation semantics can be tested.
|
||||
"""
|
||||
pubsub = MagicMock()
|
||||
pubsub.subscribe = AsyncMock()
|
||||
pubsub.unsubscribe = AsyncMock()
|
||||
pubsub.close = AsyncMock()
|
||||
|
||||
async def _listen():
|
||||
for msg in messages or []:
|
||||
yield msg
|
||||
if block:
|
||||
# Keep the listener alive until cancelled
|
||||
await asyncio.Event().wait()
|
||||
|
||||
pubsub.listen = _listen
|
||||
return pubsub
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: subscribe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_creates_queue_and_starts_listener():
|
||||
"""subscribe() should return a queue and start a background listener task."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
assert isinstance(queue, asyncio.Queue)
|
||||
assert "org_1" in registry._listener_tasks
|
||||
task = registry._listener_tasks["org_1"]
|
||||
assert isinstance(task, asyncio.Task)
|
||||
|
||||
# Cleanup
|
||||
await registry.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_reuses_listener_for_same_org():
|
||||
"""A second subscribe for the same org should NOT create a new listener task."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
registry.subscribe("org_1")
|
||||
first_task = registry._listener_tasks["org_1"]
|
||||
|
||||
registry.subscribe("org_1")
|
||||
assert registry._listener_tasks["org_1"] is first_task
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: unsubscribe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
|
||||
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub(block=True)
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
queue = registry.subscribe("org_1")
|
||||
|
||||
# Let the listener task start running
|
||||
await asyncio.sleep(0)
|
||||
|
||||
task = registry._listener_tasks["org_1"]
|
||||
|
||||
registry.unsubscribe("org_1", queue)
|
||||
assert "org_1" not in registry._listener_tasks
|
||||
|
||||
# Wait for the task to fully complete after cancellation
|
||||
await asyncio.gather(task, return_exceptions=True)
|
||||
assert task.cancelled()
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
|
||||
"""If other subscribers remain, the listener task should stay alive."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
q1 = registry.subscribe("org_1")
|
||||
registry.subscribe("org_1") # second subscriber
|
||||
|
||||
registry.unsubscribe("org_1", q1)
|
||||
assert "org_1" in registry._listener_tasks
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: publish
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_calls_redis_publish():
|
||||
"""publish() should fire-and-forget a Redis PUBLISH."""
|
||||
redis = _make_mock_redis()
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
|
||||
registry.publish("org_1", {"type": "verification_code_required"})
|
||||
|
||||
# Allow the fire-and-forget task to execute
|
||||
await asyncio.sleep(0)
|
||||
|
||||
redis.publish.assert_awaited_once_with(
|
||||
"skyvern:notifications:org_1",
|
||||
json.dumps({"type": "verification_code_required"}),
|
||||
)
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _dispatch_local
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_local_fans_out_to_all_queues():
|
||||
"""_dispatch_local should put the message into every local queue for the org."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
q1 = registry.subscribe("org_1")
|
||||
q2 = registry.subscribe("org_1")
|
||||
|
||||
msg = {"type": "test", "value": 42}
|
||||
registry._dispatch_local("org_1", msg)
|
||||
|
||||
assert q1.get_nowait() == msg
|
||||
assert q2.get_nowait() == msg
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_local_does_not_leak_across_orgs():
|
||||
"""Messages dispatched for org_a should not appear in org_b queues."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
q_a = registry.subscribe("org_a")
|
||||
q_b = registry.subscribe("org_b")
|
||||
|
||||
registry._dispatch_local("org_a", {"type": "test"})
|
||||
assert not q_a.empty()
|
||||
assert q_b.empty()
|
||||
|
||||
await registry.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: close
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_cancels_all_listeners_and_clears_state():
|
||||
"""close() should cancel every listener task and empty subscriber maps."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub(block=True)
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
registry = RedisNotificationRegistry(redis)
|
||||
registry.subscribe("org_1")
|
||||
registry.subscribe("org_2")
|
||||
|
||||
# Let the listener tasks start running
|
||||
await asyncio.sleep(0)
|
||||
|
||||
task_1 = registry._listener_tasks["org_1"]
|
||||
task_2 = registry._listener_tasks["org_2"]
|
||||
|
||||
await registry.close()
|
||||
|
||||
# Wait for the tasks to fully complete after cancellation
|
||||
await asyncio.gather(task_1, task_2, return_exceptions=True)
|
||||
assert task_1.cancelled()
|
||||
assert task_2.cancelled()
|
||||
assert len(registry._listener_tasks) == 0
|
||||
assert len(registry._subscribers) == 0
|
||||
262
tests/unit_tests/test_redis_pubsub.py
Normal file
262
tests/unit_tests/test_redis_pubsub.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Tests for RedisPubSub (generic pub/sub layer).
|
||||
|
||||
All tests use a mock Redis client — no real Redis instance required.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_redis() -> MagicMock:
|
||||
"""Return a mock redis.asyncio.Redis client."""
|
||||
redis = MagicMock()
|
||||
redis.publish = AsyncMock()
|
||||
redis.pubsub = MagicMock()
|
||||
return redis
|
||||
|
||||
|
||||
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
|
||||
"""Return a mock PubSub that yields *messages* from ``listen()``.
|
||||
|
||||
Each entry in *messages* should look like:
|
||||
{"type": "message", "data": '{"key": "val"}'}
|
||||
|
||||
If *block* is True the async generator will hang forever after
|
||||
exhausting *messages*, which keeps the listener task alive so that
|
||||
cancellation semantics can be tested.
|
||||
"""
|
||||
pubsub = MagicMock()
|
||||
pubsub.subscribe = AsyncMock()
|
||||
pubsub.unsubscribe = AsyncMock()
|
||||
pubsub.close = AsyncMock()
|
||||
|
||||
async def _listen():
|
||||
for msg in messages or []:
|
||||
yield msg
|
||||
if block:
|
||||
await asyncio.Event().wait()
|
||||
|
||||
pubsub.listen = _listen
|
||||
return pubsub
|
||||
|
||||
|
||||
PREFIX = "skyvern:test:"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: subscribe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_creates_queue_and_starts_listener():
|
||||
"""subscribe() should return a queue and start a background listener task."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
queue = ps.subscribe("key_1")
|
||||
|
||||
assert isinstance(queue, asyncio.Queue)
|
||||
assert "key_1" in ps._listener_tasks
|
||||
task = ps._listener_tasks["key_1"]
|
||||
assert isinstance(task, asyncio.Task)
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_reuses_listener_for_same_key():
|
||||
"""A second subscribe for the same key should NOT create a new listener task."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
ps.subscribe("key_1")
|
||||
first_task = ps._listener_tasks["key_1"]
|
||||
|
||||
ps.subscribe("key_1")
|
||||
assert ps._listener_tasks["key_1"] is first_task
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: unsubscribe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
|
||||
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub(block=True)
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
queue = ps.subscribe("key_1")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
task = ps._listener_tasks["key_1"]
|
||||
|
||||
ps.unsubscribe("key_1", queue)
|
||||
assert "key_1" not in ps._listener_tasks
|
||||
|
||||
await asyncio.gather(task, return_exceptions=True)
|
||||
assert task.cancelled()
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
|
||||
"""If other subscribers remain, the listener task should stay alive."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
q1 = ps.subscribe("key_1")
|
||||
ps.subscribe("key_1")
|
||||
|
||||
ps.unsubscribe("key_1", q1)
|
||||
assert "key_1" in ps._listener_tasks
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: publish
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_calls_redis_publish():
|
||||
"""publish() should fire-and-forget a Redis PUBLISH with prefixed channel."""
|
||||
redis = _make_mock_redis()
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
|
||||
ps.publish("key_1", {"type": "event"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
redis.publish.assert_awaited_once_with(
|
||||
f"{PREFIX}key_1",
|
||||
json.dumps({"type": "event"}),
|
||||
)
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _dispatch_local
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_local_fans_out_to_all_queues():
|
||||
"""_dispatch_local should put the message into every local queue for the key."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
q1 = ps.subscribe("key_1")
|
||||
q2 = ps.subscribe("key_1")
|
||||
|
||||
msg = {"type": "test", "value": 42}
|
||||
ps._dispatch_local("key_1", msg)
|
||||
|
||||
assert q1.get_nowait() == msg
|
||||
assert q2.get_nowait() == msg
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_local_does_not_leak_across_keys():
|
||||
"""Messages dispatched for key_a should not appear in key_b queues."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub()
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
q_a = ps.subscribe("key_a")
|
||||
q_b = ps.subscribe("key_b")
|
||||
|
||||
ps._dispatch_local("key_a", {"type": "test"})
|
||||
assert not q_a.empty()
|
||||
assert q_b.empty()
|
||||
|
||||
await ps.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: close
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_cancels_all_listeners_and_clears_state():
|
||||
"""close() should cancel every listener task and empty subscriber maps."""
|
||||
redis = _make_mock_redis()
|
||||
pubsub = _make_mock_pubsub(block=True)
|
||||
redis.pubsub.return_value = pubsub
|
||||
|
||||
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||
ps.subscribe("key_1")
|
||||
ps.subscribe("key_2")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
task_1 = ps._listener_tasks["key_1"]
|
||||
task_2 = ps._listener_tasks["key_2"]
|
||||
|
||||
await ps.close()
|
||||
|
||||
await asyncio.gather(task_1, task_2, return_exceptions=True)
|
||||
assert task_1.cancelled()
|
||||
assert task_2.cancelled()
|
||||
assert len(ps._listener_tasks) == 0
|
||||
assert len(ps._subscribers) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: prefix isolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_prefixes_do_not_interfere():
|
||||
"""Two RedisPubSub instances with different prefixes use separate channels."""
|
||||
redis = _make_mock_redis()
|
||||
|
||||
ps_a = RedisPubSub(redis, channel_prefix="prefix_a:")
|
||||
ps_b = RedisPubSub(redis, channel_prefix="prefix_b:")
|
||||
|
||||
ps_a.publish("key_1", {"from": "a"})
|
||||
ps_b.publish("key_1", {"from": "b"})
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
calls = redis.publish.await_args_list
|
||||
assert len(calls) == 2
|
||||
|
||||
channels = {call.args[0] for call in calls}
|
||||
assert "prefix_a:key_1" in channels
|
||||
assert "prefix_b:key_1" in channels
|
||||
|
||||
await ps_a.close()
|
||||
await ps_b.close()
|
||||
Reference in New Issue
Block a user