[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)
This commit is contained in:
@@ -849,6 +849,102 @@ class AgentDB(BaseAlchemyDB):
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_task_2fa_state(
|
||||
self,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
waiting_for_verification_code: bool,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
) -> Task:
|
||||
"""Update task 2FA verification code waiting state."""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
task.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
task.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
task.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if not waiting_for_verification_code:
|
||||
# Clear identifiers when no longer waiting
|
||||
task.verification_code_identifier = None
|
||||
task.verification_code_polling_started_at = None
|
||||
await session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
@read_retry()
|
||||
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
|
||||
"""Return active 2FA verification requests for an organization.
|
||||
|
||||
Queries both tasks and workflow runs where waiting_for_verification_code=True.
|
||||
Used to provide initial state when a WebSocket notification client connects.
|
||||
"""
|
||||
results: list[dict] = []
|
||||
async with self.Session() as session:
|
||||
# Tasks waiting for verification (exclude finalized tasks)
|
||||
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
|
||||
task_rows = (
|
||||
await session.scalars(
|
||||
select(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter_by(workflow_run_id=None)
|
||||
.filter(TaskModel.status.not_in(finalized_task_statuses))
|
||||
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for t in task_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": t.task_id,
|
||||
"workflow_run_id": None,
|
||||
"verification_code_identifier": t.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
t.verification_code_polling_started_at.isoformat()
|
||||
if t.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
# Workflow runs waiting for verification (exclude finalized runs)
|
||||
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
|
||||
wr_rows = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
|
||||
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for wr in wr_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": None,
|
||||
"workflow_run_id": wr.workflow_run_id,
|
||||
"verification_code_identifier": wr.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
wr.verification_code_polling_started_at.isoformat()
|
||||
if wr.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
async def bulk_update_tasks(
|
||||
self,
|
||||
task_ids: list[str],
|
||||
@@ -2794,6 +2890,9 @@ class AgentDB(BaseAlchemyDB):
|
||||
ai_fallback: bool | None = None,
|
||||
depends_on_workflow_run_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
waiting_for_verification_code: bool | None = None,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
) -> WorkflowRun:
|
||||
async with self.Session() as session:
|
||||
workflow_run = (
|
||||
@@ -2826,6 +2925,17 @@ class AgentDB(BaseAlchemyDB):
|
||||
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
|
||||
if browser_session_id:
|
||||
workflow_run.browser_session_id = browser_session_id
|
||||
# 2FA verification code waiting state updates
|
||||
if waiting_for_verification_code is not None:
|
||||
workflow_run.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
workflow_run.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if waiting_for_verification_code is not None and not waiting_for_verification_code:
|
||||
# Clear related fields when waiting is set to False
|
||||
workflow_run.verification_code_identifier = None
|
||||
workflow_run.verification_code_polling_started_at = None
|
||||
await session.commit()
|
||||
await save_workflow_run_logs(workflow_run_id)
|
||||
await session.refresh(workflow_run)
|
||||
@@ -3995,6 +4105,35 @@ class AgentDB(BaseAlchemyDB):
|
||||
totp_code = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
|
||||
|
||||
async def get_otp_codes_by_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
task_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
||||
limit: int = 1,
|
||||
) -> list[TOTPCode]:
|
||||
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
|
||||
|
||||
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
|
||||
The user submits codes manually via the UI, and this method finds them by run context.
|
||||
"""
|
||||
if not workflow_run_id and not task_id:
|
||||
return []
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(TOTPCodeModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
|
||||
)
|
||||
if workflow_run_id:
|
||||
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||
elif task_id:
|
||||
query = query.filter(TOTPCodeModel.task_id == task_id)
|
||||
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
||||
results = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(r) for r in results]
|
||||
|
||||
async def get_recent_otp_codes(
|
||||
self,
|
||||
organization_id: str,
|
||||
|
||||
@@ -116,6 +116,10 @@ class TaskModel(Base):
|
||||
model = Column(JSON, nullable=True)
|
||||
browser_address = Column(String, nullable=True)
|
||||
download_timeout = Column(Numeric, nullable=True)
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
|
||||
verification_code_identifier = Column(String, nullable=True)
|
||||
verification_code_polling_started_at = Column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class StepModel(Base):
|
||||
@@ -350,6 +354,10 @@ class WorkflowRunModel(Base):
|
||||
debug_session_id: Column = Column(String, nullable=True)
|
||||
ai_fallback = Column(Boolean, nullable=True)
|
||||
code_gen = Column(Boolean, nullable=True)
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code = Column(Boolean, nullable=False, default=False)
|
||||
verification_code_identifier = Column(String, nullable=True)
|
||||
verification_code_polling_started_at = Column(DateTime, nullable=True)
|
||||
|
||||
queued_at = Column(DateTime, nullable=True)
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
|
||||
@@ -211,6 +211,9 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
|
||||
browser_session_id=task_obj.browser_session_id,
|
||||
browser_address=task_obj.browser_address,
|
||||
download_timeout=task_obj.download_timeout,
|
||||
waiting_for_verification_code=task_obj.waiting_for_verification_code or False,
|
||||
verification_code_identifier=task_obj.verification_code_identifier,
|
||||
verification_code_polling_started_at=task_obj.verification_code_polling_started_at,
|
||||
)
|
||||
return task
|
||||
|
||||
@@ -424,6 +427,9 @@ def convert_to_workflow_run(
|
||||
run_with=workflow_run_model.run_with,
|
||||
code_gen=workflow_run_model.code_gen,
|
||||
ai_fallback=workflow_run_model.ai_fallback,
|
||||
waiting_for_verification_code=workflow_run_model.waiting_for_verification_code or False,
|
||||
verification_code_identifier=workflow_run_model.verification_code_identifier,
|
||||
verification_code_polling_started_at=workflow_run_model.verification_code_polling_started_at,
|
||||
)
|
||||
|
||||
|
||||
|
||||
21
skyvern/forge/sdk/notification/base.py
Normal file
21
skyvern/forge/sdk/notification/base.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Abstract base for notification registries."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseNotificationRegistry(ABC):
|
||||
"""Abstract pub/sub registry scoped by organization.
|
||||
|
||||
Implementations must fan-out: a single publish call delivers the
|
||||
message to every active subscriber for that organization.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]: ...
|
||||
|
||||
@abstractmethod
|
||||
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, organization_id: str, message: dict) -> None: ...
|
||||
14
skyvern/forge/sdk/notification/factory.py
Normal file
14
skyvern/forge/sdk/notification/factory.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
||||
|
||||
|
||||
class NotificationRegistryFactory:
|
||||
__registry: BaseNotificationRegistry = LocalNotificationRegistry()
|
||||
|
||||
@staticmethod
|
||||
def set_registry(registry: BaseNotificationRegistry) -> None:
|
||||
NotificationRegistryFactory.__registry = registry
|
||||
|
||||
@staticmethod
|
||||
def get_registry() -> BaseNotificationRegistry:
|
||||
return NotificationRegistryFactory.__registry
|
||||
45
skyvern/forge/sdk/notification/local.py
Normal file
45
skyvern/forge/sdk/notification/local.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""In-process notification registry using asyncio queues (single-pod only)."""
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LocalNotificationRegistry(BaseNotificationRegistry):
|
||||
"""In-process fan-out pub/sub using asyncio queues. Single-pod only."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
|
||||
|
||||
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
|
||||
queue: asyncio.Queue[dict] = asyncio.Queue()
|
||||
self._subscribers[organization_id].append(queue)
|
||||
LOG.info("Notification subscriber added", organization_id=organization_id)
|
||||
return queue
|
||||
|
||||
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
|
||||
queues = self._subscribers.get(organization_id)
|
||||
if queues:
|
||||
try:
|
||||
queues.remove(queue)
|
||||
except ValueError:
|
||||
pass
|
||||
if not queues:
|
||||
del self._subscribers[organization_id]
|
||||
LOG.info("Notification subscriber removed", organization_id=organization_id)
|
||||
|
||||
def publish(self, organization_id: str, message: dict) -> None:
|
||||
queues = self._subscribers.get(organization_id, [])
|
||||
for queue in queues:
|
||||
try:
|
||||
queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
LOG.warning(
|
||||
"Notification queue full, dropping message",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
55
skyvern/forge/sdk/notification/redis.py
Normal file
55
skyvern/forge/sdk/notification/redis.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Redis-backed notification registry for multi-pod deployments.
|
||||
|
||||
Thin adapter around :class:`RedisPubSub` — all Redis pub/sub logic
|
||||
lives in the generic layer; this class maps the ``organization_id``
|
||||
domain concept onto generic string keys.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from skyvern.forge.sdk.notification.base import BaseNotificationRegistry
|
||||
from skyvern.forge.sdk.redis.pubsub import RedisPubSub
|
||||
|
||||
|
||||
class RedisNotificationRegistry(BaseNotificationRegistry):
|
||||
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per org."""
|
||||
|
||||
def __init__(self, redis_client: Redis) -> None:
|
||||
self._pubsub = RedisPubSub(redis_client, channel_prefix="skyvern:notifications:")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Property accessors (used by existing tests)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def _listener_tasks(self) -> dict[str, asyncio.Task[None]]:
|
||||
return self._pubsub._listener_tasks
|
||||
|
||||
@property
|
||||
def _subscribers(self) -> dict[str, list[asyncio.Queue[dict]]]:
|
||||
return self._pubsub._subscribers
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, organization_id: str) -> asyncio.Queue[dict]:
|
||||
return self._pubsub.subscribe(organization_id)
|
||||
|
||||
def unsubscribe(self, organization_id: str, queue: asyncio.Queue[dict]) -> None:
|
||||
self._pubsub.unsubscribe(organization_id, queue)
|
||||
|
||||
def publish(self, organization_id: str, message: dict) -> None:
|
||||
self._pubsub.publish(organization_id, message)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._pubsub.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helper (exposed for tests)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _dispatch_local(self, organization_id: str, message: dict) -> None:
|
||||
self._pubsub._dispatch_local(organization_id, message)
|
||||
0
skyvern/forge/sdk/redis/__init__.py
Normal file
0
skyvern/forge/sdk/redis/__init__.py
Normal file
21
skyvern/forge/sdk/redis/factory.py
Normal file
21
skyvern/forge/sdk/redis/factory.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class RedisClientFactory:
|
||||
"""Singleton factory for a shared async Redis client.
|
||||
|
||||
Follows the same static set/get pattern as ``CacheFactory``.
|
||||
Defaults to ``None`` (no Redis in local/OSS mode).
|
||||
"""
|
||||
|
||||
__client: Redis | None = None
|
||||
|
||||
@staticmethod
|
||||
def set_client(client: Redis) -> None:
|
||||
RedisClientFactory.__client = client
|
||||
|
||||
@staticmethod
|
||||
def get_client() -> Redis | None:
|
||||
return RedisClientFactory.__client
|
||||
130
skyvern/forge/sdk/redis/pubsub.py
Normal file
130
skyvern/forge/sdk/redis/pubsub.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Generic Redis pub/sub layer.
|
||||
|
||||
Extracted from ``RedisNotificationRegistry`` so that any feature
|
||||
(notifications, events, cache invalidation, etc.) can reuse the same
|
||||
pattern with its own channel prefix.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import structlog
|
||||
from redis.asyncio import Redis
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class RedisPubSub:
|
||||
"""Fan-out pub/sub backed by Redis. One Redis PubSub channel per key."""
|
||||
|
||||
def __init__(self, redis_client: Redis, channel_prefix: str) -> None:
|
||||
self._redis = redis_client
|
||||
self._channel_prefix = channel_prefix
|
||||
self._subscribers: dict[str, list[asyncio.Queue[dict]]] = defaultdict(list)
|
||||
# One listener task per key channel
|
||||
self._listener_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, key: str) -> asyncio.Queue[dict]:
|
||||
queue: asyncio.Queue[dict] = asyncio.Queue()
|
||||
self._subscribers[key].append(queue)
|
||||
|
||||
# Spin up a Redis listener if this is the first local subscriber
|
||||
if key not in self._listener_tasks:
|
||||
task = asyncio.get_running_loop().create_task(self._listen(key))
|
||||
self._listener_tasks[key] = task
|
||||
|
||||
LOG.info("PubSub subscriber added", key=key, channel_prefix=self._channel_prefix)
|
||||
return queue
|
||||
|
||||
def unsubscribe(self, key: str, queue: asyncio.Queue[dict]) -> None:
|
||||
queues = self._subscribers.get(key)
|
||||
if queues:
|
||||
try:
|
||||
queues.remove(queue)
|
||||
except ValueError:
|
||||
pass
|
||||
if not queues:
|
||||
del self._subscribers[key]
|
||||
self._cancel_listener(key)
|
||||
LOG.info("PubSub subscriber removed", key=key, channel_prefix=self._channel_prefix)
|
||||
|
||||
def publish(self, key: str, message: dict) -> None:
|
||||
"""Fire-and-forget Redis PUBLISH."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._publish_to_redis(key, message))
|
||||
except RuntimeError:
|
||||
LOG.warning(
|
||||
"No running event loop; cannot publish via Redis",
|
||||
key=key,
|
||||
channel_prefix=self._channel_prefix,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Cancel all listener tasks and clear state. Call on shutdown."""
|
||||
for key in list(self._listener_tasks):
|
||||
self._cancel_listener(key)
|
||||
self._subscribers.clear()
|
||||
LOG.info("RedisPubSub closed", channel_prefix=self._channel_prefix)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _publish_to_redis(self, key: str, message: dict) -> None:
|
||||
channel = f"{self._channel_prefix}{key}"
|
||||
try:
|
||||
await self._redis.publish(channel, json.dumps(message))
|
||||
except Exception:
|
||||
LOG.exception("Failed to publish to Redis", key=key, channel_prefix=self._channel_prefix)
|
||||
|
||||
async def _listen(self, key: str) -> None:
|
||||
"""Subscribe to a Redis channel and fan out messages locally."""
|
||||
channel = f"{self._channel_prefix}{key}"
|
||||
pubsub = self._redis.pubsub()
|
||||
try:
|
||||
await pubsub.subscribe(channel)
|
||||
LOG.info("Redis listener started", channel=channel)
|
||||
async for raw_message in pubsub.listen():
|
||||
if raw_message["type"] != "message":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(raw_message["data"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
LOG.warning("Invalid JSON on Redis channel", channel=channel)
|
||||
continue
|
||||
self._dispatch_local(key, data)
|
||||
except asyncio.CancelledError:
|
||||
LOG.info("Redis listener cancelled", channel=channel)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception("Redis listener error", channel=channel)
|
||||
finally:
|
||||
try:
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.close()
|
||||
except Exception:
|
||||
LOG.warning("Error closing Redis pubsub", channel=channel)
|
||||
|
||||
def _dispatch_local(self, key: str, message: dict) -> None:
|
||||
"""Fan out a message to all local asyncio queues for this key."""
|
||||
queues = self._subscribers.get(key, [])
|
||||
for queue in queues:
|
||||
try:
|
||||
queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
LOG.warning(
|
||||
"Queue full, dropping message",
|
||||
key=key,
|
||||
channel_prefix=self._channel_prefix,
|
||||
)
|
||||
|
||||
def _cancel_listener(self, key: str) -> None:
|
||||
task = self._listener_tasks.pop(key, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
@@ -11,5 +11,6 @@ from skyvern.forge.sdk.routes import sdk # noqa: F401
|
||||
from skyvern.forge.sdk.routes import webhooks # noqa: F401
|
||||
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import notifications # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
|
||||
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401
|
||||
|
||||
@@ -131,10 +131,15 @@ async def send_totp_code(
|
||||
task = await app.DATABASE.get_task(data.task_id, curr_org.organization_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid task id: {data.task_id}")
|
||||
workflow_id_for_storage: str | None = None
|
||||
if data.workflow_id:
|
||||
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
|
||||
if data.workflow_id.startswith("wpid_"):
|
||||
workflow = await app.DATABASE.get_workflow_by_permanent_id(data.workflow_id, curr_org.organization_id)
|
||||
else:
|
||||
workflow = await app.DATABASE.get_workflow(data.workflow_id, curr_org.organization_id)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid workflow id: {data.workflow_id}")
|
||||
workflow_id_for_storage = workflow.workflow_id
|
||||
if data.workflow_run_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(data.workflow_run_id, curr_org.organization_id)
|
||||
if not workflow_run:
|
||||
@@ -162,7 +167,7 @@ async def send_totp_code(
|
||||
content=data.content,
|
||||
code=otp_value.value,
|
||||
task_id=data.task_id,
|
||||
workflow_id=data.workflow_id,
|
||||
workflow_id=workflow_id_for_storage,
|
||||
workflow_run_id=data.workflow_run_id,
|
||||
source=data.source,
|
||||
expired_at=data.expired_at,
|
||||
|
||||
101
skyvern/forge/sdk/routes/streaming/notifications.py
Normal file
101
skyvern/forge/sdk/routes/streaming/notifications.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""WebSocket endpoint for streaming global 2FA verification code notifications."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||
from skyvern.forge.sdk.routes.routers import base_router
|
||||
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
HEARTBEAT_INTERVAL = 60
|
||||
|
||||
|
||||
@base_router.websocket("/stream/notifications")
|
||||
async def notification_stream(
|
||||
websocket: WebSocket,
|
||||
apikey: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
auth = local_auth if settings.ENV == "local" else real_auth
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
if not organization_id:
|
||||
LOG.info("Notifications: Authentication failed")
|
||||
return
|
||||
|
||||
LOG.info("Notifications: Started streaming", organization_id=organization_id)
|
||||
registry = NotificationRegistryFactory.get_registry()
|
||||
queue = registry.subscribe(organization_id)
|
||||
|
||||
try:
|
||||
# Send initial state: all currently active verification requests
|
||||
active_requests = await app.DATABASE.get_active_verification_requests(organization_id)
|
||||
for req in active_requests:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "verification_code_required",
|
||||
"task_id": req.get("task_id"),
|
||||
"workflow_run_id": req.get("workflow_run_id"),
|
||||
"identifier": req.get("verification_code_identifier"),
|
||||
"polling_started_at": req.get("verification_code_polling_started_at"),
|
||||
}
|
||||
)
|
||||
|
||||
# Watch for client disconnect while streaming events
|
||||
disconnect_event = asyncio.Event()
|
||||
|
||||
async def _watch_disconnect() -> None:
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive()
|
||||
except (WebSocketDisconnect, ConnectionClosedOK, ConnectionClosedError):
|
||||
disconnect_event.set()
|
||||
|
||||
watcher = asyncio.create_task(_watch_disconnect())
|
||||
try:
|
||||
while not disconnect_event.is_set():
|
||||
queue_task = asyncio.ensure_future(asyncio.wait_for(queue.get(), timeout=HEARTBEAT_INTERVAL))
|
||||
disconnect_wait = asyncio.ensure_future(disconnect_event.wait())
|
||||
done, pending = await asyncio.wait({queue_task, disconnect_wait}, return_when=asyncio.FIRST_COMPLETED)
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
|
||||
if disconnect_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
message = queue_task.result()
|
||||
await websocket.send_json(message)
|
||||
except TimeoutError:
|
||||
try:
|
||||
await websocket.send_json({"type": "heartbeat"})
|
||||
except Exception:
|
||||
LOG.info(
|
||||
"Notifications: Client unreachable during heartbeat. Closing.",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
watcher.cancel()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info("Notifications: WebSocket disconnected", organization_id=organization_id)
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("Notifications: ConnectionClosedOK", organization_id=organization_id)
|
||||
except ConnectionClosedError:
|
||||
LOG.warning(
|
||||
"Notifications: ConnectionClosedError (client likely disconnected)", organization_id=organization_id
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning("Notifications: Error while streaming", organization_id=organization_id, exc_info=True)
|
||||
finally:
|
||||
registry.unsubscribe(organization_id, queue)
|
||||
LOG.info("Notifications: Connection closed", organization_id=organization_id)
|
||||
@@ -288,6 +288,10 @@ class Task(TaskBase):
|
||||
queued_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code: bool = False
|
||||
verification_code_identifier: str | None = None
|
||||
verification_code_polling_started_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def llm_key(self) -> str | None:
|
||||
@@ -365,6 +369,9 @@ class Task(TaskBase):
|
||||
max_screenshot_scrolls=self.max_screenshot_scrolls,
|
||||
step_count=step_count,
|
||||
browser_session_id=self.browser_session_id,
|
||||
waiting_for_verification_code=self.waiting_for_verification_code,
|
||||
verification_code_identifier=self.verification_code_identifier,
|
||||
verification_code_polling_started_at=self.verification_code_polling_started_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -392,6 +399,10 @@ class TaskResponse(BaseModel):
|
||||
max_screenshot_scrolls: int | None = None
|
||||
step_count: int | None = None
|
||||
browser_session_id: str | None = None
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code: bool = False
|
||||
verification_code_identifier: str | None = None
|
||||
verification_code_polling_started_at: datetime | None = None
|
||||
|
||||
|
||||
class TaskOutput(BaseModel):
|
||||
|
||||
@@ -172,6 +172,10 @@ class WorkflowRun(BaseModel):
|
||||
sequential_key: str | None = None
|
||||
ai_fallback: bool | None = None
|
||||
code_gen: bool | None = None
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code: bool = False
|
||||
verification_code_identifier: str | None = None
|
||||
verification_code_polling_started_at: datetime | None = None
|
||||
|
||||
queued_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
@@ -226,6 +230,10 @@ class WorkflowRunResponseBase(BaseModel):
|
||||
browser_address: str | None = None
|
||||
script_run: ScriptRunResponse | None = None
|
||||
errors: list[dict[str, Any]] | None = None
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code: bool = False
|
||||
verification_code_identifier: str | None = None
|
||||
verification_code_polling_started_at: datetime | None = None
|
||||
|
||||
|
||||
class WorkflowRunWithWorkflowResponse(WorkflowRunResponseBase):
|
||||
|
||||
@@ -3019,6 +3019,10 @@ class WorkflowService:
|
||||
browser_address=workflow_run.browser_address,
|
||||
script_run=workflow_run.script_run,
|
||||
errors=errors,
|
||||
# 2FA verification code waiting state fields
|
||||
waiting_for_verification_code=workflow_run.waiting_for_verification_code,
|
||||
verification_code_identifier=workflow_run.verification_code_identifier,
|
||||
verification_code_polling_started_at=workflow_run.verification_code_polling_started_at,
|
||||
)
|
||||
|
||||
async def clean_up_workflow(
|
||||
|
||||
Reference in New Issue
Block a user