[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)
This commit is contained in:
@@ -141,3 +141,12 @@ SKYVERN_AUTH_BITWARDEN_CLIENT_SECRET=your-client-secret-here
|
|||||||
|
|
||||||
# Timeout in seconds for Bitwarden operations
|
# Timeout in seconds for Bitwarden operations
|
||||||
# BITWARDEN_TIMEOUT_SECONDS=60
|
# BITWARDEN_TIMEOUT_SECONDS=60
|
||||||
|
|
||||||
|
# Shared Redis URL used by any service that needs Redis (pub/sub, cache, etc.)
|
||||||
|
# REDIS_URL=redis://localhost:6379/0
|
||||||
|
|
||||||
|
# Notification registry type: "local" (default, in-process) or "redis" (multi-pod)
|
||||||
|
# NOTIFICATION_REGISTRY_TYPE=local
|
||||||
|
|
||||||
|
# Optional: override Redis URL specifically for notifications (falls back to REDIS_URL)
|
||||||
|
# NOTIFICATION_REDIS_URL=
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
"""add 2fa waiting state fields to workflow_runs and tasks
|
||||||
|
|
||||||
|
Revision ID: a1b2c3d4e5f6
|
||||||
|
Revises: 43217e31df12
|
||||||
|
Create Date: 2026-02-13 00:00:00.000000+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "a1b2c3d4e5f6"
|
||||||
|
down_revision: Union[str, None] = "43217e31df12"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS waiting_for_verification_code BOOLEAN NOT NULL DEFAULT false"
|
||||||
|
)
|
||||||
|
op.execute("ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS verification_code_identifier VARCHAR")
|
||||||
|
op.execute("ALTER TABLE workflow_runs ADD COLUMN IF NOT EXISTS verification_code_polling_started_at TIMESTAMP")
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE tasks ADD COLUMN IF NOT EXISTS waiting_for_verification_code BOOLEAN NOT NULL DEFAULT false"
|
||||||
|
)
|
||||||
|
op.execute("ALTER TABLE tasks ADD COLUMN IF NOT EXISTS verification_code_identifier VARCHAR")
|
||||||
|
op.execute("ALTER TABLE tasks ADD COLUMN IF NOT EXISTS verification_code_polling_started_at TIMESTAMP")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS verification_code_polling_started_at")
|
||||||
|
op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS verification_code_identifier")
|
||||||
|
op.execute("ALTER TABLE tasks DROP COLUMN IF EXISTS waiting_for_verification_code")
|
||||||
|
op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS verification_code_polling_started_at")
|
||||||
|
op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS verification_code_identifier")
|
||||||
|
op.execute("ALTER TABLE workflow_runs DROP COLUMN IF EXISTS waiting_for_verification_code")
|
||||||
@@ -98,6 +98,13 @@ class Settings(BaseSettings):
|
|||||||
# Supported storage types: local, s3cloud, azureblob
|
# Supported storage types: local, s3cloud, azureblob
|
||||||
SKYVERN_STORAGE_TYPE: str = "local"
|
SKYVERN_STORAGE_TYPE: str = "local"
|
||||||
|
|
||||||
|
# Shared Redis URL (used by any service that needs Redis)
|
||||||
|
REDIS_URL: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
|
# Notification registry settings ("local" or "redis")
|
||||||
|
NOTIFICATION_REGISTRY_TYPE: str = "local"
|
||||||
|
NOTIFICATION_REDIS_URL: str | None = None # Deprecated: falls back to REDIS_URL
|
||||||
|
|
||||||
# S3/AWS settings
|
# S3/AWS settings
|
||||||
AWS_REGION: str = "us-east-1"
|
AWS_REGION: str = "us-east-1"
|
||||||
MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB
|
MAX_UPLOAD_FILE_SIZE: int = 10 * 1024 * 1024 # 10 MB
|
||||||
|
|||||||
@@ -2544,7 +2544,7 @@ class ForgeAgent:
|
|||||||
step,
|
step,
|
||||||
browser_state,
|
browser_state,
|
||||||
scraped_page,
|
scraped_page,
|
||||||
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
|
verification_code_check=True,
|
||||||
expire_verification_code=True,
|
expire_verification_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -3169,7 +3169,7 @@ class ForgeAgent:
|
|||||||
|
|
||||||
current_context = skyvern_context.ensure_context()
|
current_context = skyvern_context.ensure_context()
|
||||||
verification_code = current_context.totp_codes.get(task.task_id)
|
verification_code = current_context.totp_codes.get(task.task_id)
|
||||||
if (task.totp_verification_url or task.totp_identifier) and verification_code:
|
if verification_code:
|
||||||
if (
|
if (
|
||||||
isinstance(final_navigation_payload, dict)
|
isinstance(final_navigation_payload, dict)
|
||||||
and SPECIAL_FIELD_VERIFICATION_CODE not in final_navigation_payload
|
and SPECIAL_FIELD_VERIFICATION_CODE not in final_navigation_payload
|
||||||
@@ -4444,13 +4444,11 @@ class ForgeAgent:
|
|||||||
if not task.organization_id:
|
if not task.organization_id:
|
||||||
return json_response, []
|
return json_response, []
|
||||||
|
|
||||||
if not task.totp_verification_url and not task.totp_identifier:
|
|
||||||
return json_response, []
|
|
||||||
|
|
||||||
should_verify_by_magic_link = json_response.get("should_verify_by_magic_link")
|
should_verify_by_magic_link = json_response.get("should_verify_by_magic_link")
|
||||||
place_to_enter_verification_code = json_response.get("place_to_enter_verification_code")
|
place_to_enter_verification_code = json_response.get("place_to_enter_verification_code")
|
||||||
should_enter_verification_code = json_response.get("should_enter_verification_code")
|
should_enter_verification_code = json_response.get("should_enter_verification_code")
|
||||||
|
|
||||||
|
# If no OTP verification needed, return early to avoid unnecessary processing
|
||||||
if (
|
if (
|
||||||
not should_verify_by_magic_link
|
not should_verify_by_magic_link
|
||||||
and not place_to_enter_verification_code
|
and not place_to_enter_verification_code
|
||||||
@@ -4466,8 +4464,10 @@ class ForgeAgent:
|
|||||||
return json_response, actions
|
return json_response, actions
|
||||||
|
|
||||||
if should_verify_by_magic_link:
|
if should_verify_by_magic_link:
|
||||||
actions = await self.handle_potential_magic_link(task, step, scraped_page, browser_state, json_response)
|
# Magic links still require TOTP config (need a source to poll the link from)
|
||||||
return json_response, actions
|
if task.totp_verification_url or task.totp_identifier:
|
||||||
|
actions = await self.handle_potential_magic_link(task, step, scraped_page, browser_state, json_response)
|
||||||
|
return json_response, actions
|
||||||
|
|
||||||
return json_response, []
|
return json_response, []
|
||||||
|
|
||||||
@@ -4524,12 +4524,7 @@ class ForgeAgent:
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
place_to_enter_verification_code = json_response.get("place_to_enter_verification_code")
|
place_to_enter_verification_code = json_response.get("place_to_enter_verification_code")
|
||||||
should_enter_verification_code = json_response.get("should_enter_verification_code")
|
should_enter_verification_code = json_response.get("should_enter_verification_code")
|
||||||
if (
|
if place_to_enter_verification_code and should_enter_verification_code and task.organization_id:
|
||||||
place_to_enter_verification_code
|
|
||||||
and should_enter_verification_code
|
|
||||||
and (task.totp_verification_url or task.totp_identifier)
|
|
||||||
and task.organization_id
|
|
||||||
):
|
|
||||||
LOG.info("Need verification code")
|
LOG.info("Need verification code")
|
||||||
workflow_id = workflow_permanent_id = None
|
workflow_id = workflow_permanent_id = None
|
||||||
if task.workflow_run_id:
|
if task.workflow_run_id:
|
||||||
|
|||||||
@@ -80,6 +80,20 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, Any]:
|
|||||||
# Stop cleanup scheduler
|
# Stop cleanup scheduler
|
||||||
await stop_cleanup_scheduler()
|
await stop_cleanup_scheduler()
|
||||||
|
|
||||||
|
# Close notification registry (e.g. cancel Redis listener tasks)
|
||||||
|
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||||
|
|
||||||
|
registry = NotificationRegistryFactory.get_registry()
|
||||||
|
if hasattr(registry, "close"):
|
||||||
|
await registry.close()
|
||||||
|
|
||||||
|
# Close shared Redis client (after registry so listener tasks drain first)
|
||||||
|
from skyvern.forge.sdk.redis.factory import RedisClientFactory
|
||||||
|
|
||||||
|
redis_client = RedisClientFactory.get_client()
|
||||||
|
if redis_client is not None:
|
||||||
|
await redis_client.close()
|
||||||
|
|
||||||
if forge_app.api_app_shutdown_event:
|
if forge_app.api_app_shutdown_event:
|
||||||
LOG.info("Calling api app shutdown event")
|
LOG.info("Calling api app shutdown event")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -110,6 +110,18 @@ def create_forge_app() -> ForgeApp:
|
|||||||
StorageFactory.set_storage(AzureStorage())
|
StorageFactory.set_storage(AzureStorage())
|
||||||
app.STORAGE = StorageFactory.get_storage()
|
app.STORAGE = StorageFactory.get_storage()
|
||||||
app.CACHE = CacheFactory.get_cache()
|
app.CACHE = CacheFactory.get_cache()
|
||||||
|
|
||||||
|
if settings.NOTIFICATION_REGISTRY_TYPE == "redis":
|
||||||
|
from redis.asyncio import from_url as redis_from_url
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||||
|
from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry
|
||||||
|
from skyvern.forge.sdk.redis.factory import RedisClientFactory
|
||||||
|
|
||||||
|
redis_url = settings.NOTIFICATION_REDIS_URL or settings.REDIS_URL
|
||||||
|
redis_client = redis_from_url(redis_url, decode_responses=True)
|
||||||
|
RedisClientFactory.set_client(redis_client)
|
||||||
|
NotificationRegistryFactory.set_registry(RedisNotificationRegistry(redis_client))
|
||||||
app.ARTIFACT_MANAGER = ArtifactManager()
|
app.ARTIFACT_MANAGER = ArtifactManager()
|
||||||
app.BROWSER_MANAGER = RealBrowserManager()
|
app.BROWSER_MANAGER = RealBrowserManager()
|
||||||
app.EXPERIMENTATION_PROVIDER = NoOpExperimentationProvider()
|
app.EXPERIMENTATION_PROVIDER = NoOpExperimentationProvider()
|
||||||
|
|||||||
@@ -849,6 +849,102 @@ class AgentDB(BaseAlchemyDB):
|
|||||||
LOG.error("UnexpectedError", exc_info=True)
|
LOG.error("UnexpectedError", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def update_task_2fa_state(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
waiting_for_verification_code: bool,
|
||||||
|
verification_code_identifier: str | None = None,
|
||||||
|
verification_code_polling_started_at: datetime | None = None,
|
||||||
|
) -> Task:
|
||||||
|
"""Update task 2FA verification code waiting state."""
|
||||||
|
try:
|
||||||
|
async with self.Session() as session:
|
||||||
|
if task := (
|
||||||
|
await session.scalars(
|
||||||
|
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||||
|
)
|
||||||
|
).first():
|
||||||
|
task.waiting_for_verification_code = waiting_for_verification_code
|
||||||
|
if verification_code_identifier is not None:
|
||||||
|
task.verification_code_identifier = verification_code_identifier
|
||||||
|
if verification_code_polling_started_at is not None:
|
||||||
|
task.verification_code_polling_started_at = verification_code_polling_started_at
|
||||||
|
if not waiting_for_verification_code:
|
||||||
|
# Clear identifiers when no longer waiting
|
||||||
|
task.verification_code_identifier = None
|
||||||
|
task.verification_code_polling_started_at = None
|
||||||
|
await session.commit()
|
||||||
|
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||||
|
if not updated_task:
|
||||||
|
raise NotFoundError("Task not found")
|
||||||
|
return updated_task
|
||||||
|
else:
|
||||||
|
raise NotFoundError("Task not found")
|
||||||
|
except SQLAlchemyError:
|
||||||
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@read_retry()
|
||||||
|
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
|
||||||
|
"""Return active 2FA verification requests for an organization.
|
||||||
|
|
||||||
|
Queries both tasks and workflow runs where waiting_for_verification_code=True.
|
||||||
|
Used to provide initial state when a WebSocket notification client connects.
|
||||||
|
"""
|
||||||
|
results: list[dict] = []
|
||||||
|
async with self.Session() as session:
|
||||||
|
# Tasks waiting for verification (exclude finalized tasks)
|
||||||
|
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
|
||||||
|
task_rows = (
|
||||||
|
await session.scalars(
|
||||||
|
select(TaskModel)
|
||||||
|
.filter_by(organization_id=organization_id)
|
||||||
|
.filter_by(waiting_for_verification_code=True)
|
||||||
|
.filter_by(workflow_run_id=None)
|
||||||
|
.filter(TaskModel.status.not_in(finalized_task_statuses))
|
||||||
|
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
for t in task_rows:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"task_id": t.task_id,
|
||||||
|
"workflow_run_id": None,
|
||||||
|
"verification_code_identifier": t.verification_code_identifier,
|
||||||
|
"verification_code_polling_started_at": (
|
||||||
|
t.verification_code_polling_started_at.isoformat()
|
||||||
|
if t.verification_code_polling_started_at
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Workflow runs waiting for verification (exclude finalized runs)
|
||||||
|
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
|
||||||
|
wr_rows = (
|
||||||
|
await session.scalars(
|
||||||
|
select(WorkflowRunModel)
|
||||||
|
.filter_by(organization_id=organization_id)
|
||||||
|
.filter_by(waiting_for_verification_code=True)
|
||||||
|
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
|
||||||
|
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
for wr in wr_rows:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"task_id": None,
|
||||||
|
"workflow_run_id": wr.workflow_run_id,
|
||||||
|
"verification_code_identifier": wr.verification_code_identifier,
|
||||||
|
"verification_code_polling_started_at": (
|
||||||
|
wr.verification_code_polling_started_at.isoformat()
|
||||||
|
if wr.verification_code_polling_started_at
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
async def bulk_update_tasks(
|
async def bulk_update_tasks(
|
||||||
self,
|
self,
|
||||||
task_ids: list[str],
|
task_ids: list[str],
|
||||||
@@ -2794,6 +2890,9 @@ class AgentDB(BaseAlchemyDB):
|
|||||||
ai_fallback: bool | None = None,
|
ai_fallback: bool | None = None,
|
||||||
depends_on_workflow_run_id: str | None = None,
|
depends_on_workflow_run_id: str | None = None,
|
||||||
browser_session_id: str | None = None,
|
browser_session_id: str | None = None,
|
||||||
|
waiting_for_verification_code: bool | None = None,
|
||||||
|
verification_code_identifier: str | None = None,
|
||||||
|
verification_code_polling_started_at: datetime | None = None,
|
||||||
) -> WorkflowRun:
|
) -> WorkflowRun:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
workflow_run = (
|
workflow_run = (
|
||||||
@@ -2826,6 +2925,17 @@ class AgentDB(BaseAlchemyDB):
|
|||||||
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
|
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
|
||||||
if browser_session_id:
|
if browser_session_id:
|
||||||
workflow_run.browser_session_id = browser_session_id
|
workflow_run.browser_session_id = browser_session_id
|
||||||
|
# 2FA verification code waiting state updates
|
||||||
|
if waiting_for_verification_code is not None:
|
||||||
|
workflow_run.waiting_for_verification_code = waiting_for_verification_code
|
||||||
|
if verification_code_identifier is not None:
|
||||||
|
workflow_run.verification_code_identifier = verification_code_identifier
|
||||||
|
if verification_code_polling_started_at is not None:
|
||||||
|
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
|
||||||
|
if waiting_for_verification_code is not None and not waiting_for_verification_code:
|
||||||
|
# Clear related fields when waiting is set to False
|
||||||
|
workflow_run.verification_code_identifier = None
|
||||||
|
workflow_run.verification_code_polling_started_at = None
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await save_workflow_run_logs(workflow_run_id)
|
await save_workflow_run_logs(workflow_run_id)
|
||||||
await session.refresh(workflow_run)
|
await session.refresh(workflow_run)
|
||||||
@@ -3995,6 +4105,35 @@ class AgentDB(BaseAlchemyDB):
|
|||||||
totp_code = (await session.scalars(query)).all()
|
totp_code = (await session.scalars(query)).all()
|
||||||
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
|
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
|
||||||
|
|
||||||
|
async def get_otp_codes_by_run(
|
||||||
|
self,
|
||||||
|
organization_id: str,
|
||||||
|
task_id: str | None = None,
|
||||||
|
workflow_run_id: str | None = None,
|
||||||
|
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
||||||
|
limit: int = 1,
|
||||||
|
) -> list[TOTPCode]:
|
||||||
|
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
|
||||||
|
|
||||||
|
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
|
||||||
|
The user submits codes manually via the UI, and this method finds them by run context.
|
||||||
|
"""
|
||||||
|
if not workflow_run_id and not task_id:
|
||||||
|
return []
|
||||||
|
async with self.Session() as session:
|
||||||
|
query = (
|
||||||
|
select(TOTPCodeModel)
|
||||||
|
.filter_by(organization_id=organization_id)
|
||||||
|
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
|
||||||
|
)
|
||||||
|
if workflow_run_id:
|
||||||
|
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||||
|
elif task_id:
|
||||||
|
query = query.filter(TOTPCodeModel.task_id == task_id)
|
||||||
|
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
||||||
|
results = (await session.scalars(query)).all()
|
||||||
|
return [TOTPCode.model_validate(r) for r in results]
|
||||||
|
|
||||||
async def get_recent_otp_codes(
|
async def get_recent_otp_codes(
|
||||||
self,
|
self,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
|
|||||||
@@ -116,6 +116,10 @@ class TaskModel(Base):
|
|||||||
model = Column(JSON, nullable=True)
|
model = Column(JSON, nullable=True)
|
||||||
browser_address = Column(String, nullable=True)
|
browser_address = Column(String, nullable=True)
|
||||||
download_timeout = Column(Numeric, nullable=True)
|
download_timeout = Column(Numeric, nullable=True)
|
||||||
|
# 2FA verification code waiting state fields
|
||||||
|
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
|
||||||
|
verification_code_identifier = Column(String, nullable=True)
|
||||||
|
verification_code_polling_started_at = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class StepModel(Base):
|
class StepModel(Base):
|
||||||
@@ -350,6 +354,10 @@ class WorkflowRunModel(Base):
|
|||||||
debug_session_id: Column = Column(String, nullable=True)
|
debug_session_id: Column = Column(String, nullable=True)
|
||||||
ai_fallback = Column(Boolean, nullable=True)
|
ai_fallback = Column(Boolean, nullable=True)
|
||||||
code_gen = Column(Boolean, nullable=True)
|
code_gen = Column(Boolean, nullable=True)
|
||||||
|
# 2FA verification code waiting state fields
|
||||||
|
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
|
||||||
|
verification_code_identifier = Column(String, nullable=True)
|
||||||
|
verification_code_polling_started_at = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
queued_at = Column(DateTime, nullable=True)
|
queued_at = Column(DateTime, nullable=True)
|
||||||
started_at = Column(DateTime, nullable=True)
|
started_at = Column(DateTime, nullable=True)
|
||||||
|
|||||||
@@ -211,6 +211,9 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
|
|||||||
browser_session_id=task_obj.browser_session_id,
|
browser_session_id=task_obj.browser_session_id,
|
||||||
browser_address=task_obj.browser_address,
|
browser_address=task_obj.browser_address,
|
||||||
download_timeout=task_obj.download_timeout,
|
download_timeout=task_obj.download_timeout,
|
||||||
|
waiting_for_verification_code=task_obj.waiting_for_verification_code or False,
|
||||||
|
verification_code_identifier=task_obj.verification_code_identifier,
|
||||||
|
verification_code_polling_started_at=task_obj.verification_code_polling_started_at,
|
||||||
)
|
)
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -424,6 +427,9 @@ def convert_to_workflow_run(
|
|||||||
run_with=workflow_run_model.run_with,
|
run_with=workflow_run_model.run_with,
|
||||||
code_gen=workflow_run_model.code_gen,
|
code_gen=workflow_run_model.code_gen,
|
||||||
ai_fallback=workflow_run_model.ai_fallback,
|
ai_fallback=workflow_run_model.ai_fallback,
|
||||||
|
waiting_for_verification_code=workflow_run_model.waiting_for_verification_code or False,
|
||||||
|
verification_code_identifier=workflow_run_model.verification_code_identifier,
|
||||||
|
verification_code_polling_started_at=workflow_run_model.verification_code_polling_started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
21
skyvern/forge/sdk/notification/base.py
Normal file
21
skyvern/forge/sdk/notification/base.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Abstract base for notification registries."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseNotificationRegistry(ABC):
|
||||||
|
"""Abstract pub/sub registry scoped by organization.
|
||||||
|
|
||||||
|
Implementations must fan-out: a single publish call delivers the
|
||||||
|
message to every active subscriber for that organization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def publish(self, organization_id: str, message: dict) -> None: ...
|
||||||
14
skyvern/forge/sdk/notification/factory.py
Normal file
14
skyvern/forge/sdk/notification/factory.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||||
|
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationRegistryFactory:
|
||||||
|
__registry: BaseNotificationRegistry = LocalNotificationRegistry()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_registry(registry: BaseNotificationRegistry) -> None:
|
||||||
|
NotificationRegistryFactory.__registry = registry
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_registry() -> BaseNotificationRegistry:
|
||||||
|
return NotificationRegistryFactory.__registry
|
||||||
45
skyvern/forge/sdk/notification/local.py
Normal file
45
skyvern/forge/sdk/notification/local.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""In-process notification registry using asyncio queues (single-pod only)."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class LocalNotificationRegistry(BaseNotificationRegistry):
|
||||||
|
"""In-process fan-out pub/sub using asyncio queues. Single-pod only."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
|
||||||
|
|
||||||
|
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
|
||||||
|
queue: asyncio.Queue[dict] = asyncio.Queue()
|
||||||
|
self._subscribers[organization_id].append(queue)
|
||||||
|
LOG.info("Notification subscriber added", organization_id=organization_id)
|
||||||
|
return queue
|
||||||
|
|
||||||
|
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
|
||||||
|
queues = self._subscribers.get(organization_id)
|
||||||
|
if queues:
|
||||||
|
try:
|
||||||
|
queues.remove(queue)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
if not queues:
|
||||||
|
del self._subscribers[organization_id]
|
||||||
|
LOG.info("Notification subscriber removed", organization_id=organization_id)
|
||||||
|
|
||||||
|
def publish(self, organization_id: str, message: dict) -> None:
|
||||||
|
queues = self._subscribers.get(organization_id, [])
|
||||||
|
for queue in queues:
|
||||||
|
try:
|
||||||
|
queue.put_nowait(message)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
LOG.warning(
|
||||||
|
"Notification queue full, dropping message",
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
55
skyvern/forge/sdk/notification/redis.py
Normal file
55
skyvern/forge/sdk/notification/redis.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Redis-backed notification registry for multi-pod deployments.
|
||||||
|
|
||||||
|
Thin adapter around :class:`RedisPubSub` — all Redis pub/sub logic
|
||||||
|
lives in the generic layer; this class maps the ``organization_id``
|
||||||
|
domain concept onto generic string keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||||
|
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
|
||||||
|
|
||||||
|
|
||||||
|
class RedisNotificationRegistry(BaseNotificationRegistry):
|
||||||
|
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per org."""
|
||||||
|
|
||||||
|
def __init__(self, redis_client: Redis) -> None:
|
||||||
|
self._pubsub = RedisPubSub(redis_client, channel_prefix="skyvern:notifications:")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Property accessors (used by existing tests)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _listener_tasks(self) -> dict[str, asyncio.Task[None]]:
|
||||||
|
return self._pubsub._listener_tasks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _subscribers(self) -> dict[str, list[asyncio.Queue[dict]]]:
|
||||||
|
return self._pubsub._subscribers
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public interface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
|
||||||
|
return self._pubsub.subscribe(organization_id)
|
||||||
|
|
||||||
|
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
|
||||||
|
self._pubsub.unsubscribe(organization_id, queue)
|
||||||
|
|
||||||
|
def publish(self, organization_id: str, message: dict) -> None:
|
||||||
|
self._pubsub.publish(organization_id, message)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
await self._pubsub.close()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helper (exposed for tests)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _dispatch_local(self, organization_id: str, message: dict) -> None:
|
||||||
|
self._pubsub._dispatch_local(organization_id, message)
|
||||||
0
skyvern/forge/sdk/redis/__init__.py
Normal file
0
skyvern/forge/sdk/redis/__init__.py
Normal file
21
skyvern/forge/sdk/redis/factory.py
Normal file
21
skyvern/forge/sdk/redis/factory.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
|
||||||
|
class RedisClientFactory:
|
||||||
|
"""Singleton factory for a shared async Redis client.
|
||||||
|
|
||||||
|
Follows the same static set/get pattern as ``CacheFactory``.
|
||||||
|
Defaults to ``None`` (no Redis in local/OSS mode).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__client: Redis | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_client(client: Redis) -> None:
|
||||||
|
RedisClientFactory.__client = client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_client() -> Redis | None:
|
||||||
|
return RedisClientFactory.__client
|
||||||
130
skyvern/forge/sdk/redis/pubsub.py
Normal file
130
skyvern/forge/sdk/redis/pubsub.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""Generic Redis pub/sub layer.
|
||||||
|
|
||||||
|
Extracted from ``RedisNotificationRegistry`` so that any feature
|
||||||
|
(notifications, events, cache invalidation, etc.) can reuse the same
|
||||||
|
pattern with its own channel prefix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class RedisPubSub:
|
||||||
|
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per key."""
|
||||||
|
|
||||||
|
def __init__(self, redis_client: Redis, channel_prefix: str) -> None:
|
||||||
|
self._redis = redis_client
|
||||||
|
self._channel_prefix = channel_prefix
|
||||||
|
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
|
||||||
|
# One listener task per key channel
|
||||||
|
self._listener_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public interface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def subscribe(self, key: str) -> asyncio.Queue[dict]:
|
||||||
|
queue: asyncio.Queue[dict] = asyncio.Queue()
|
||||||
|
self._subscribers[key].append(queue)
|
||||||
|
|
||||||
|
# Spin up a Redis listener if this is the first local subscriber
|
||||||
|
if key not in self._listener_tasks:
|
||||||
|
task = asyncio.get_running_loop().create_task(self._listen(key))
|
||||||
|
self._listener_tasks[key] = task
|
||||||
|
|
||||||
|
LOG.info("PubSub subscriber added", key=key, channel_prefix=self._channel_prefix)
|
||||||
|
return queue
|
||||||
|
|
||||||
|
def unsubscribe(self, key: str, queue: asyncio.Queue[dict]) -> None:
|
||||||
|
queues = self._subscribers.get(key)
|
||||||
|
if queues:
|
||||||
|
try:
|
||||||
|
queues.remove(queue)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
if not queues:
|
||||||
|
del self._subscribers[key]
|
||||||
|
self._cancel_listener(key)
|
||||||
|
LOG.info("PubSub subscriber removed", key=key, channel_prefix=self._channel_prefix)
|
||||||
|
|
||||||
|
def publish(self, key: str, message: dict) -> None:
|
||||||
|
"""Fire-and-forget Redis PUBLISH."""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.create_task(self._publish_to_redis(key, message))
|
||||||
|
except RuntimeError:
|
||||||
|
LOG.warning(
|
||||||
|
"No running event loop; cannot publish via Redis",
|
||||||
|
key=key,
|
||||||
|
channel_prefix=self._channel_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Cancel all listener tasks and clear state. Call on shutdown."""
|
||||||
|
for key in list(self._listener_tasks):
|
||||||
|
self._cancel_listener(key)
|
||||||
|
self._subscribers.clear()
|
||||||
|
LOG.info("RedisPubSub closed", channel_prefix=self._channel_prefix)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _publish_to_redis(self, key: str, message: dict) -> None:
|
||||||
|
channel = f"{self._channel_prefix}{key}"
|
||||||
|
try:
|
||||||
|
await self._redis.publish(channel, json.dumps(message))
|
||||||
|
except Exception:
|
||||||
|
LOG.exception("Failed to publish to Redis", key=key, channel_prefix=self._channel_prefix)
|
||||||
|
|
||||||
|
async def _listen(self, key: str) -> None:
|
||||||
|
"""Subscribe to a Redis channel and fan out messages locally."""
|
||||||
|
channel = f"{self._channel_prefix}{key}"
|
||||||
|
pubsub = self._redis.pubsub()
|
||||||
|
try:
|
||||||
|
await pubsub.subscribe(channel)
|
||||||
|
LOG.info("Redis listener started", channel=channel)
|
||||||
|
async for raw_message in pubsub.listen():
|
||||||
|
if raw_message["type"] != "message":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(raw_message["data"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
LOG.warning("Invalid JSON on Redis channel", channel=channel)
|
||||||
|
continue
|
||||||
|
self._dispatch_local(key, data)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
LOG.info("Redis listener cancelled", channel=channel)
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
LOG.exception("Redis listener error", channel=channel)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await pubsub.unsubscribe(channel)
|
||||||
|
await pubsub.close()
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Error closing Redis pubsub", channel=channel)
|
||||||
|
|
||||||
|
def _dispatch_local(self, key: str, message: dict) -> None:
|
||||||
|
"""Fan out a message to all local asyncio queues for this key."""
|
||||||
|
queues = self._subscribers.get(key, [])
|
||||||
|
for queue in queues:
|
||||||
|
try:
|
||||||
|
queue.put_nowait(message)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
LOG.warning(
|
||||||
|
"Queue full, dropping message",
|
||||||
|
key=key,
|
||||||
|
channel_prefix=self._channel_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _cancel_listener(self, key: str) -> None:
|
||||||
|
task = self._listener_tasks.pop(key, None)
|
||||||
|
if task and not task.done():
|
||||||
|
task.cancel()
|
||||||
@@ -11,5 +11,6 @@ from skyvern.forge.sdk.routes import sdk # noqa: F401
|
|||||||
from skyvern.forge.sdk.routes import webhooks # noqa: F401
|
from skyvern.forge.sdk.routes import webhooks # noqa: F401
|
||||||
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
|
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
|
||||||
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
|
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
|
||||||
|
from skyvern.forge.sdk.routes.streaming import notifications # noqa: F401
|
||||||
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
|
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
|
||||||
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401
|
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401
|
||||||
|
|||||||
@@ -131,10 +131,15 @@ async def send_totp_code(
|
|||||||
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
|
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
|
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
|
||||||
|
workflow_id_for_storage: str | None = None
|
||||||
if data.workflow_id:
|
if data.workflow_id:
|
||||||
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
|
if data.workflow_id.startswith("wpid_"):
|
||||||
|
workflow = await app.DATABASE.get_workflow_by_permanent_id(data.workflow_id, curr_org.organization_id)
|
||||||
|
else:
|
||||||
|
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
|
||||||
if not workflow:
|
if not workflow:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
|
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
|
||||||
|
workflow_id_for_storage = workflow.workflow_id
|
||||||
if data.workflow_run_id:
|
if data.workflow_run_id:
|
||||||
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
|
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
@@ -162,7 +167,7 @@ async def send_totp_code(
|
|||||||
content=data.content,
|
content=data.content,
|
||||||
code=otp_value.value,
|
code=otp_value.value,
|
||||||
task_id=data.task_id,
|
task_id=data.task_id,
|
||||||
workflow_id=data.workflow_id,
|
workflow_id=workflow_id_for_storage,
|
||||||
workflow_run_id=data.workflow_run_id,
|
workflow_run_id=data.workflow_run_id,
|
||||||
source=data.source,
|
source=data.source,
|
||||||
expired_at=data.expired_at,
|
expired_at=data.expired_at,
|
||||||
|
|||||||
101
skyvern/forge/sdk/routes/streaming/notifications.py
Normal file
101
skyvern/forge/sdk/routes/streaming/notifications.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""WebSocket endpoint for streaming global 2FA verification code notifications."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import WebSocket, WebSocketDisconnect
|
||||||
|
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||||
|
from skyvern.forge.sdk.routes.routers import base_router
|
||||||
|
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
||||||
|
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
HEARTBEAT_INTERVAL = 60
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.websocket("/stream/notifications")
|
||||||
|
async def notification_stream(
|
||||||
|
websocket: WebSocket,
|
||||||
|
apikey: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
auth = local_auth if settings.ENV == "local" else real_auth
|
||||||
|
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||||
|
if not organization_id:
|
||||||
|
LOG.info("Notifications: Authentication failed")
|
||||||
|
return
|
||||||
|
|
||||||
|
LOG.info("Notifications: Started streaming", organization_id=organization_id)
|
||||||
|
registry = NotificationRegistryFactory.get_registry()
|
||||||
|
queue = registry.subscribe(organization_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send initial state: all currently active verification requests
|
||||||
|
active_requests = await app.DATABASE.get_active_verification_requests(organization_id)
|
||||||
|
for req in active_requests:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "verification_code_required",
|
||||||
|
"task_id": req.get("task_id"),
|
||||||
|
"workflow_run_id": req.get("workflow_run_id"),
|
||||||
|
"identifier": req.get("verification_code_identifier"),
|
||||||
|
"polling_started_at": req.get("verification_code_polling_started_at"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Watch for client disconnect while streaming events
|
||||||
|
disconnect_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def _watch_disconnect() -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await websocket.receive()
|
||||||
|
except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError):
|
||||||
|
disconnect_event.set()
|
||||||
|
|
||||||
|
watcher = asyncio.create_task(_watch_disconnect())
|
||||||
|
try:
|
||||||
|
while not disconnect_event.is_set():
|
||||||
|
queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL))
|
||||||
|
disconnect_wait = asyncio.ensure_future(disconnect_event.wait())
|
||||||
|
done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
for p in pending:
|
||||||
|
p.cancel()
|
||||||
|
|
||||||
|
if disconnect_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = queue_task.result()
|
||||||
|
await websocket.send_json(message)
|
||||||
|
except TimeoutError:
|
||||||
|
try:
|
||||||
|
await websocket.send_json({"type": "heartbeat"})
|
||||||
|
except Exception:
|
||||||
|
LOG.info(
|
||||||
|
"Notifications: Client unreachable during heartbeat. Closing.",
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
watcher.cancel()
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id)
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id)
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.warning(
|
||||||
|
"Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True)
|
||||||
|
finally:
|
||||||
|
registry.unsubscribe(organization_id, queue)
|
||||||
|
LOG.info("Notifications: Connection closed", organization_id=organization_id)
|
||||||
@@ -288,6 +288,10 @@ class Task(TaskBase):
|
|||||||
queued_at: datetime | None = None
|
queued_at: datetime | None = None
|
||||||
started_at: datetime | None = None
|
started_at: datetime | None = None
|
||||||
finished_at: datetime | None = None
|
finished_at: datetime | None = None
|
||||||
|
# 2FA verification code waiting state fields
|
||||||
|
waiting_for_verification_code: bool = False
|
||||||
|
verification_code_identifier: str | None = None
|
||||||
|
verification_code_polling_started_at: datetime | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def llm_key(self) -> str | None:
|
def llm_key(self) -> str | None:
|
||||||
@@ -365,6 +369,9 @@ class Task(TaskBase):
|
|||||||
max_screenshot_scrolls=self.max_screenshot_scrolls,
|
max_screenshot_scrolls=self.max_screenshot_scrolls,
|
||||||
step_count=step_count,
|
step_count=step_count,
|
||||||
browser_session_id=self.browser_session_id,
|
browser_session_id=self.browser_session_id,
|
||||||
|
waiting_for_verification_code=self.waiting_for_verification_code,
|
||||||
|
verification_code_identifier=self.verification_code_identifier,
|
||||||
|
verification_code_polling_started_at=self.verification_code_polling_started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -392,6 +399,10 @@ class TaskResponse(BaseModel):
|
|||||||
max_screenshot_scrolls: int | None = None
|
max_screenshot_scrolls: int | None = None
|
||||||
step_count: int | None = None
|
step_count: int | None = None
|
||||||
browser_session_id: str | None = None
|
browser_session_id: str | None = None
|
||||||
|
# 2FA verification code waiting state fields
|
||||||
|
waiting_for_verification_code: bool = False
|
||||||
|
verification_code_identifier: str | None = None
|
||||||
|
verification_code_polling_started_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class TaskOutput(BaseModel):
|
class TaskOutput(BaseModel):
|
||||||
|
|||||||
@@ -172,6 +172,10 @@ class WorkflowRun(BaseModel):
|
|||||||
sequential_key: str | None = None
|
sequential_key: str | None = None
|
||||||
ai_fallback: bool | None = None
|
ai_fallback: bool | None = None
|
||||||
code_gen: bool | None = None
|
code_gen: bool | None = None
|
||||||
|
# 2FA verification code waiting state fields
|
||||||
|
waiting_for_verification_code: bool = False
|
||||||
|
verification_code_identifier: str | None = None
|
||||||
|
verification_code_polling_started_at: datetime | None = None
|
||||||
|
|
||||||
queued_at: datetime | None = None
|
queued_at: datetime | None = None
|
||||||
started_at: datetime | None = None
|
started_at: datetime | None = None
|
||||||
@@ -226,6 +230,10 @@ class WorkflowRunResponseBase(BaseModel):
|
|||||||
browser_address: str | None = None
|
browser_address: str | None = None
|
||||||
script_run: ScriptRunResponse | None = None
|
script_run: ScriptRunResponse | None = None
|
||||||
errors: list[dict[str, Any]] | None = None
|
errors: list[dict[str, Any]] | None = None
|
||||||
|
# 2FA verification code waiting state fields
|
||||||
|
waiting_for_verification_code: bool = False
|
||||||
|
verification_code_identifier: str | None = None
|
||||||
|
verification_code_polling_started_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunWithWorkflowResponse(WorkflowRunResponseBase):
|
class WorkflowRunWithWorkflowResponse(WorkflowRunResponseBase):
|
||||||
|
|||||||
@@ -3019,6 +3019,10 @@ class WorkflowService:
|
|||||||
browser_address=workflow_run.browser_address,
|
browser_address=workflow_run.browser_address,
|
||||||
script_run=workflow_run.script_run,
|
script_run=workflow_run.script_run,
|
||||||
errors=errors,
|
errors=errors,
|
||||||
|
# 2FA verification code waiting state fields
|
||||||
|
waiting_for_verification_code=workflow_run.waiting_for_verification_code,
|
||||||
|
verification_code_identifier=workflow_run.verification_code_identifier,
|
||||||
|
verification_code_polling_started_at=workflow_run.verification_code_polling_started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def clean_up_workflow(
|
async def clean_up_workflow(
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from skyvern.forge.prompts import prompt_engine
|
|||||||
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
||||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
||||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
|
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||||
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
@@ -80,38 +81,147 @@ async def poll_otp_value(
|
|||||||
totp_verification_url=totp_verification_url,
|
totp_verification_url=totp_verification_url,
|
||||||
totp_identifier=totp_identifier,
|
totp_identifier=totp_identifier,
|
||||||
)
|
)
|
||||||
while True:
|
|
||||||
await asyncio.sleep(10)
|
# Set the waiting state in the database when polling starts
|
||||||
# check timeout
|
identifier_for_ui = totp_identifier
|
||||||
if datetime.utcnow() > timeout_datetime:
|
if workflow_run_id:
|
||||||
LOG.warning("Polling otp value timed out")
|
try:
|
||||||
raise NoTOTPVerificationCodeFound(
|
await app.DATABASE.update_workflow_run(
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
workflow_id=workflow_permanent_id,
|
waiting_for_verification_code=True,
|
||||||
totp_verification_url=totp_verification_url,
|
verification_code_identifier=identifier_for_ui,
|
||||||
totp_identifier=totp_identifier,
|
verification_code_polling_started_at=start_datetime,
|
||||||
)
|
)
|
||||||
otp_value: OTPValue | None = None
|
LOG.info(
|
||||||
if totp_verification_url:
|
"Set 2FA waiting state for workflow run",
|
||||||
otp_value = await _get_otp_value_from_url(
|
|
||||||
organization_id,
|
|
||||||
totp_verification_url,
|
|
||||||
org_token.token,
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
|
verification_code_identifier=identifier_for_ui,
|
||||||
)
|
)
|
||||||
elif totp_identifier:
|
try:
|
||||||
otp_value = await _get_otp_value_from_db(
|
NotificationRegistryFactory.get_registry().publish(
|
||||||
organization_id,
|
organization_id,
|
||||||
totp_identifier,
|
{
|
||||||
|
"type": "verification_code_required",
|
||||||
|
"workflow_run_id": workflow_run_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
"identifier": identifier_for_ui,
|
||||||
|
"polling_started_at": start_datetime.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to publish 2FA required notification for workflow run", exc_info=True)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to set 2FA waiting state for workflow run", exc_info=True)
|
||||||
|
elif task_id:
|
||||||
|
try:
|
||||||
|
await app.DATABASE.update_task_2fa_state(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
workflow_id=workflow_permanent_id,
|
organization_id=organization_id,
|
||||||
workflow_run_id=workflow_run_id,
|
waiting_for_verification_code=True,
|
||||||
|
verification_code_identifier=identifier_for_ui,
|
||||||
|
verification_code_polling_started_at=start_datetime,
|
||||||
)
|
)
|
||||||
if otp_value:
|
LOG.info(
|
||||||
LOG.info("Got otp value", otp_value=otp_value)
|
"Set 2FA waiting state for task",
|
||||||
return otp_value
|
task_id=task_id,
|
||||||
|
verification_code_identifier=identifier_for_ui,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
NotificationRegistryFactory.get_registry().publish(
|
||||||
|
organization_id,
|
||||||
|
{
|
||||||
|
"type": "verification_code_required",
|
||||||
|
"task_id": task_id,
|
||||||
|
"identifier": identifier_for_ui,
|
||||||
|
"polling_started_at": start_datetime.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to publish 2FA required notification for task", exc_info=True)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to set 2FA waiting state for task", exc_info=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
# check timeout
|
||||||
|
if datetime.utcnow() > timeout_datetime:
|
||||||
|
LOG.warning("Polling otp value timed out")
|
||||||
|
raise NoTOTPVerificationCodeFound(
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
workflow_id=workflow_permanent_id,
|
||||||
|
totp_verification_url=totp_verification_url,
|
||||||
|
totp_identifier=totp_identifier,
|
||||||
|
)
|
||||||
|
otp_value: OTPValue | None = None
|
||||||
|
if totp_verification_url:
|
||||||
|
otp_value = await _get_otp_value_from_url(
|
||||||
|
organization_id,
|
||||||
|
totp_verification_url,
|
||||||
|
org_token.token,
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
elif totp_identifier:
|
||||||
|
otp_value = await _get_otp_value_from_db(
|
||||||
|
organization_id,
|
||||||
|
totp_identifier,
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
if not otp_value:
|
||||||
|
otp_value = await _get_otp_value_by_run(
|
||||||
|
organization_id,
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No pre-configured TOTP — poll for manually submitted codes by run context
|
||||||
|
otp_value = await _get_otp_value_by_run(
|
||||||
|
organization_id,
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
if otp_value:
|
||||||
|
LOG.info("Got otp value", otp_value=otp_value)
|
||||||
|
return otp_value
|
||||||
|
finally:
|
||||||
|
# Clear the waiting state when polling completes (success, timeout, or error)
|
||||||
|
if workflow_run_id:
|
||||||
|
try:
|
||||||
|
await app.DATABASE.update_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
waiting_for_verification_code=False,
|
||||||
|
)
|
||||||
|
LOG.info("Cleared 2FA waiting state for workflow run", workflow_run_id=workflow_run_id)
|
||||||
|
try:
|
||||||
|
NotificationRegistryFactory.get_registry().publish(
|
||||||
|
organization_id,
|
||||||
|
{"type": "verification_code_resolved", "workflow_run_id": workflow_run_id, "task_id": task_id},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to publish 2FA resolved notification for workflow run", exc_info=True)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to clear 2FA waiting state for workflow run", exc_info=True)
|
||||||
|
elif task_id:
|
||||||
|
try:
|
||||||
|
await app.DATABASE.update_task_2fa_state(
|
||||||
|
task_id=task_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
waiting_for_verification_code=False,
|
||||||
|
)
|
||||||
|
LOG.info("Cleared 2FA waiting state for task", task_id=task_id)
|
||||||
|
try:
|
||||||
|
NotificationRegistryFactory.get_registry().publish(
|
||||||
|
organization_id,
|
||||||
|
{"type": "verification_code_resolved", "task_id": task_id},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to publish 2FA resolved notification for task", exc_info=True)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to clear 2FA waiting state for task", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
async def _get_otp_value_from_url(
|
async def _get_otp_value_from_url(
|
||||||
@@ -175,6 +285,28 @@ async def _get_otp_value_from_url(
|
|||||||
return otp_value
|
return otp_value
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_otp_value_by_run(
|
||||||
|
organization_id: str,
|
||||||
|
task_id: str | None = None,
|
||||||
|
workflow_run_id: str | None = None,
|
||||||
|
) -> OTPValue | None:
|
||||||
|
"""Look up OTP codes by task_id/workflow_run_id when no totp_identifier is configured.
|
||||||
|
|
||||||
|
Used for the manual 2FA input flow where users submit codes through the UI
|
||||||
|
without pre-configured TOTP credentials.
|
||||||
|
"""
|
||||||
|
codes = await app.DATABASE.get_otp_codes_by_run(
|
||||||
|
organization_id=organization_id,
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
limit=1,
|
||||||
|
)
|
||||||
|
if codes:
|
||||||
|
code = codes[0]
|
||||||
|
return OTPValue(value=code.code, type=code.otp_type)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def _get_otp_value_from_db(
|
async def _get_otp_value_from_db(
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
totp_identifier: str,
|
totp_identifier: str,
|
||||||
|
|||||||
106
tests/unit_tests/test_notification_registry.py
Normal file
106
tests/unit_tests/test_notification_registry.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""Tests for NotificationRegistry pub/sub and get_active_verification_requests (SKY-6)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||||
|
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||||
|
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||||
|
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
||||||
|
|
||||||
|
# === Task 1: NotificationRegistry subscribe / publish / unsubscribe ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_and_publish():
|
||||||
|
"""Published messages should be received by subscribers."""
|
||||||
|
registry = LocalNotificationRegistry()
|
||||||
|
queue = registry.subscribe("org_1")
|
||||||
|
|
||||||
|
registry.publish("org_1", {"type": "verification_code_required", "task_id": "tsk_1"})
|
||||||
|
msg = queue.get_nowait()
|
||||||
|
assert msg["type"] == "verification_code_required"
|
||||||
|
assert msg["task_id"] == "tsk_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_subscribers():
|
||||||
|
"""All subscribers for an org should receive the same message."""
|
||||||
|
registry = LocalNotificationRegistry()
|
||||||
|
q1 = registry.subscribe("org_1")
|
||||||
|
q2 = registry.subscribe("org_1")
|
||||||
|
|
||||||
|
registry.publish("org_1", {"type": "verification_code_required"})
|
||||||
|
assert not q1.empty()
|
||||||
|
assert not q2.empty()
|
||||||
|
assert q1.get_nowait() == q2.get_nowait()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_wrong_org_does_not_leak():
|
||||||
|
"""Messages for org_A should not appear in org_B's queue."""
|
||||||
|
registry = LocalNotificationRegistry()
|
||||||
|
q_a = registry.subscribe("org_a")
|
||||||
|
q_b = registry.subscribe("org_b")
|
||||||
|
|
||||||
|
registry.publish("org_a", {"type": "test"})
|
||||||
|
assert not q_a.empty()
|
||||||
|
assert q_b.empty()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsubscribe():
|
||||||
|
"""After unsubscribe, the queue should no longer receive messages."""
|
||||||
|
registry = LocalNotificationRegistry()
|
||||||
|
queue = registry.subscribe("org_1")
|
||||||
|
|
||||||
|
registry.unsubscribe("org_1", queue)
|
||||||
|
registry.publish("org_1", {"type": "test"})
|
||||||
|
assert queue.empty()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsubscribe_idempotent():
|
||||||
|
"""Unsubscribing a queue that's already removed should not raise."""
|
||||||
|
registry = LocalNotificationRegistry()
|
||||||
|
queue = registry.subscribe("org_1")
|
||||||
|
registry.unsubscribe("org_1", queue)
|
||||||
|
registry.unsubscribe("org_1", queue) # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
# === Task: BaseNotificationRegistry ABC ===
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_notification_registry_cannot_be_instantiated():
|
||||||
|
"""ABC should not be directly instantiable."""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
BaseNotificationRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
# === Task: NotificationRegistryFactory ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_factory_returns_local_by_default():
|
||||||
|
"""Factory should return a LocalNotificationRegistry by default."""
|
||||||
|
registry = NotificationRegistryFactory.get_registry()
|
||||||
|
assert isinstance(registry, LocalNotificationRegistry)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_factory_set_and_get():
|
||||||
|
"""Factory should allow swapping the registry implementation."""
|
||||||
|
original = NotificationRegistryFactory.get_registry()
|
||||||
|
try:
|
||||||
|
custom = LocalNotificationRegistry()
|
||||||
|
NotificationRegistryFactory.set_registry(custom)
|
||||||
|
assert NotificationRegistryFactory.get_registry() is custom
|
||||||
|
finally:
|
||||||
|
NotificationRegistryFactory.set_registry(original)
|
||||||
|
|
||||||
|
|
||||||
|
# === Task 2: get_active_verification_requests DB method ===
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_active_verification_requests_method_exists():
|
||||||
|
"""AgentDB should have get_active_verification_requests method."""
|
||||||
|
assert hasattr(AgentDB, "get_active_verification_requests")
|
||||||
544
tests/unit_tests/test_otp_no_config.py
Normal file
544
tests/unit_tests/test_otp_no_config.py
Normal file
@@ -0,0 +1,544 @@
|
|||||||
|
"""Tests for manual 2FA input without pre-configured TOTP credentials (SKY-6)."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from skyvern.constants import SPECIAL_FIELD_VERIFICATION_CODE
|
||||||
|
from skyvern.forge.agent import ForgeAgent
|
||||||
|
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||||
|
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
||||||
|
from skyvern.forge.sdk.routes.credentials import send_totp_code
|
||||||
|
from skyvern.forge.sdk.schemas.totp_codes import TOTPCodeCreate
|
||||||
|
from skyvern.schemas.runs import RunEngine
|
||||||
|
from skyvern.services.otp_service import OTPValue, _get_otp_value_by_run, poll_otp_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_otp_codes_by_run_exists():
|
||||||
|
"""get_otp_codes_by_run should exist on AgentDB."""
|
||||||
|
assert hasattr(AgentDB, "get_otp_codes_by_run"), "AgentDB missing get_otp_codes_by_run method"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_otp_codes_by_run_returns_empty_without_identifiers():
|
||||||
|
"""get_otp_codes_by_run should return [] when neither task_id nor workflow_run_id is given."""
|
||||||
|
db = AgentDB.__new__(AgentDB)
|
||||||
|
result = await db.get_otp_codes_by_run(
|
||||||
|
organization_id="org_1",
|
||||||
|
)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# === Task 2: _get_otp_value_by_run OTP service function ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_otp_value_by_run_returns_code():
|
||||||
|
"""_get_otp_value_by_run should find OTP codes by task_id."""
|
||||||
|
mock_code = MagicMock()
|
||||||
|
mock_code.code = "123456"
|
||||||
|
mock_code.otp_type = "totp"
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.DATABASE = mock_db
|
||||||
|
|
||||||
|
with patch("skyvern.services.otp_service.app", new=mock_app):
|
||||||
|
result = await _get_otp_value_by_run(
|
||||||
|
organization_id="org_1",
|
||||||
|
task_id="tsk_1",
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "123456"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_otp_value_by_run_returns_none_when_no_codes():
|
||||||
|
"""_get_otp_value_by_run should return None when no codes found."""
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.get_otp_codes_by_run.return_value = []
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.DATABASE = mock_db
|
||||||
|
|
||||||
|
with patch("skyvern.services.otp_service.app", new=mock_app):
|
||||||
|
result = await _get_otp_value_by_run(
|
||||||
|
organization_id="org_1",
|
||||||
|
task_id="tsk_1",
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# === Task 3: poll_otp_value without identifier ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_otp_value_without_identifier_uses_run_lookup():
|
||||||
|
"""poll_otp_value should use _get_otp_value_by_run when no identifier/URL provided."""
|
||||||
|
mock_code = MagicMock()
|
||||||
|
mock_code.code = "123456"
|
||||||
|
mock_code.otp_type = "totp"
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||||
|
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||||
|
mock_db.update_task_2fa_state = AsyncMock()
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.DATABASE = mock_db
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||||
|
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
):
|
||||||
|
result = await poll_otp_value(
|
||||||
|
organization_id="org_1",
|
||||||
|
task_id="tsk_1",
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "123456"
|
||||||
|
|
||||||
|
|
||||||
|
# === Task 6: Integration test — handle_potential_OTP_actions without TOTP config ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_potential_OTP_actions_without_totp_config():
|
||||||
|
"""When LLM detects 2FA but no TOTP config exists, should still enter verification flow."""
|
||||||
|
agent = ForgeAgent.__new__(ForgeAgent)
|
||||||
|
|
||||||
|
task = MagicMock()
|
||||||
|
task.organization_id = "org_1"
|
||||||
|
task.totp_verification_url = None
|
||||||
|
task.totp_identifier = None
|
||||||
|
task.task_id = "tsk_1"
|
||||||
|
task.workflow_run_id = None
|
||||||
|
|
||||||
|
step = MagicMock()
|
||||||
|
scraped_page = MagicMock()
|
||||||
|
browser_state = MagicMock()
|
||||||
|
|
||||||
|
json_response = {
|
||||||
|
"should_enter_verification_code": True,
|
||||||
|
"place_to_enter_verification_code": "input#otp-code",
|
||||||
|
"actions": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(agent, "handle_potential_verification_code", new_callable=AsyncMock) as mock_handler:
|
||||||
|
mock_handler.return_value = {"actions": []}
|
||||||
|
with patch("skyvern.forge.agent.parse_actions", return_value=[]):
|
||||||
|
result_json, result_actions = await agent.handle_potential_OTP_actions(
|
||||||
|
task, step, scraped_page, browser_state, json_response
|
||||||
|
)
|
||||||
|
mock_handler.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_potential_OTP_actions_skips_magic_link_without_totp_config():
|
||||||
|
"""Magic links should still require TOTP config."""
|
||||||
|
agent = ForgeAgent.__new__(ForgeAgent)
|
||||||
|
|
||||||
|
task = MagicMock()
|
||||||
|
task.organization_id = "org_1"
|
||||||
|
task.totp_verification_url = None
|
||||||
|
task.totp_identifier = None
|
||||||
|
|
||||||
|
step = MagicMock()
|
||||||
|
scraped_page = MagicMock()
|
||||||
|
browser_state = MagicMock()
|
||||||
|
|
||||||
|
json_response = {
|
||||||
|
"should_verify_by_magic_link": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(agent, "handle_potential_magic_link", new_callable=AsyncMock) as mock_handler:
|
||||||
|
result_json, result_actions = await agent.handle_potential_OTP_actions(
|
||||||
|
task, step, scraped_page, browser_state, json_response
|
||||||
|
)
|
||||||
|
mock_handler.assert_not_called()
|
||||||
|
assert result_actions == []
|
||||||
|
|
||||||
|
|
||||||
|
# === Task 7: verification_code_check always True in LLM prompt ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verification_code_check_always_true_without_totp_config():
|
||||||
|
"""_build_extract_action_prompt should receive verification_code_check=True even when task has no TOTP config."""
|
||||||
|
agent = ForgeAgent.__new__(ForgeAgent)
|
||||||
|
agent.async_operation_pool = MagicMock()
|
||||||
|
|
||||||
|
task = MagicMock()
|
||||||
|
task.totp_verification_url = None
|
||||||
|
task.totp_identifier = None
|
||||||
|
task.task_id = "tsk_1"
|
||||||
|
task.workflow_run_id = None
|
||||||
|
task.organization_id = "org_1"
|
||||||
|
task.url = "https://example.com"
|
||||||
|
|
||||||
|
step = MagicMock()
|
||||||
|
step.step_id = "step_1"
|
||||||
|
step.order = 0
|
||||||
|
step.retry_index = 0
|
||||||
|
|
||||||
|
scraped_page = MagicMock()
|
||||||
|
scraped_page.elements = []
|
||||||
|
|
||||||
|
browser_state = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
|
||||||
|
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
|
||||||
|
patch.object(
|
||||||
|
agent,
|
||||||
|
"_build_extract_action_prompt",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=("prompt", False, "extract_action"),
|
||||||
|
) as mock_build,
|
||||||
|
):
|
||||||
|
mock_ctx.current.return_value = None
|
||||||
|
await agent.build_and_record_step_prompt(
|
||||||
|
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
|
||||||
|
)
|
||||||
|
mock_build.assert_called_once()
|
||||||
|
_, kwargs = mock_build.call_args
|
||||||
|
assert kwargs["verification_code_check"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verification_code_check_always_true_with_totp_config():
|
||||||
|
"""_build_extract_action_prompt should receive verification_code_check=True when task HAS TOTP config (unchanged)."""
|
||||||
|
agent = ForgeAgent.__new__(ForgeAgent)
|
||||||
|
agent.async_operation_pool = MagicMock()
|
||||||
|
|
||||||
|
task = MagicMock()
|
||||||
|
task.totp_verification_url = "https://otp.example.com"
|
||||||
|
task.totp_identifier = "user@example.com"
|
||||||
|
task.task_id = "tsk_2"
|
||||||
|
task.workflow_run_id = None
|
||||||
|
task.organization_id = "org_1"
|
||||||
|
task.url = "https://example.com"
|
||||||
|
|
||||||
|
step = MagicMock()
|
||||||
|
step.step_id = "step_1"
|
||||||
|
step.order = 0
|
||||||
|
step.retry_index = 0
|
||||||
|
|
||||||
|
scraped_page = MagicMock()
|
||||||
|
scraped_page.elements = []
|
||||||
|
|
||||||
|
browser_state = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
|
||||||
|
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
|
||||||
|
patch.object(
|
||||||
|
agent,
|
||||||
|
"_build_extract_action_prompt",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=("prompt", False, "extract_action"),
|
||||||
|
) as mock_build,
|
||||||
|
):
|
||||||
|
mock_ctx.current.return_value = None
|
||||||
|
await agent.build_and_record_step_prompt(
|
||||||
|
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
|
||||||
|
)
|
||||||
|
mock_build.assert_called_once()
|
||||||
|
_, kwargs = mock_build.call_args
|
||||||
|
assert kwargs["verification_code_check"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# === Fix: poll_otp_value should pass workflow_id, not workflow_permanent_id ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_otp_value_passes_workflow_id_not_permanent_id():
|
||||||
|
"""poll_otp_value should pass workflow_id (w_* format) to _get_otp_value_from_db, not workflow_permanent_id."""
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||||
|
mock_db.update_workflow_run = AsyncMock()
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.DATABASE = mock_db
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||||
|
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"skyvern.services.otp_service._get_otp_value_from_db",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=OTPValue(value="654321", type="totp"),
|
||||||
|
) as mock_get_from_db,
|
||||||
|
):
|
||||||
|
result = await poll_otp_value(
|
||||||
|
organization_id="org_1",
|
||||||
|
workflow_id="w_123",
|
||||||
|
workflow_run_id="wr_789",
|
||||||
|
workflow_permanent_id="wpid_456",
|
||||||
|
totp_identifier="user@example.com",
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.value == "654321"
|
||||||
|
mock_get_from_db.assert_called_once_with(
|
||||||
|
"org_1",
|
||||||
|
"user@example.com",
|
||||||
|
task_id=None,
|
||||||
|
workflow_id="w_123",
|
||||||
|
workflow_run_id="wr_789",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# === Fix: send_totp_code should resolve wpid_* to w_* before storage ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_totp_code_resolves_wpid_to_workflow_id():
|
||||||
|
"""send_totp_code should resolve wpid_* to w_* before storing in DB."""
|
||||||
|
mock_workflow = MagicMock()
|
||||||
|
mock_workflow.workflow_id = "w_abc123"
|
||||||
|
|
||||||
|
mock_totp_code = MagicMock()
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
||||||
|
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.DATABASE = mock_db
|
||||||
|
|
||||||
|
data = TOTPCodeCreate(
|
||||||
|
totp_identifier="user@example.com",
|
||||||
|
content="123456",
|
||||||
|
workflow_id="wpid_xyz789",
|
||||||
|
)
|
||||||
|
curr_org = MagicMock()
|
||||||
|
curr_org.organization_id = "org_1"
|
||||||
|
|
||||||
|
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
||||||
|
await send_totp_code(data=data, curr_org=curr_org)
|
||||||
|
|
||||||
|
mock_db.create_otp_code.assert_called_once()
|
||||||
|
call_kwargs = mock_db.create_otp_code.call_args[1]
|
||||||
|
assert call_kwargs["workflow_id"] == "w_abc123", f"Expected w_abc123 but got {call_kwargs['workflow_id']}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_totp_code_w_format_passes_through():
|
||||||
|
"""send_totp_code should resolve and store w_* format workflow_id correctly."""
|
||||||
|
mock_workflow = MagicMock()
|
||||||
|
mock_workflow.workflow_id = "w_abc123"
|
||||||
|
|
||||||
|
mock_totp_code = MagicMock()
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.get_workflow = AsyncMock(return_value=mock_workflow)
|
||||||
|
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.DATABASE = mock_db
|
||||||
|
|
||||||
|
data = TOTPCodeCreate(
|
||||||
|
totp_identifier="user@example.com",
|
||||||
|
content="123456",
|
||||||
|
workflow_id="w_abc123",
|
||||||
|
)
|
||||||
|
curr_org = MagicMock()
|
||||||
|
curr_org.organization_id = "org_1"
|
||||||
|
|
||||||
|
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
||||||
|
await send_totp_code(data=data, curr_org=curr_org)
|
||||||
|
|
||||||
|
call_kwargs = mock_db.create_otp_code.call_args[1]
|
||||||
|
assert call_kwargs["workflow_id"] == "w_abc123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_totp_code_none_workflow_id():
|
||||||
|
"""send_totp_code should pass None workflow_id when not provided."""
|
||||||
|
mock_totp_code = MagicMock()
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.DATABASE = mock_db
|
||||||
|
|
||||||
|
data = TOTPCodeCreate(
|
||||||
|
totp_identifier="user@example.com",
|
||||||
|
content="123456",
|
||||||
|
)
|
||||||
|
curr_org = MagicMock()
|
||||||
|
curr_org.organization_id = "org_1"
|
||||||
|
|
||||||
|
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
||||||
|
await send_totp_code(data=data, curr_org=curr_org)
|
||||||
|
|
||||||
|
call_kwargs = mock_db.create_otp_code.call_args[1]
|
||||||
|
assert call_kwargs["workflow_id"] is None
|
||||||
|
|
||||||
|
|
||||||
|
# === Fix: _build_navigation_payload should inject code without TOTP config ===
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_navigation_payload_injects_code_without_totp_config():
|
||||||
|
"""_build_navigation_payload should inject SPECIAL_FIELD_VERIFICATION_CODE even when
|
||||||
|
task has no totp_verification_url or totp_identifier (manual 2FA flow)."""
|
||||||
|
agent = ForgeAgent.__new__(ForgeAgent)
|
||||||
|
|
||||||
|
task = MagicMock()
|
||||||
|
task.totp_verification_url = None
|
||||||
|
task.totp_identifier = None
|
||||||
|
task.task_id = "tsk_manual_2fa"
|
||||||
|
task.workflow_run_id = "wr_123"
|
||||||
|
task.navigation_payload = {"username": "user@example.com"}
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.totp_codes = {"tsk_manual_2fa": "123456"}
|
||||||
|
mock_context.has_magic_link_page.return_value = False
|
||||||
|
|
||||||
|
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
||||||
|
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
||||||
|
result = agent._build_navigation_payload(task)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert SPECIAL_FIELD_VERIFICATION_CODE in result
|
||||||
|
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "123456"
|
||||||
|
# Original payload preserved
|
||||||
|
assert result["username"] == "user@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_navigation_payload_injects_code_when_payload_is_none():
|
||||||
|
"""_build_navigation_payload should create a dict with the code when payload is None."""
|
||||||
|
agent = ForgeAgent.__new__(ForgeAgent)
|
||||||
|
|
||||||
|
task = MagicMock()
|
||||||
|
task.totp_verification_url = None
|
||||||
|
task.totp_identifier = None
|
||||||
|
task.task_id = "tsk_manual_2fa"
|
||||||
|
task.workflow_run_id = "wr_123"
|
||||||
|
task.navigation_payload = None
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.totp_codes = {"tsk_manual_2fa": "999999"}
|
||||||
|
mock_context.has_magic_link_page.return_value = False
|
||||||
|
|
||||||
|
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
||||||
|
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
||||||
|
result = agent._build_navigation_payload(task)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "999999"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_navigation_payload_no_code_no_injection():
|
||||||
|
"""_build_navigation_payload should NOT inject anything when no code in context."""
|
||||||
|
agent = ForgeAgent.__new__(ForgeAgent)
|
||||||
|
|
||||||
|
task = MagicMock()
|
||||||
|
task.totp_verification_url = None
|
||||||
|
task.totp_identifier = None
|
||||||
|
task.task_id = "tsk_no_code"
|
||||||
|
task.workflow_run_id = "wr_456"
|
||||||
|
task.navigation_payload = {"field": "value"}
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.totp_codes = {} # No code in context
|
||||||
|
mock_context.has_magic_link_page.return_value = False
|
||||||
|
|
||||||
|
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
||||||
|
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
||||||
|
result = agent._build_navigation_payload(task)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert SPECIAL_FIELD_VERIFICATION_CODE not in result
|
||||||
|
assert result["field"] == "value"
|
||||||
|
|
||||||
|
|
||||||
|
# === Task: poll_otp_value publishes 2FA events to notification registry ===
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_otp_value_publishes_required_event_for_task():
|
||||||
|
"""poll_otp_value should publish verification_code_required when task waiting state is set."""
|
||||||
|
mock_code = MagicMock()
|
||||||
|
mock_code.code = "123456"
|
||||||
|
mock_code.otp_type = "totp"
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||||
|
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||||
|
mock_db.update_task_2fa_state = AsyncMock()
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.DATABASE = mock_db
|
||||||
|
|
||||||
|
registry = LocalNotificationRegistry()
|
||||||
|
queue = registry.subscribe("org_1")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||||
|
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
|
||||||
|
new=registry,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await poll_otp_value(organization_id="org_1", task_id="tsk_1")
|
||||||
|
|
||||||
|
# Should have received required + resolved messages
|
||||||
|
messages = []
|
||||||
|
while not queue.empty():
|
||||||
|
messages.append(queue.get_nowait())
|
||||||
|
|
||||||
|
types = [m["type"] for m in messages]
|
||||||
|
assert "verification_code_required" in types
|
||||||
|
assert "verification_code_resolved" in types
|
||||||
|
|
||||||
|
required = next(m for m in messages if m["type"] == "verification_code_required")
|
||||||
|
assert required["task_id"] == "tsk_1"
|
||||||
|
|
||||||
|
resolved = next(m for m in messages if m["type"] == "verification_code_resolved")
|
||||||
|
assert resolved["task_id"] == "tsk_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_otp_value_publishes_required_event_for_workflow_run():
|
||||||
|
"""poll_otp_value should publish verification_code_required when workflow run waiting state is set."""
|
||||||
|
mock_code = MagicMock()
|
||||||
|
mock_code.code = "654321"
|
||||||
|
mock_code.otp_type = "totp"
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
||||||
|
mock_db.update_workflow_run = AsyncMock()
|
||||||
|
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.DATABASE = mock_db
|
||||||
|
|
||||||
|
registry = LocalNotificationRegistry()
|
||||||
|
queue = registry.subscribe("org_1")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("skyvern.services.otp_service.app", new=mock_app),
|
||||||
|
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
|
||||||
|
new=registry,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await poll_otp_value(organization_id="org_1", workflow_run_id="wr_1")
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
while not queue.empty():
|
||||||
|
messages.append(queue.get_nowait())
|
||||||
|
|
||||||
|
types = [m["type"] for m in messages]
|
||||||
|
assert "verification_code_required" in types
|
||||||
|
assert "verification_code_resolved" in types
|
||||||
|
|
||||||
|
required = next(m for m in messages if m["type"] == "verification_code_required")
|
||||||
|
assert required["workflow_run_id"] == "wr_1"
|
||||||
22
tests/unit_tests/test_redis_client_factory.py
Normal file
22
tests/unit_tests/test_redis_client_factory.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Tests for RedisClientFactory."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.redis.factory import RedisClientFactory
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_is_none():
|
||||||
|
"""Factory returns None when no client has been set."""
|
||||||
|
# Reset to default state
|
||||||
|
RedisClientFactory.set_client(None) # type: ignore[arg-type]
|
||||||
|
assert RedisClientFactory.get_client() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_and_get():
|
||||||
|
"""Round-trip: set_client then get_client returns the same object."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
RedisClientFactory.set_client(mock_client)
|
||||||
|
assert RedisClientFactory.get_client() is mock_client
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
RedisClientFactory.set_client(None) # type: ignore[arg-type]
|
||||||
237
tests/unit_tests/test_redis_notification_registry.py
Normal file
237
tests/unit_tests/test_redis_notification_registry.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
"""Tests for RedisNotificationRegistry (SKY-6).
|
||||||
|
|
||||||
|
All tests use a mock Redis client — no real Redis instance required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.notification.redis import RedisNotificationRegistry
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_redis() -> MagicMock:
|
||||||
|
"""Return a mock redis.asyncio.Redis client."""
|
||||||
|
redis = MagicMock()
|
||||||
|
redis.publish = AsyncMock()
|
||||||
|
redis.pubsub = MagicMock()
|
||||||
|
return redis
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
|
||||||
|
"""Return a mock PubSub that yields *messages* from ``listen()``.
|
||||||
|
|
||||||
|
Each entry in *messages* should look like:
|
||||||
|
{"type": "message", "data": '{"key": "val"}'}
|
||||||
|
|
||||||
|
If *block* is True the async generator will hang forever after
|
||||||
|
exhausting *messages*, which keeps the listener task alive so that
|
||||||
|
cancellation semantics can be tested.
|
||||||
|
"""
|
||||||
|
pubsub = MagicMock()
|
||||||
|
pubsub.subscribe = AsyncMock()
|
||||||
|
pubsub.unsubscribe = AsyncMock()
|
||||||
|
pubsub.close = AsyncMock()
|
||||||
|
|
||||||
|
async def _listen():
|
||||||
|
for msg in messages or []:
|
||||||
|
yield msg
|
||||||
|
if block:
|
||||||
|
# Keep the listener alive until cancelled
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
pubsub.listen = _listen
|
||||||
|
return pubsub
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: subscribe
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_creates_queue_and_starts_listener():
|
||||||
|
"""subscribe() should return a queue and start a background listener task."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
registry = RedisNotificationRegistry(redis)
|
||||||
|
queue = registry.subscribe("org_1")
|
||||||
|
|
||||||
|
assert isinstance(queue, asyncio.Queue)
|
||||||
|
assert "org_1" in registry._listener_tasks
|
||||||
|
task = registry._listener_tasks["org_1"]
|
||||||
|
assert isinstance(task, asyncio.Task)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await registry.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_reuses_listener_for_same_org():
|
||||||
|
"""A second subscribe for the same org should NOT create a new listener task."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
registry = RedisNotificationRegistry(redis)
|
||||||
|
registry.subscribe("org_1")
|
||||||
|
first_task = registry._listener_tasks["org_1"]
|
||||||
|
|
||||||
|
registry.subscribe("org_1")
|
||||||
|
assert registry._listener_tasks["org_1"] is first_task
|
||||||
|
|
||||||
|
await registry.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: unsubscribe
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
|
||||||
|
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub(block=True)
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
registry = RedisNotificationRegistry(redis)
|
||||||
|
queue = registry.subscribe("org_1")
|
||||||
|
|
||||||
|
# Let the listener task start running
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
task = registry._listener_tasks["org_1"]
|
||||||
|
|
||||||
|
registry.unsubscribe("org_1", queue)
|
||||||
|
assert "org_1" not in registry._listener_tasks
|
||||||
|
|
||||||
|
# Wait for the task to fully complete after cancellation
|
||||||
|
await asyncio.gather(task, return_exceptions=True)
|
||||||
|
assert task.cancelled()
|
||||||
|
|
||||||
|
await registry.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
|
||||||
|
"""If other subscribers remain, the listener task should stay alive."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
registry = RedisNotificationRegistry(redis)
|
||||||
|
q1 = registry.subscribe("org_1")
|
||||||
|
registry.subscribe("org_1") # second subscriber
|
||||||
|
|
||||||
|
registry.unsubscribe("org_1", q1)
|
||||||
|
assert "org_1" in registry._listener_tasks
|
||||||
|
|
||||||
|
await registry.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: publish
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_calls_redis_publish():
|
||||||
|
"""publish() should fire-and-forget a Redis PUBLISH."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
registry = RedisNotificationRegistry(redis)
|
||||||
|
|
||||||
|
registry.publish("org_1", {"type": "verification_code_required"})
|
||||||
|
|
||||||
|
# Allow the fire-and-forget task to execute
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
redis.publish.assert_awaited_once_with(
|
||||||
|
"skyvern:notifications:org_1",
|
||||||
|
json.dumps({"type": "verification_code_required"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: _dispatch_local
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_local_fans_out_to_all_queues():
|
||||||
|
"""_dispatch_local should put the message into every local queue for the org."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
registry = RedisNotificationRegistry(redis)
|
||||||
|
q1 = registry.subscribe("org_1")
|
||||||
|
q2 = registry.subscribe("org_1")
|
||||||
|
|
||||||
|
msg = {"type": "test", "value": 42}
|
||||||
|
registry._dispatch_local("org_1", msg)
|
||||||
|
|
||||||
|
assert q1.get_nowait() == msg
|
||||||
|
assert q2.get_nowait() == msg
|
||||||
|
|
||||||
|
await registry.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_local_does_not_leak_across_orgs():
|
||||||
|
"""Messages dispatched for org_a should not appear in org_b queues."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
registry = RedisNotificationRegistry(redis)
|
||||||
|
q_a = registry.subscribe("org_a")
|
||||||
|
q_b = registry.subscribe("org_b")
|
||||||
|
|
||||||
|
registry._dispatch_local("org_a", {"type": "test"})
|
||||||
|
assert not q_a.empty()
|
||||||
|
assert q_b.empty()
|
||||||
|
|
||||||
|
await registry.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: close
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_cancels_all_listeners_and_clears_state():
|
||||||
|
"""close() should cancel every listener task and empty subscriber maps."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub(block=True)
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
registry = RedisNotificationRegistry(redis)
|
||||||
|
registry.subscribe("org_1")
|
||||||
|
registry.subscribe("org_2")
|
||||||
|
|
||||||
|
# Let the listener tasks start running
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
task_1 = registry._listener_tasks["org_1"]
|
||||||
|
task_2 = registry._listener_tasks["org_2"]
|
||||||
|
|
||||||
|
await registry.close()
|
||||||
|
|
||||||
|
# Wait for the tasks to fully complete after cancellation
|
||||||
|
await asyncio.gather(task_1, task_2, return_exceptions=True)
|
||||||
|
assert task_1.cancelled()
|
||||||
|
assert task_2.cancelled()
|
||||||
|
assert len(registry._listener_tasks) == 0
|
||||||
|
assert len(registry._subscribers) == 0
|
||||||
262
tests/unit_tests/test_redis_pubsub.py
Normal file
262
tests/unit_tests/test_redis_pubsub.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
"""Tests for RedisPubSub (generic pub/sub layer).
|
||||||
|
|
||||||
|
All tests use a mock Redis client — no real Redis instance required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_redis() -> MagicMock:
|
||||||
|
"""Return a mock redis.asyncio.Redis client."""
|
||||||
|
redis = MagicMock()
|
||||||
|
redis.publish = AsyncMock()
|
||||||
|
redis.pubsub = MagicMock()
|
||||||
|
return redis
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_pubsub(messages: list[dict] | None = None, *, block: bool = False) -> MagicMock:
|
||||||
|
"""Return a mock PubSub that yields *messages* from ``listen()``.
|
||||||
|
|
||||||
|
Each entry in *messages* should look like:
|
||||||
|
{"type": "message", "data": '{"key": "val"}'}
|
||||||
|
|
||||||
|
If *block* is True the async generator will hang forever after
|
||||||
|
exhausting *messages*, which keeps the listener task alive so that
|
||||||
|
cancellation semantics can be tested.
|
||||||
|
"""
|
||||||
|
pubsub = MagicMock()
|
||||||
|
pubsub.subscribe = AsyncMock()
|
||||||
|
pubsub.unsubscribe = AsyncMock()
|
||||||
|
pubsub.close = AsyncMock()
|
||||||
|
|
||||||
|
async def _listen():
|
||||||
|
for msg in messages or []:
|
||||||
|
yield msg
|
||||||
|
if block:
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
pubsub.listen = _listen
|
||||||
|
return pubsub
|
||||||
|
|
||||||
|
|
||||||
|
PREFIX = "skyvern:test:"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: subscribe
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_creates_queue_and_starts_listener():
|
||||||
|
"""subscribe() should return a queue and start a background listener task."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||||
|
queue = ps.subscribe("key_1")
|
||||||
|
|
||||||
|
assert isinstance(queue, asyncio.Queue)
|
||||||
|
assert "key_1" in ps._listener_tasks
|
||||||
|
task = ps._listener_tasks["key_1"]
|
||||||
|
assert isinstance(task, asyncio.Task)
|
||||||
|
|
||||||
|
await ps.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_reuses_listener_for_same_key():
|
||||||
|
"""A second subscribe for the same key should NOT create a new listener task."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||||
|
ps.subscribe("key_1")
|
||||||
|
first_task = ps._listener_tasks["key_1"]
|
||||||
|
|
||||||
|
ps.subscribe("key_1")
|
||||||
|
assert ps._listener_tasks["key_1"] is first_task
|
||||||
|
|
||||||
|
await ps.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: unsubscribe
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsubscribe_cancels_listener_when_last_subscriber_leaves():
|
||||||
|
"""When the last subscriber unsubscribes, the listener task should be cancelled."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub(block=True)
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||||
|
queue = ps.subscribe("key_1")
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
task = ps._listener_tasks["key_1"]
|
||||||
|
|
||||||
|
ps.unsubscribe("key_1", queue)
|
||||||
|
assert "key_1" not in ps._listener_tasks
|
||||||
|
|
||||||
|
await asyncio.gather(task, return_exceptions=True)
|
||||||
|
assert task.cancelled()
|
||||||
|
|
||||||
|
await ps.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsubscribe_keeps_listener_when_subscribers_remain():
|
||||||
|
"""If other subscribers remain, the listener task should stay alive."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||||
|
q1 = ps.subscribe("key_1")
|
||||||
|
ps.subscribe("key_1")
|
||||||
|
|
||||||
|
ps.unsubscribe("key_1", q1)
|
||||||
|
assert "key_1" in ps._listener_tasks
|
||||||
|
|
||||||
|
await ps.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: publish
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_calls_redis_publish():
|
||||||
|
"""publish() should fire-and-forget a Redis PUBLISH with prefixed channel."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||||
|
|
||||||
|
ps.publish("key_1", {"type": "event"})
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
redis.publish.assert_awaited_once_with(
|
||||||
|
f"{PREFIX}key_1",
|
||||||
|
json.dumps({"type": "event"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
await ps.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: _dispatch_local
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_local_fans_out_to_all_queues():
|
||||||
|
"""_dispatch_local should put the message into every local queue for the key."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||||
|
q1 = ps.subscribe("key_1")
|
||||||
|
q2 = ps.subscribe("key_1")
|
||||||
|
|
||||||
|
msg = {"type": "test", "value": 42}
|
||||||
|
ps._dispatch_local("key_1", msg)
|
||||||
|
|
||||||
|
assert q1.get_nowait() == msg
|
||||||
|
assert q2.get_nowait() == msg
|
||||||
|
|
||||||
|
await ps.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_local_does_not_leak_across_keys():
|
||||||
|
"""Messages dispatched for key_a should not appear in key_b queues."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub()
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||||
|
q_a = ps.subscribe("key_a")
|
||||||
|
q_b = ps.subscribe("key_b")
|
||||||
|
|
||||||
|
ps._dispatch_local("key_a", {"type": "test"})
|
||||||
|
assert not q_a.empty()
|
||||||
|
assert q_b.empty()
|
||||||
|
|
||||||
|
await ps.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: close
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_cancels_all_listeners_and_clears_state():
|
||||||
|
"""close() should cancel every listener task and empty subscriber maps."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
pubsub = _make_mock_pubsub(block=True)
|
||||||
|
redis.pubsub.return_value = pubsub
|
||||||
|
|
||||||
|
ps = RedisPubSub(redis, channel_prefix=PREFIX)
|
||||||
|
ps.subscribe("key_1")
|
||||||
|
ps.subscribe("key_2")
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
task_1 = ps._listener_tasks["key_1"]
|
||||||
|
task_2 = ps._listener_tasks["key_2"]
|
||||||
|
|
||||||
|
await ps.close()
|
||||||
|
|
||||||
|
await asyncio.gather(task_1, task_2, return_exceptions=True)
|
||||||
|
assert task_1.cancelled()
|
||||||
|
assert task_2.cancelled()
|
||||||
|
assert len(ps._listener_tasks) == 0
|
||||||
|
assert len(ps._subscribers) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: prefix isolation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_different_prefixes_do_not_interfere():
|
||||||
|
"""Two RedisPubSub instances with different prefixes use separate channels."""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
|
||||||
|
ps_a = RedisPubSub(redis, channel_prefix="prefix_a:")
|
||||||
|
ps_b = RedisPubSub(redis, channel_prefix="prefix_b:")
|
||||||
|
|
||||||
|
ps_a.publish("key_1", {"from": "a"})
|
||||||
|
ps_b.publish("key_1", {"from": "b"})
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
calls = redis.publish.await_args_list
|
||||||
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
channels = {call.args[0] for call in calls}
|
||||||
|
assert "prefix_a:key_1" in channels
|
||||||
|
assert "prefix_b:key_1" in channels
|
||||||
|
|
||||||
|
await ps_a.close()
|
||||||
|
await ps_b.close()
|
||||||
Reference in New Issue
Block a user