[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)

This commit is contained in:
Aaron Perez
2026-02-18 17:21:58 -05:00
committed by GitHub
parent b48bf707c3
commit e3b6d22fb6
28 changed files with 1989 additions and 41 deletions

View File

@@ -141,3 +141,12 @@ SKYVERN_AUTH_BITWARDEN_CLIENT_SECRET=your-client-secret-here
# Timeout in seconds for Bitwarden operations
# BITWARDEN_TIMEOUT_SECONDS=60
# Shared Redis URL used by any service that needs Redis (pub/sub, cache, etc.)
# REDIS_URL=redis://localhost:6379/0
# Notification registry type: "local" (default, in-process) or "redis" (multi-pod)
# NOTIFICATION_REGISTRY_TYPE=local
# Optional: override Redis URL specifically for notifications (falls back to REDIS_URL)
# NOTIFICATION_REDIS_URL=

View File

@@ -0,0 +1,39 @@
"""add 2fa waiting state fields to workflow_runs and tasks
Revision ID: a1b2c3d4e5f6
Revises: 43217e31df12
Create Date: 2026-02-13 00:00:00.000000+00:00
"""
from typing import Sequence, Union
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "a1b2c3d4e5f6"
down_revision: Union[str, None] = "43217e31df12"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS waiting_for_verification_code BOOLEAN NOT NULL DEFAULT false"
)
op.execute("ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS verification_code_identifier VARCHAR")
op.execute("ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS verification_code_polling_started_at TIMESTAMP")
op.execute(
"ALTER TABLE tasks ADD COLUMN IF NOT EXISTS waiting_for_verification_code BOOLEAN NOT NULL DEFAULT false"
)
op.execute("ALTER TABLE tasks ADD COLUMN IF NOT EXISTS verification_code_identifier VARCHAR")
op.execute("ALTER TABLE tasks ADD COLUMN IF NOT EXISTS verification_code_polling_started_at TIMESTAMP")
def downgrade() -> None:
op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS verification_code_polling_started_at")
op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS verification_code_identifier")
op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS waiting_for_verification_code")
op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS verification_code_polling_started_at")
op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS verification_code_identifier")
op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS waiting_for_verification_code")

View File

@@ -98,6 +98,13 @@ class Settings(BaseSettings):
# Supported storage types: local, s3cloud, azureblob
SKYVERN_STORAGE_TYPE: str = "local"
# Shared Redis URL (used by any service that needs Redis)
REDIS_URL: str = "redis://localhost:6379/0"
# Notification registry settings ("local" or "redis")
NOTIFICATION_REGISTRY_TYPE: str = "local"
NOTIFICATION_REDIS_URL: str | None = None # Deprecated: falls back to REDIS_URL
# S3/AWS settings
AWS_REGION: str = "us-east-1"
MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB

View File

@@ -2544,7 +2544,7 @@ class ForgeAgent:
step,
browser_state,
scraped_page,
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
verification_code_check=True,
expire_verification_code=True,
)
@@ -3169,7 +3169,7 @@ class ForgeAgent:
current_context = skyvern_context.ensure_context()
verification_code = current_context.totp_codes.get(task.task_id)
if (task.totp_verification_url or task.totp_identifier) and verification_code:
if verification_code:
if (
isinstance(final_navigation_payload, dict)
and SPECIAL_FIELD_VERIFICATION_CODE not in final_navigation_payload
@@ -4444,13 +4444,11 @@ class ForgeAgent:
if not task.organization_id:
return json_response, []
if not task.totp_verification_url and not task.totp_identifier:
return json_response, []
should_verify_by_magic_link = json_response.get("should_verify_by_magic_link")
place_to_enter_verification_code = json_response.get("place_to_enter_verification_code")
should_enter_verification_code = json_response.get("should_enter_verification_code")
# If no OTP verification needed, return early to avoid unnecessary processing
if (
not should_verify_by_magic_link
and not place_to_enter_verification_code
@@ -4466,8 +4464,10 @@ class ForgeAgent:
return json_response, actions
if should_verify_by_magic_link:
actions = await self.handle_potential_magic_link(task, step, scraped_page, browser_state, json_response)
return json_response, actions
# Magic links still require TOTP config (need a source to poll the link from)
if task.totp_verification_url or task.totp_identifier:
actions = await self.handle_potential_magic_link(task, step, scraped_page, browser_state, json_response)
return json_response, actions
return json_response, []
@@ -4524,12 +4524,7 @@ class ForgeAgent:
) -> dict[str, Any]:
place_to_enter_verification_code = json_response.get("place_to_enter_verification_code")
should_enter_verification_code = json_response.get("should_enter_verification_code")
if (
place_to_enter_verification_code
and should_enter_verification_code
and (task.totp_verification_url or task.totp_identifier)
and task.organization_id
):
if place_to_enter_verification_code and should_enter_verification_code and task.organization_id:
LOG.info("Need verification code")
workflow_id = workflow_permanent_id = None
if task.workflow_run_id:

View File

@@ -80,6 +80,20 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, Any]:
# Stop cleanup scheduler
await stop_cleanup_scheduler()
# Close notification registry (e.g. cancel Redis listener tasks)
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
registry = NotificationRegistryFactory.get_registry()
if hasattr(registry, "close"):
await registry.close()
# Close shared Redis client (after registry so listener tasks drain first)
from skyvern.forge.sdk.redis.factory import RedisClientFactory
redis_client = RedisClientFactory.get_client()
if redis_client is not None:
await redis_client.close()
if forge_app.api_app_shutdown_event:
LOG.info("Calling api app shutdown event")
try:

View File

@@ -110,6 +110,18 @@ def create_forge_app() -> ForgeApp:
StorageFactory.set_storage(AzureStorage())
app.STORAGE = StorageFactory.get_storage()
app.CACHE = CacheFactory.get_cache()
if settings.NOTIFICATION_REGISTRY_TYPE == "redis":
from redis.asyncio import from_url as redis_from_url
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry
from skyvern.forge.sdk.redis.factory import RedisClientFactory
redis_url = settings.NOTIFICATION_REDIS_URL or settings.REDIS_URL
redis_client = redis_from_url(redis_url, decode_responses=True)
RedisClientFactory.set_client(redis_client)
NotificationRegistryFactory.set_registry(RedisNotificationRegistry(redis_client))
app.ARTIFACT_MANAGER = ArtifactManager()
app.BROWSER_MANAGER = RealBrowserManager()
app.EXPERIMENTATION_PROVIDER = NoOpExperimentationProvider()

View File

@@ -849,6 +849,102 @@ class AgentDB(BaseAlchemyDB):
LOG.error("UnexpectedError", exc_info=True)
raise
async def update_task_2fa_state(
self,
task_id: str,
organization_id: str,
waiting_for_verification_code: bool,
verification_code_identifier: str | None = None,
verification_code_polling_started_at: datetime | None = None,
) -> Task:
"""Update task 2FA verification code waiting state."""
try:
async with self.Session() as session:
if task := (
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
task.waiting_for_verification_code = waiting_for_verification_code
if verification_code_identifier is not None:
task.verification_code_identifier = verification_code_identifier
if verification_code_polling_started_at is not None:
task.verification_code_polling_started_at = verification_code_polling_started_at
if not waiting_for_verification_code:
# Clear identifiers when no longer waiting
task.verification_code_identifier = None
task.verification_code_polling_started_at = None
await session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:
raise NotFoundError("Task not found")
return updated_task
else:
raise NotFoundError("Task not found")
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
@read_retry()
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
"""Return active 2FA verification requests for an organization.
Queries both tasks and workflow runs where waiting_for_verification_code=True.
Used to provide initial state when a WebSocket notification client connects.
"""
results: list[dict] = []
async with self.Session() as session:
# Tasks waiting for verification (exclude finalized tasks)
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
task_rows = (
await session.scalars(
select(TaskModel)
.filter_by(organization_id=organization_id)
.filter_by(waiting_for_verification_code=True)
.filter_by(workflow_run_id=None)
.filter(TaskModel.status.not_in(finalized_task_statuses))
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
)
).all()
for t in task_rows:
results.append(
{
"task_id": t.task_id,
"workflow_run_id": None,
"verification_code_identifier": t.verification_code_identifier,
"verification_code_polling_started_at": (
t.verification_code_polling_started_at.isoformat()
if t.verification_code_polling_started_at
else None
),
}
)
# Workflow runs waiting for verification (exclude finalized runs)
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
wr_rows = (
await session.scalars(
select(WorkflowRunModel)
.filter_by(organization_id=organization_id)
.filter_by(waiting_for_verification_code=True)
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
)
).all()
for wr in wr_rows:
results.append(
{
"task_id": None,
"workflow_run_id": wr.workflow_run_id,
"verification_code_identifier": wr.verification_code_identifier,
"verification_code_polling_started_at": (
wr.verification_code_polling_started_at.isoformat()
if wr.verification_code_polling_started_at
else None
),
}
)
return results
async def bulk_update_tasks(
self,
task_ids: list[str],
@@ -2794,6 +2890,9 @@ class AgentDB(BaseAlchemyDB):
ai_fallback: bool | None = None,
depends_on_workflow_run_id: str | None = None,
browser_session_id: str | None = None,
waiting_for_verification_code: bool | None = None,
verification_code_identifier: str | None = None,
verification_code_polling_started_at: datetime | None = None,
) -> WorkflowRun:
async with self.Session() as session:
workflow_run = (
@@ -2826,6 +2925,17 @@ class AgentDB(BaseAlchemyDB):
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
if browser_session_id:
workflow_run.browser_session_id = browser_session_id
# 2FA verification code waiting state updates
if waiting_for_verification_code is not None:
workflow_run.waiting_for_verification_code = waiting_for_verification_code
if verification_code_identifier is not None:
workflow_run.verification_code_identifier = verification_code_identifier
if verification_code_polling_started_at is not None:
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
if waiting_for_verification_code is not None and not waiting_for_verification_code:
# Clear related fields when waiting is set to False
workflow_run.verification_code_identifier = None
workflow_run.verification_code_polling_started_at = None
await session.commit()
await save_workflow_run_logs(workflow_run_id)
await session.refresh(workflow_run)
@@ -3995,6 +4105,35 @@ class AgentDB(BaseAlchemyDB):
totp_code = (await session.scalars(query)).all()
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
async def get_otp_codes_by_run(
self,
organization_id: str,
task_id: str | None = None,
workflow_run_id: str | None = None,
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
limit: int = 1,
) -> list[TOTPCode]:
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
The user submits codes manually via the UI, and this method finds them by run context.
"""
if not workflow_run_id and not task_id:
return []
async with self.Session() as session:
query = (
select(TOTPCodeModel)
.filter_by(organization_id=organization_id)
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
)
if workflow_run_id:
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
elif task_id:
query = query.filter(TOTPCodeModel.task_id == task_id)
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
results = (await session.scalars(query)).all()
return [TOTPCode.model_validate(r) for r in results]
async def get_recent_otp_codes(
self,
organization_id: str,

View File

@@ -116,6 +116,10 @@ class TaskModel(Base):
model = Column(JSON, nullable=True)
browser_address = Column(String, nullable=True)
download_timeout = Column(Numeric, nullable=True)
# 2FA verification code waiting state fields
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
verification_code_identifier = Column(String, nullable=True)
verification_code_polling_started_at = Column(DateTime, nullable=True)
class StepModel(Base):
@@ -350,6 +354,10 @@ class WorkflowRunModel(Base):
debug_session_id: Column = Column(String, nullable=True)
ai_fallback = Column(Boolean, nullable=True)
code_gen = Column(Boolean, nullable=True)
# 2FA verification code waiting state fields
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
verification_code_identifier = Column(String, nullable=True)
verification_code_polling_started_at = Column(DateTime, nullable=True)
queued_at = Column(DateTime, nullable=True)
started_at = Column(DateTime, nullable=True)

View File

@@ -211,6 +211,9 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
browser_session_id=task_obj.browser_session_id,
browser_address=task_obj.browser_address,
download_timeout=task_obj.download_timeout,
waiting_for_verification_code=task_obj.waiting_for_verification_code or False,
verification_code_identifier=task_obj.verification_code_identifier,
verification_code_polling_started_at=task_obj.verification_code_polling_started_at,
)
return task
@@ -424,6 +427,9 @@ def convert_to_workflow_run(
run_with=workflow_run_model.run_with,
code_gen=workflow_run_model.code_gen,
ai_fallback=workflow_run_model.ai_fallback,
waiting_for_verification_code=workflow_run_model.waiting_for_verification_code or False,
verification_code_identifier=workflow_run_model.verification_code_identifier,
verification_code_polling_started_at=workflow_run_model.verification_code_polling_started_at,
)

View File

@@ -0,0 +1,21 @@
"""Abstract base for notification registries."""
import asyncio
from abc import ABC, abstractmethod
class BaseNotificationRegistry(ABC):
"""Abstract pub/sub registry scoped by organization.
Implementations must fan-out: a single publish call delivers the
message to every active subscriber for that organization.
"""
@abstractmethod
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]: ...
@abstractmethod
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None: ...
@abstractmethod
def publish(self, organization_id: str, message: dict) -> None: ...

View File

@@ -0,0 +1,14 @@
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
class NotificationRegistryFactory:
__registry: BaseNotificationRegistry = LocalNotificationRegistry()
@staticmethod
def set_registry(registry: BaseNotificationRegistry) -> None:
NotificationRegistryFactory.__registry = registry
@staticmethod
def get_registry() -> BaseNotificationRegistry:
return NotificationRegistryFactory.__registry

View File

@@ -0,0 +1,45 @@
"""In-process notification registry using asyncio queues (single-pod only)."""
import asyncio
from collections import defaultdict
import structlog
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
LOG = structlog.get_logger()
class LocalNotificationRegistry(BaseNotificationRegistry):
"""In-process fan-out pub/sub using asyncio queues. Single-pod only."""
def __init__(self) -> None:
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
queue: asyncio.Queue[dict] = asyncio.Queue()
self._subscribers[organization_id].append(queue)
LOG.info("Notification subscriber added", organization_id=organization_id)
return queue
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
queues = self._subscribers.get(organization_id)
if queues:
try:
queues.remove(queue)
except ValueError:
pass
if not queues:
del self._subscribers[organization_id]
LOG.info("Notification subscriber removed", organization_id=organization_id)
def publish(self, organization_id: str, message: dict) -> None:
queues = self._subscribers.get(organization_id, [])
for queue in queues:
try:
queue.put_nowait(message)
except asyncio.QueueFull:
LOG.warning(
"Notification queue full, dropping message",
organization_id=organization_id,
)

View File

@@ -0,0 +1,55 @@
"""Redis-backed notification registry for multi-pod deployments.
Thin adapter around :class:`RedisPubSub` — all Redis pub/sub logic
lives in the generic layer; this class maps the ``organization_id``
domain concept onto generic string keys.
"""
import asyncio
from redis.asyncio import Redis
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
class RedisNotificationRegistry(BaseNotificationRegistry):
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per org."""
def __init__(self, redis_client: Redis) -> None:
self._pubsub = RedisPubSub(redis_client, channel_prefix="skyvern:notifications:")
# ------------------------------------------------------------------
# Property accessors (used by existing tests)
# ------------------------------------------------------------------
@property
def _listener_tasks(self) -> dict[str, asyncio.Task[None]]:
return self._pubsub._listener_tasks
@property
def _subscribers(self) -> dict[str, list[asyncio.Queue[dict]]]:
return self._pubsub._subscribers
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
return self._pubsub.subscribe(organization_id)
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
self._pubsub.unsubscribe(organization_id, queue)
def publish(self, organization_id: str, message: dict) -> None:
self._pubsub.publish(organization_id, message)
async def close(self) -> None:
await self._pubsub.close()
# ------------------------------------------------------------------
# Internal helper (exposed for tests)
# ------------------------------------------------------------------
def _dispatch_local(self, organization_id: str, message: dict) -> None:
self._pubsub._dispatch_local(organization_id, message)

View File

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from redis.asyncio import Redis
class RedisClientFactory:
"""Singleton factory for a shared async Redis client.
Follows the same static set/get pattern as ``CacheFactory``.
Defaults to ``None`` (no Redis in local/OSS mode).
"""
__client: Redis | None = None
@staticmethod
def set_client(client: Redis) -> None:
RedisClientFactory.__client = client
@staticmethod
def get_client() -> Redis | None:
return RedisClientFactory.__client

View File

@@ -0,0 +1,130 @@
"""Generic Redis pub/sub layer.
Extracted from ``RedisNotificationRegistry`` so that any feature
(notifications, events, cache invalidation, etc.) can reuse the same
pattern with its own channel prefix.
"""
import asyncio
import json
from collections import defaultdict
import structlog
from redis.asyncio import Redis
LOG = structlog.get_logger()
class RedisPubSub:
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per key."""
def __init__(self, redis_client: Redis, channel_prefix: str) -> None:
self._redis = redis_client
self._channel_prefix = channel_prefix
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
# One listener task per key channel
self._listener_tasks: dict[str, asyncio.Task[None]] = {}
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def subscribe(self, key: str) -> asyncio.Queue[dict]:
queue: asyncio.Queue[dict] = asyncio.Queue()
self._subscribers[key].append(queue)
# Spin up a Redis listener if this is the first local subscriber
if key not in self._listener_tasks:
task = asyncio.get_running_loop().create_task(self._listen(key))
self._listener_tasks[key] = task
LOG.info("PubSub subscriber added", key=key, channel_prefix=self._channel_prefix)
return queue
def unsubscribe(self, key: str, queue: asyncio.Queue[dict]) -> None:
queues = self._subscribers.get(key)
if queues:
try:
queues.remove(queue)
except ValueError:
pass
if not queues:
del self._subscribers[key]
self._cancel_listener(key)
LOG.info("PubSub subscriber removed", key=key, channel_prefix=self._channel_prefix)
def publish(self, key: str, message: dict) -> None:
"""Fire-and-forget Redis PUBLISH."""
try:
loop = asyncio.get_running_loop()
loop.create_task(self._publish_to_redis(key, message))
except RuntimeError:
LOG.warning(
"No running event loop; cannot publish via Redis",
key=key,
channel_prefix=self._channel_prefix,
)
async def close(self) -> None:
"""Cancel all listener tasks and clear state. Call on shutdown."""
for key in list(self._listener_tasks):
self._cancel_listener(key)
self._subscribers.clear()
LOG.info("RedisPubSub closed", channel_prefix=self._channel_prefix)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _publish_to_redis(self, key: str, message: dict) -> None:
channel = f"{self._channel_prefix}{key}"
try:
await self._redis.publish(channel, json.dumps(message))
except Exception:
LOG.exception("Failed to publish to Redis", key=key, channel_prefix=self._channel_prefix)
async def _listen(self, key: str) -> None:
"""Subscribe to a Redis channel and fan out messages locally."""
channel = f"{self._channel_prefix}{key}"
pubsub = self._redis.pubsub()
try:
await pubsub.subscribe(channel)
LOG.info("Redis listener started", channel=channel)
async for raw_message in pubsub.listen():
if raw_message["type"] != "message":
continue
try:
data = json.loads(raw_message["data"])
except (json.JSONDecodeError, TypeError):
LOG.warning("Invalid JSON on Redis channel", channel=channel)
continue
self._dispatch_local(key, data)
except asyncio.CancelledError:
LOG.info("Redis listener cancelled", channel=channel)
raise
except Exception:
LOG.exception("Redis listener error", channel=channel)
finally:
try:
await pubsub.unsubscribe(channel)
await pubsub.close()
except Exception:
LOG.warning("Error closing Redis pubsub", channel=channel)
def _dispatch_local(self, key: str, message: dict) -> None:
"""Fan out a message to all local asyncio queues for this key."""
queues = self._subscribers.get(key, [])
for queue in queues:
try:
queue.put_nowait(message)
except asyncio.QueueFull:
LOG.warning(
"Queue full, dropping message",
key=key,
channel_prefix=self._channel_prefix,
)
def _cancel_listener(self, key: str) -> None:
task = self._listener_tasks.pop(key, None)
if task and not task.done():
task.cancel()

View File

@@ -11,5 +11,6 @@ from skyvern.forge.sdk.routes import sdk # noqa: F401
from skyvern.forge.sdk.routes import webhooks # noqa: F401
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
from skyvern.forge.sdk.routes.streaming import notifications # noqa: F401
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401

View File

@@ -131,10 +131,15 @@ async def send_totp_code(
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
if not task:
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
workflow_id_for_storage: str | None = None
if data.workflow_id:
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
if data.workflow_id.startswith("wpid_"):
workflow = await app.DATABASE.get_workflow_by_permanent_id(data.workflow_id, curr_org.organization_id)
else:
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
if not workflow:
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
workflow_id_for_storage = workflow.workflow_id
if data.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
if not workflow_run:
@@ -162,7 +167,7 @@ async def send_totp_code(
content=data.content,
code=otp_value.value,
task_id=data.task_id,
workflow_id=data.workflow_id,
workflow_id=workflow_id_for_storage,
workflow_run_id=data.workflow_run_id,
source=data.source,
expired_at=data.expired_at,

View File

@@ -0,0 +1,101 @@
"""WebSocket endpoint for streaming global 2FA verification code notifications."""
import asyncio
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.routes.routers import base_router
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
LOG = structlog.get_logger()
HEARTBEAT_INTERVAL = 60
@base_router.websocket("/stream/notifications")
async def notification_stream(
websocket: WebSocket,
apikey: str | None = None,
token: str | None = None,
) -> None:
auth = local_auth if settings.ENV == "local" else real_auth
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.info("Notifications: Authentication failed")
return
LOG.info("Notifications: Started streaming", organization_id=organization_id)
registry = NotificationRegistryFactory.get_registry()
queue = registry.subscribe(organization_id)
try:
# Send initial state: all currently active verification requests
active_requests = await app.DATABASE.get_active_verification_requests(organization_id)
for req in active_requests:
await websocket.send_json(
{
"type": "verification_code_required",
"task_id": req.get("task_id"),
"workflow_run_id": req.get("workflow_run_id"),
"identifier": req.get("verification_code_identifier"),
"polling_started_at": req.get("verification_code_polling_started_at"),
}
)
# Watch for client disconnect while streaming events
disconnect_event = asyncio.Event()
async def _watch_disconnect() -> None:
try:
while True:
await websocket.receive()
except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError):
disconnect_event.set()
watcher = asyncio.create_task(_watch_disconnect())
try:
while not disconnect_event.is_set():
queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL))
disconnect_wait = asyncio.ensure_future(disconnect_event.wait())
done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED)
for p in pending:
p.cancel()
if disconnect_event.is_set():
return
try:
message = queue_task.result()
await websocket.send_json(message)
except TimeoutError:
try:
await websocket.send_json({"type": "heartbeat"})
except Exception:
LOG.info(
"Notifications: Client unreachable during heartbeat. Closing.",
organization_id=organization_id,
)
return
except asyncio.CancelledError:
return
finally:
watcher.cancel()
except WebSocketDisconnect:
LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id)
except ConnectionClosedOK:
LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id)
except ConnectionClosedError:
LOG.warning(
"Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id
)
except Exception:
LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True)
finally:
registry.unsubscribe(organization_id, queue)
LOG.info("Notifications: Connection closed", organization_id=organization_id)

View File

@@ -288,6 +288,10 @@ class Task(TaskBase):
queued_at: datetime | None = None
started_at: datetime | None = None
finished_at: datetime | None = None
# 2FA verification code waiting state fields
waiting_for_verification_code: bool = False
verification_code_identifier: str | None = None
verification_code_polling_started_at: datetime | None = None
@property
def llm_key(self) -> str | None:
@@ -365,6 +369,9 @@ class Task(TaskBase):
max_screenshot_scrolls=self.max_screenshot_scrolls,
step_count=step_count,
browser_session_id=self.browser_session_id,
waiting_for_verification_code=self.waiting_for_verification_code,
verification_code_identifier=self.verification_code_identifier,
verification_code_polling_started_at=self.verification_code_polling_started_at,
)
@@ -392,6 +399,10 @@ class TaskResponse(BaseModel):
max_screenshot_scrolls: int | None = None
step_count: int | None = None
browser_session_id: str | None = None
# 2FA verification code waiting state fields
waiting_for_verification_code: bool = False
verification_code_identifier: str | None = None
verification_code_polling_started_at: datetime | None = None
class TaskOutput(BaseModel):

View File

@@ -172,6 +172,10 @@ class WorkflowRun(BaseModel):
sequential_key: str | None = None
ai_fallback: bool | None = None
code_gen: bool | None = None
# 2FA verification code waiting state fields
waiting_for_verification_code: bool = False
verification_code_identifier: str | None = None
verification_code_polling_started_at: datetime | None = None
queued_at: datetime | None = None
started_at: datetime | None = None
@@ -226,6 +230,10 @@ class WorkflowRunResponseBase(BaseModel):
browser_address: str | None = None
script_run: ScriptRunResponse | None = None
errors: list[dict[str, Any]] | None = None
# 2FA verification code waiting state fields
waiting_for_verification_code: bool = False
verification_code_identifier: str | None = None
verification_code_polling_started_at: datetime | None = None
class WorkflowRunWithWorkflowResponse(WorkflowRunResponseBase):

View File

@@ -3019,6 +3019,10 @@ class WorkflowService:
browser_address=workflow_run.browser_address,
script_run=workflow_run.script_run,
errors=errors,
# 2FA verification code waiting state fields
waiting_for_verification_code=workflow_run.waiting_for_verification_code,
verification_code_identifier=workflow_run.verification_code_identifier,
verification_code_polling_started_at=workflow_run.verification_code_polling_started_at,
)
async def clean_up_workflow(

View File

@@ -11,6 +11,7 @@ from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.schemas.totp_codes import OTPType
LOG = structlog.get_logger()
@@ -80,38 +81,147 @@ async def poll_otp_value(
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
)
while True:
await asyncio.sleep(10)
# check timeout
if datetime.utcnow() > timeout_datetime:
LOG.warning("Polling otp value timed out")
raise NoTOTPVerificationCodeFound(
task_id=task_id,
# Set the waiting state in the database when polling starts
identifier_for_ui = totp_identifier
if workflow_run_id:
try:
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
workflow_id=workflow_permanent_id,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
waiting_for_verification_code=True,
verification_code_identifier=identifier_for_ui,
verification_code_polling_started_at=start_datetime,
)
otp_value: OTPValue | None = None
if totp_verification_url:
otp_value = await _get_otp_value_from_url(
organization_id,
totp_verification_url,
org_token.token,
task_id=task_id,
LOG.info(
"Set 2FA waiting state for workflow run",
workflow_run_id=workflow_run_id,
verification_code_identifier=identifier_for_ui,
)
elif totp_identifier:
otp_value = await _get_otp_value_from_db(
organization_id,
totp_identifier,
try:
NotificationRegistryFactory.get_registry().publish(
organization_id,
{
"type": "verification_code_required",
"workflow_run_id": workflow_run_id,
"task_id": task_id,
"identifier": identifier_for_ui,
"polling_started_at": start_datetime.isoformat(),
},
)
except Exception:
LOG.warning("Failed to publish 2FA required notification for workflow run", exc_info=True)
except Exception:
LOG.warning("Failed to set 2FA waiting state for workflow run", exc_info=True)
elif task_id:
try:
await app.DATABASE.update_task_2fa_state(
task_id=task_id,
workflow_id=workflow_permanent_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
waiting_for_verification_code=True,
verification_code_identifier=identifier_for_ui,
verification_code_polling_started_at=start_datetime,
)
if otp_value:
LOG.info("Got otp value", otp_value=otp_value)
return otp_value
LOG.info(
"Set 2FA waiting state for task",
task_id=task_id,
verification_code_identifier=identifier_for_ui,
)
try:
NotificationRegistryFactory.get_registry().publish(
organization_id,
{
"type": "verification_code_required",
"task_id": task_id,
"identifier": identifier_for_ui,
"polling_started_at": start_datetime.isoformat(),
},
)
except Exception:
LOG.warning("Failed to publish 2FA required notification for task", exc_info=True)
except Exception:
LOG.warning("Failed to set 2FA waiting state for task", exc_info=True)
try:
while True:
await asyncio.sleep(10)
# check timeout
if datetime.utcnow() > timeout_datetime:
LOG.warning("Polling otp value timed out")
raise NoTOTPVerificationCodeFound(
task_id=task_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow_permanent_id,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
)
otp_value: OTPValue | None = None
if totp_verification_url:
otp_value = await _get_otp_value_from_url(
organization_id,
totp_verification_url,
org_token.token,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
elif totp_identifier:
otp_value = await _get_otp_value_from_db(
organization_id,
totp_identifier,
task_id=task_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
)
if not otp_value:
otp_value = await _get_otp_value_by_run(
organization_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
else:
# No pre-configured TOTP — poll for manually submitted codes by run context
otp_value = await _get_otp_value_by_run(
organization_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
if otp_value:
LOG.info("Got otp value", otp_value=otp_value)
return otp_value
finally:
# Clear the waiting state when polling completes (success, timeout, or error)
if workflow_run_id:
try:
await app.DATABASE.update_workflow_run(
workflow_run_id=workflow_run_id,
waiting_for_verification_code=False,
)
LOG.info("Cleared 2FA waiting state for workflow run", workflow_run_id=workflow_run_id)
try:
NotificationRegistryFactory.get_registry().publish(
organization_id,
{"type": "verification_code_resolved", "workflow_run_id": workflow_run_id, "task_id": task_id},
)
except Exception:
LOG.warning("Failed to publish 2FA resolved notification for workflow run", exc_info=True)
except Exception:
LOG.warning("Failed to clear 2FA waiting state for workflow run", exc_info=True)
elif task_id:
try:
await app.DATABASE.update_task_2fa_state(
task_id=task_id,
organization_id=organization_id,
waiting_for_verification_code=False,
)
LOG.info("Cleared 2FA waiting state for task", task_id=task_id)
try:
NotificationRegistryFactory.get_registry().publish(
organization_id,
{"type": "verification_code_resolved", "task_id": task_id},
)
except Exception:
LOG.warning("Failed to publish 2FA resolved notification for task", exc_info=True)
except Exception:
LOG.warning("Failed to clear 2FA waiting state for task", exc_info=True)
async def _get_otp_value_from_url(
@@ -175,6 +285,28 @@ async def _get_otp_value_from_url(
return otp_value
async def _get_otp_value_by_run(
organization_id: str,
task_id: str | None = None,
workflow_run_id: str | None = None,
) -> OTPValue | None:
"""Look up OTP codes by task_id/workflow_run_id when no totp_identifier is configured.
Used for the manual 2FA input flow where users submit codes through the UI
without pre-configured TOTP credentials.
"""
codes = await app.DATABASE.get_otp_codes_by_run(
organization_id=organization_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
limit=1,
)
if codes:
code = codes[0]
return OTPValue(value=code.code, type=code.otp_type)
return None
async def _get_otp_value_from_db(
organization_id: str,
totp_identifier: str,

View File

@@ -0,0 +1,106 @@
"""Tests for NotificationRegistry pub/sub and get_active_verification_requests (SKY-6)."""
import pytest
from skyvern.forge.sdk.db.agent_db import AgentDB
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
# === Task 1: NotificationRegistry subscribe / publish / unsubscribe ===
@pytest.mark.asyncio
async def test_subscribe_and_publish():
"""Published messages should be received by subscribers."""
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
registry.publish("org_1", {"type": "verification_code_required", "task_id": "tsk_1"})
msg = queue.get_nowait()
assert msg["type"] == "verification_code_required"
assert msg["task_id"] == "tsk_1"
@pytest.mark.asyncio
async def test_multiple_subscribers():
"""All subscribers for an org should receive the same message."""
registry = LocalNotificationRegistry()
q1 = registry.subscribe("org_1")
q2 = registry.subscribe("org_1")
registry.publish("org_1", {"type": "verification_code_required"})
assert not q1.empty()
assert not q2.empty()
assert q1.get_nowait() == q2.get_nowait()
@pytest.mark.asyncio
async def test_publish_wrong_org_does_not_leak():
"""Messages for org_A should not appear in org_B's queue."""
registry = LocalNotificationRegistry()
q_a = registry.subscribe("org_a")
q_b = registry.subscribe("org_b")
registry.publish("org_a", {"type": "test"})
assert not q_a.empty()
assert q_b.empty()
@pytest.mark.asyncio
async def test_unsubscribe():
"""After unsubscribe, the queue should no longer receive messages."""
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
registry.unsubscribe("org_1", queue)
registry.publish("org_1", {"type": "test"})
assert queue.empty()
@pytest.mark.asyncio
async def test_unsubscribe_idempotent():
"""Unsubscribing a queue that's already removed should not raise."""
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
registry.unsubscribe("org_1", queue)
registry.unsubscribe("org_1", queue) # should not raise
# === Task: BaseNotificationRegistry ABC ===
def test_base_notification_registry_cannot_be_instantiated():
"""ABC should not be directly instantiable."""
with pytest.raises(TypeError):
BaseNotificationRegistry()
# === Task: NotificationRegistryFactory ===
@pytest.mark.asyncio
async def test_factory_returns_local_by_default():
"""Factory should return a LocalNotificationRegistry by default."""
registry = NotificationRegistryFactory.get_registry()
assert isinstance(registry, LocalNotificationRegistry)
@pytest.mark.asyncio
async def test_factory_set_and_get():
"""Factory should allow swapping the registry implementation."""
original = NotificationRegistryFactory.get_registry()
try:
custom = LocalNotificationRegistry()
NotificationRegistryFactory.set_registry(custom)
assert NotificationRegistryFactory.get_registry() is custom
finally:
NotificationRegistryFactory.set_registry(original)
# === Task 2: get_active_verification_requests DB method ===
def test_get_active_verification_requests_method_exists():
"""AgentDB should have get_active_verification_requests method."""
assert hasattr(AgentDB, "get_active_verification_requests")

View File

@@ -0,0 +1,544 @@
"""Tests for manual 2FA input without pre-configured TOTP credentials (SKY-6)."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from skyvern.constants import SPECIAL_FIELD_VERIFICATION_CODE
from skyvern.forge.agent import ForgeAgent
from skyvern.forge.sdk.db.agent_db import AgentDB
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
from skyvern.forge.sdk.routes.credentials import send_totp_code
from skyvern.forge.sdk.schemas.totp_codes import TOTPCodeCreate
from skyvern.schemas.runs import RunEngine
from skyvern.services.otp_service import OTPValue, _get_otp_value_by_run, poll_otp_value
@pytest.mark.asyncio
async def test_get_otp_codes_by_run_exists():
"""get_otp_codes_by_run should exist on AgentDB."""
assert hasattr(AgentDB, "get_otp_codes_by_run"), "AgentDB missing get_otp_codes_by_run method"
@pytest.mark.asyncio
async def test_get_otp_codes_by_run_returns_empty_without_identifiers():
"""get_otp_codes_by_run should return [] when neither task_id nor workflow_run_id is given."""
db = AgentDB.__new__(AgentDB)
result = await db.get_otp_codes_by_run(
organization_id="org_1",
)
assert result == []
# === Task 2: _get_otp_value_by_run OTP service function ===
@pytest.mark.asyncio
async def test_get_otp_value_by_run_returns_code():
"""_get_otp_value_by_run should find OTP codes by task_id."""
mock_code = MagicMock()
mock_code.code = "123456"
mock_code.otp_type = "totp"
mock_db = AsyncMock()
mock_db.get_otp_codes_by_run.return_value = [mock_code]
mock_app = MagicMock()
mock_app.DATABASE = mock_db
with patch("skyvern.services.otp_service.app", new=mock_app):
result = await _get_otp_value_by_run(
organization_id="org_1",
task_id="tsk_1",
)
assert result is not None
assert result.value == "123456"
@pytest.mark.asyncio
async def test_get_otp_value_by_run_returns_none_when_no_codes():
"""_get_otp_value_by_run should return None when no codes found."""
mock_db = AsyncMock()
mock_db.get_otp_codes_by_run.return_value = []
mock_app = MagicMock()
mock_app.DATABASE = mock_db
with patch("skyvern.services.otp_service.app", new=mock_app):
result = await _get_otp_value_by_run(
organization_id="org_1",
task_id="tsk_1",
)
assert result is None
# === Task 3: poll_otp_value without identifier ===
@pytest.mark.asyncio
async def test_poll_otp_value_without_identifier_uses_run_lookup():
"""poll_otp_value should use _get_otp_value_by_run when no identifier/URL provided."""
mock_code = MagicMock()
mock_code.code = "123456"
mock_code.otp_type = "totp"
mock_db = AsyncMock()
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
mock_db.get_otp_codes_by_run.return_value = [mock_code]
mock_db.update_task_2fa_state = AsyncMock()
mock_app = MagicMock()
mock_app.DATABASE = mock_db
with (
patch("skyvern.services.otp_service.app", new=mock_app),
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
):
result = await poll_otp_value(
organization_id="org_1",
task_id="tsk_1",
)
assert result is not None
assert result.value == "123456"
# === Task 6: Integration test — handle_potential_OTP_actions without TOTP config ===
@pytest.mark.asyncio
async def test_handle_potential_OTP_actions_without_totp_config():
"""When LLM detects 2FA but no TOTP config exists, should still enter verification flow."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.organization_id = "org_1"
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_1"
task.workflow_run_id = None
step = MagicMock()
scraped_page = MagicMock()
browser_state = MagicMock()
json_response = {
"should_enter_verification_code": True,
"place_to_enter_verification_code": "input#otp-code",
"actions": [],
}
with patch.object(agent, "handle_potential_verification_code", new_callable=AsyncMock) as mock_handler:
mock_handler.return_value = {"actions": []}
with patch("skyvern.forge.agent.parse_actions", return_value=[]):
result_json, result_actions = await agent.handle_potential_OTP_actions(
task, step, scraped_page, browser_state, json_response
)
mock_handler.assert_called_once()
@pytest.mark.asyncio
async def test_handle_potential_OTP_actions_skips_magic_link_without_totp_config():
"""Magic links should still require TOTP config."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.organization_id = "org_1"
task.totp_verification_url = None
task.totp_identifier = None
step = MagicMock()
scraped_page = MagicMock()
browser_state = MagicMock()
json_response = {
"should_verify_by_magic_link": True,
}
with patch.object(agent, "handle_potential_magic_link", new_callable=AsyncMock) as mock_handler:
result_json, result_actions = await agent.handle_potential_OTP_actions(
task, step, scraped_page, browser_state, json_response
)
mock_handler.assert_not_called()
assert result_actions == []
# === Task 7: verification_code_check always True in LLM prompt ===
@pytest.mark.asyncio
async def test_verification_code_check_always_true_without_totp_config():
"""_build_extract_action_prompt should receive verification_code_check=True even when task has no TOTP config."""
agent = ForgeAgent.__new__(ForgeAgent)
agent.async_operation_pool = MagicMock()
task = MagicMock()
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_1"
task.workflow_run_id = None
task.organization_id = "org_1"
task.url = "https://example.com"
step = MagicMock()
step.step_id = "step_1"
step.order = 0
step.retry_index = 0
scraped_page = MagicMock()
scraped_page.elements = []
browser_state = MagicMock()
with (
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
patch.object(
agent,
"_build_extract_action_prompt",
new_callable=AsyncMock,
return_value=("prompt", False, "extract_action"),
) as mock_build,
):
mock_ctx.current.return_value = None
await agent.build_and_record_step_prompt(
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
)
mock_build.assert_called_once()
_, kwargs = mock_build.call_args
assert kwargs["verification_code_check"] is True
@pytest.mark.asyncio
async def test_verification_code_check_always_true_with_totp_config():
"""_build_extract_action_prompt should receive verification_code_check=True when task HAS TOTP config (unchanged)."""
agent = ForgeAgent.__new__(ForgeAgent)
agent.async_operation_pool = MagicMock()
task = MagicMock()
task.totp_verification_url = "https://otp.example.com"
task.totp_identifier = "user@example.com"
task.task_id = "tsk_2"
task.workflow_run_id = None
task.organization_id = "org_1"
task.url = "https://example.com"
step = MagicMock()
step.step_id = "step_1"
step.order = 0
step.retry_index = 0
scraped_page = MagicMock()
scraped_page.elements = []
browser_state = MagicMock()
with (
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
patch.object(
agent,
"_build_extract_action_prompt",
new_callable=AsyncMock,
return_value=("prompt", False, "extract_action"),
) as mock_build,
):
mock_ctx.current.return_value = None
await agent.build_and_record_step_prompt(
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
)
mock_build.assert_called_once()
_, kwargs = mock_build.call_args
assert kwargs["verification_code_check"] is True
# === Fix: poll_otp_value should pass workflow_id, not workflow_permanent_id ===
@pytest.mark.asyncio
async def test_poll_otp_value_passes_workflow_id_not_permanent_id():
"""poll_otp_value should pass workflow_id (w_* format) to _get_otp_value_from_db, not workflow_permanent_id."""
mock_db = AsyncMock()
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
mock_db.update_workflow_run = AsyncMock()
mock_app = MagicMock()
mock_app.DATABASE = mock_db
with (
patch("skyvern.services.otp_service.app", new=mock_app),
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
patch(
"skyvern.services.otp_service._get_otp_value_from_db",
new_callable=AsyncMock,
return_value=OTPValue(value="654321", type="totp"),
) as mock_get_from_db,
):
result = await poll_otp_value(
organization_id="org_1",
workflow_id="w_123",
workflow_run_id="wr_789",
workflow_permanent_id="wpid_456",
totp_identifier="user@example.com",
)
assert result is not None
assert result.value == "654321"
mock_get_from_db.assert_called_once_with(
"org_1",
"user@example.com",
task_id=None,
workflow_id="w_123",
workflow_run_id="wr_789",
)
# === Fix: send_totp_code should resolve wpid_* to w_* before storage ===
@pytest.mark.asyncio
async def test_send_totp_code_resolves_wpid_to_workflow_id():
"""send_totp_code should resolve wpid_* to w_* before storing in DB."""
mock_workflow = MagicMock()
mock_workflow.workflow_id = "w_abc123"
mock_totp_code = MagicMock()
mock_db = AsyncMock()
mock_db.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
mock_app = MagicMock()
mock_app.DATABASE = mock_db
data = TOTPCodeCreate(
totp_identifier="user@example.com",
content="123456",
workflow_id="wpid_xyz789",
)
curr_org = MagicMock()
curr_org.organization_id = "org_1"
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
await send_totp_code(data=data, curr_org=curr_org)
mock_db.create_otp_code.assert_called_once()
call_kwargs = mock_db.create_otp_code.call_args[1]
assert call_kwargs["workflow_id"] == "w_abc123", f"Expected w_abc123 but got {call_kwargs['workflow_id']}"
@pytest.mark.asyncio
async def test_send_totp_code_w_format_passes_through():
"""send_totp_code should resolve and store w_* format workflow_id correctly."""
mock_workflow = MagicMock()
mock_workflow.workflow_id = "w_abc123"
mock_totp_code = MagicMock()
mock_db = AsyncMock()
mock_db.get_workflow = AsyncMock(return_value=mock_workflow)
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
mock_app = MagicMock()
mock_app.DATABASE = mock_db
data = TOTPCodeCreate(
totp_identifier="user@example.com",
content="123456",
workflow_id="w_abc123",
)
curr_org = MagicMock()
curr_org.organization_id = "org_1"
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
await send_totp_code(data=data, curr_org=curr_org)
call_kwargs = mock_db.create_otp_code.call_args[1]
assert call_kwargs["workflow_id"] == "w_abc123"
@pytest.mark.asyncio
async def test_send_totp_code_none_workflow_id():
"""send_totp_code should pass None workflow_id when not provided."""
mock_totp_code = MagicMock()
mock_db = AsyncMock()
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
mock_app = MagicMock()
mock_app.DATABASE = mock_db
data = TOTPCodeCreate(
totp_identifier="user@example.com",
content="123456",
)
curr_org = MagicMock()
curr_org.organization_id = "org_1"
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
await send_totp_code(data=data, curr_org=curr_org)
call_kwargs = mock_db.create_otp_code.call_args[1]
assert call_kwargs["workflow_id"] is None
# === Fix: _build_navigation_payload should inject code without TOTP config ===
def test_build_navigation_payload_injects_code_without_totp_config():
"""_build_navigation_payload should inject SPECIAL_FIELD_VERIFICATION_CODE even when
task has no totp_verification_url or totp_identifier (manual 2FA flow)."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_manual_2fa"
task.workflow_run_id = "wr_123"
task.navigation_payload = {"username": "user@example.com"}
mock_context = MagicMock()
mock_context.totp_codes = {"tsk_manual_2fa": "123456"}
mock_context.has_magic_link_page.return_value = False
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
mock_skyvern_ctx.ensure_context.return_value = mock_context
result = agent._build_navigation_payload(task)
assert isinstance(result, dict)
assert SPECIAL_FIELD_VERIFICATION_CODE in result
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "123456"
# Original payload preserved
assert result["username"] == "user@example.com"
def test_build_navigation_payload_injects_code_when_payload_is_none():
"""_build_navigation_payload should create a dict with the code when payload is None."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_manual_2fa"
task.workflow_run_id = "wr_123"
task.navigation_payload = None
mock_context = MagicMock()
mock_context.totp_codes = {"tsk_manual_2fa": "999999"}
mock_context.has_magic_link_page.return_value = False
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
mock_skyvern_ctx.ensure_context.return_value = mock_context
result = agent._build_navigation_payload(task)
assert isinstance(result, dict)
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "999999"
def test_build_navigation_payload_no_code_no_injection():
"""_build_navigation_payload should NOT inject anything when no code in context."""
agent = ForgeAgent.__new__(ForgeAgent)
task = MagicMock()
task.totp_verification_url = None
task.totp_identifier = None
task.task_id = "tsk_no_code"
task.workflow_run_id = "wr_456"
task.navigation_payload = {"field": "value"}
mock_context = MagicMock()
mock_context.totp_codes = {} # No code in context
mock_context.has_magic_link_page.return_value = False
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
mock_skyvern_ctx.ensure_context.return_value = mock_context
result = agent._build_navigation_payload(task)
assert isinstance(result, dict)
assert SPECIAL_FIELD_VERIFICATION_CODE not in result
assert result["field"] == "value"
# === Task: poll_otp_value publishes 2FA events to notification registry ===
@pytest.mark.asyncio
async def test_poll_otp_value_publishes_required_event_for_task():
"""poll_otp_value should publish verification_code_required when task waiting state is set."""
mock_code = MagicMock()
mock_code.code = "123456"
mock_code.otp_type = "totp"
mock_db = AsyncMock()
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
mock_db.get_otp_codes_by_run.return_value = [mock_code]
mock_db.update_task_2fa_state = AsyncMock()
mock_app = MagicMock()
mock_app.DATABASE = mock_db
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
with (
patch("skyvern.services.otp_service.app", new=mock_app),
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
patch(
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
new=registry,
),
):
await poll_otp_value(organization_id="org_1", task_id="tsk_1")
# Should have received required + resolved messages
messages = []
while not queue.empty():
messages.append(queue.get_nowait())
types = [m["type"] for m in messages]
assert "verification_code_required" in types
assert "verification_code_resolved" in types
required = next(m for m in messages if m["type"] == "verification_code_required")
assert required["task_id"] == "tsk_1"
resolved = next(m for m in messages if m["type"] == "verification_code_resolved")
assert resolved["task_id"] == "tsk_1"
@pytest.mark.asyncio
async def test_poll_otp_value_publishes_required_event_for_workflow_run():
"""poll_otp_value should publish verification_code_required when workflow run waiting state is set."""
mock_code = MagicMock()
mock_code.code = "654321"
mock_code.otp_type = "totp"
mock_db = AsyncMock()
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
mock_db.update_workflow_run = AsyncMock()
mock_db.get_otp_codes_by_run.return_value = [mock_code]
mock_app = MagicMock()
mock_app.DATABASE = mock_db
registry = LocalNotificationRegistry()
queue = registry.subscribe("org_1")
with (
patch("skyvern.services.otp_service.app", new=mock_app),
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
patch(
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
new=registry,
),
):
await poll_otp_value(organization_id="org_1", workflow_run_id="wr_1")
messages = []
while not queue.empty():
messages.append(queue.get_nowait())
types = [m["type"] for m in messages]
assert "verification_code_required" in types
assert "verification_code_resolved" in types
required = next(m for m in messages if m["type"] == "verification_code_required")
assert required["workflow_run_id"] == "wr_1"

View File

@@ -0,0 +1,22 @@
"""Tests for RedisClientFactory."""
from unittest.mock import MagicMock
from skyvern.forge.sdk.redis.factory import RedisClientFactory
def test_default_is_none():
"""Factory returns None when no client has been set."""
# Reset to default state
RedisClientFactory.set_client(None) # type: ignore[arg-type]
assert RedisClientFactory.get_client() is None
def test_set_and_get():
"""Round-trip: set_client then get_client returns the same object."""
mock_client = MagicMock()
RedisClientFactory.set_client(mock_client)
assert RedisClientFactory.get_client() is mock_client
# Cleanup
RedisClientFactory.set_client(None) # type: ignore[arg-type]

View File

@@ -0,0 +1,237 @@
"""Tests for RedisNotificationRegistry (SKY-6).
All tests use a mock Redis client — no real Redis instance required.
"""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_mock_redis() -> MagicMock:
"""Return a mock redis.asyncio.Redis client."""
redis = MagicMock()
redis.publish = AsyncMock()
redis.pubsub = MagicMock()
return redis
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
"""Return a mock PubSub that yields *messages* from ``listen()``.
Each entry in *messages* should look like:
{"type": "message", "data": '{"key": "val"}'}
If *block* is True the async generator will hang forever after
exhausting *messages*, which keeps the listener task alive so that
cancellation semantics can be tested.
"""
pubsub = MagicMock()
pubsub.subscribe = AsyncMock()
pubsub.unsubscribe = AsyncMock()
pubsub.close = AsyncMock()
async def _listen():
for msg in messages or []:
yield msg
if block:
# Keep the listener alive until cancelled
await asyncio.Event().wait()
pubsub.listen = _listen
return pubsub
# ---------------------------------------------------------------------------
# Tests: subscribe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_creates_queue_and_starts_listener():
"""subscribe() should return a queue and start a background listener task."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
queue = registry.subscribe("org_1")
assert isinstance(queue, asyncio.Queue)
assert "org_1" in registry._listener_tasks
task = registry._listener_tasks["org_1"]
assert isinstance(task, asyncio.Task)
# Cleanup
await registry.close()
@pytest.mark.asyncio
async def test_subscribe_reuses_listener_for_same_org():
"""A second subscribe for the same org should NOT create a new listener task."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
registry.subscribe("org_1")
first_task = registry._listener_tasks["org_1"]
registry.subscribe("org_1")
assert registry._listener_tasks["org_1"] is first_task
await registry.close()
# ---------------------------------------------------------------------------
# Tests: unsubscribe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub(block=True)
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
queue = registry.subscribe("org_1")
# Let the listener task start running
await asyncio.sleep(0)
task = registry._listener_tasks["org_1"]
registry.unsubscribe("org_1", queue)
assert "org_1" not in registry._listener_tasks
# Wait for the task to fully complete after cancellation
await asyncio.gather(task, return_exceptions=True)
assert task.cancelled()
await registry.close()
@pytest.mark.asyncio
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
"""If other subscribers remain, the listener task should stay alive."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
q1 = registry.subscribe("org_1")
registry.subscribe("org_1") # second subscriber
registry.unsubscribe("org_1", q1)
assert "org_1" in registry._listener_tasks
await registry.close()
# ---------------------------------------------------------------------------
# Tests: publish
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_publish_calls_redis_publish():
"""publish() should fire-and-forget a Redis PUBLISH."""
redis = _make_mock_redis()
registry = RedisNotificationRegistry(redis)
registry.publish("org_1", {"type": "verification_code_required"})
# Allow the fire-and-forget task to execute
await asyncio.sleep(0)
redis.publish.assert_awaited_once_with(
"skyvern:notifications:org_1",
json.dumps({"type": "verification_code_required"}),
)
await registry.close()
# ---------------------------------------------------------------------------
# Tests: _dispatch_local
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_dispatch_local_fans_out_to_all_queues():
"""_dispatch_local should put the message into every local queue for the org."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
q1 = registry.subscribe("org_1")
q2 = registry.subscribe("org_1")
msg = {"type": "test", "value": 42}
registry._dispatch_local("org_1", msg)
assert q1.get_nowait() == msg
assert q2.get_nowait() == msg
await registry.close()
@pytest.mark.asyncio
async def test_dispatch_local_does_not_leak_across_orgs():
"""Messages dispatched for org_a should not appear in org_b queues."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
q_a = registry.subscribe("org_a")
q_b = registry.subscribe("org_b")
registry._dispatch_local("org_a", {"type": "test"})
assert not q_a.empty()
assert q_b.empty()
await registry.close()
# ---------------------------------------------------------------------------
# Tests: close
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_close_cancels_all_listeners_and_clears_state():
"""close() should cancel every listener task and empty subscriber maps."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub(block=True)
redis.pubsub.return_value = pubsub
registry = RedisNotificationRegistry(redis)
registry.subscribe("org_1")
registry.subscribe("org_2")
# Let the listener tasks start running
await asyncio.sleep(0)
task_1 = registry._listener_tasks["org_1"]
task_2 = registry._listener_tasks["org_2"]
await registry.close()
# Wait for the tasks to fully complete after cancellation
await asyncio.gather(task_1, task_2, return_exceptions=True)
assert task_1.cancelled()
assert task_2.cancelled()
assert len(registry._listener_tasks) == 0
assert len(registry._subscribers) == 0

View File

@@ -0,0 +1,262 @@
"""Tests for RedisPubSub (generic pub/sub layer).
All tests use a mock Redis client — no real Redis instance required.
"""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_mock_redis() -> MagicMock:
"""Return a mock redis.asyncio.Redis client."""
redis = MagicMock()
redis.publish = AsyncMock()
redis.pubsub = MagicMock()
return redis
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
"""Return a mock PubSub that yields *messages* from ``listen()``.
Each entry in *messages* should look like:
{"type": "message", "data": '{"key": "val"}'}
If *block* is True the async generator will hang forever after
exhausting *messages*, which keeps the listener task alive so that
cancellation semantics can be tested.
"""
pubsub = MagicMock()
pubsub.subscribe = AsyncMock()
pubsub.unsubscribe = AsyncMock()
pubsub.close = AsyncMock()
async def _listen():
for msg in messages or []:
yield msg
if block:
await asyncio.Event().wait()
pubsub.listen = _listen
return pubsub
PREFIX = "skyvern:test:"
# ---------------------------------------------------------------------------
# Tests: subscribe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_creates_queue_and_starts_listener():
"""subscribe() should return a queue and start a background listener task."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
queue = ps.subscribe("key_1")
assert isinstance(queue, asyncio.Queue)
assert "key_1" in ps._listener_tasks
task = ps._listener_tasks["key_1"]
assert isinstance(task, asyncio.Task)
await ps.close()
@pytest.mark.asyncio
async def test_subscribe_reuses_listener_for_same_key():
"""A second subscribe for the same key should NOT create a new listener task."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
ps.subscribe("key_1")
first_task = ps._listener_tasks["key_1"]
ps.subscribe("key_1")
assert ps._listener_tasks["key_1"] is first_task
await ps.close()
# ---------------------------------------------------------------------------
# Tests: unsubscribe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub(block=True)
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
queue = ps.subscribe("key_1")
await asyncio.sleep(0)
task = ps._listener_tasks["key_1"]
ps.unsubscribe("key_1", queue)
assert "key_1" not in ps._listener_tasks
await asyncio.gather(task, return_exceptions=True)
assert task.cancelled()
await ps.close()
@pytest.mark.asyncio
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
"""If other subscribers remain, the listener task should stay alive."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
q1 = ps.subscribe("key_1")
ps.subscribe("key_1")
ps.unsubscribe("key_1", q1)
assert "key_1" in ps._listener_tasks
await ps.close()
# ---------------------------------------------------------------------------
# Tests: publish
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_publish_calls_redis_publish():
"""publish() should fire-and-forget a Redis PUBLISH with prefixed channel."""
redis = _make_mock_redis()
ps = RedisPubSub(redis, channel_prefix=PREFIX)
ps.publish("key_1", {"type": "event"})
await asyncio.sleep(0)
redis.publish.assert_awaited_once_with(
f"{PREFIX}key_1",
json.dumps({"type": "event"}),
)
await ps.close()
# ---------------------------------------------------------------------------
# Tests: _dispatch_local
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_dispatch_local_fans_out_to_all_queues():
"""_dispatch_local should put the message into every local queue for the key."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
q1 = ps.subscribe("key_1")
q2 = ps.subscribe("key_1")
msg = {"type": "test", "value": 42}
ps._dispatch_local("key_1", msg)
assert q1.get_nowait() == msg
assert q2.get_nowait() == msg
await ps.close()
@pytest.mark.asyncio
async def test_dispatch_local_does_not_leak_across_keys():
"""Messages dispatched for key_a should not appear in key_b queues."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub()
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
q_a = ps.subscribe("key_a")
q_b = ps.subscribe("key_b")
ps._dispatch_local("key_a", {"type": "test"})
assert not q_a.empty()
assert q_b.empty()
await ps.close()
# ---------------------------------------------------------------------------
# Tests: close
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_close_cancels_all_listeners_and_clears_state():
"""close() should cancel every listener task and empty subscriber maps."""
redis = _make_mock_redis()
pubsub = _make_mock_pubsub(block=True)
redis.pubsub.return_value = pubsub
ps = RedisPubSub(redis, channel_prefix=PREFIX)
ps.subscribe("key_1")
ps.subscribe("key_2")
await asyncio.sleep(0)
task_1 = ps._listener_tasks["key_1"]
task_2 = ps._listener_tasks["key_2"]
await ps.close()
await asyncio.gather(task_1, task_2, return_exceptions=True)
assert task_1.cancelled()
assert task_2.cancelled()
assert len(ps._listener_tasks) == 0
assert len(ps._subscribers) == 0
# ---------------------------------------------------------------------------
# Tests: prefix isolation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_different_prefixes_do_not_interfere():
"""Two RedisPubSub instances with different prefixes use separate channels."""
redis = _make_mock_redis()
ps_a = RedisPubSub(redis, channel_prefix="prefix_a:")
ps_b = RedisPubSub(redis, channel_prefix="prefix_b:")
ps_a.publish("key_1", {"from": "a"})
ps_b.publish("key_1", {"from": "b"})
await asyncio.sleep(0)
calls = redis.publish.await_args_list
assert len(calls) == 2
channels = {call.args[0] for call in calls}
assert "prefix_a:key_1" in channels
assert "prefix_b:key_1" in channels
await ps_a.close()
await ps_b.close()