[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)

This commit is contained in:
Aaron Perez
2026-02-18 17:21:58 -05:00
committed by GitHub
parent b48bf707c3
commit e3b6d22fb6
28 changed files with 1989 additions and 41 deletions

View File

@@ -849,6 +849,102 @@ class AgentDB(BaseAlchemyDB):
LOG.error("UnexpectedError", exc_info=True)
raise
async def update_task_2fa_state(
self,
task_id: str,
organization_id: str,
waiting_for_verification_code: bool,
verification_code_identifier: str | None = None,
verification_code_polling_started_at: datetime | None = None,
) -> Task:
"""Update task 2FA verification code waiting state."""
try:
async with self.Session() as session:
if task := (
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
task.waiting_for_verification_code = waiting_for_verification_code
if verification_code_identifier is not None:
task.verification_code_identifier = verification_code_identifier
if verification_code_polling_started_at is not None:
task.verification_code_polling_started_at = verification_code_polling_started_at
if not waiting_for_verification_code:
# Clear identifiers when no longer waiting
task.verification_code_identifier = None
task.verification_code_polling_started_at = None
await session.commit()
updated_task = await self.get_task(task_id, organization_id=organization_id)
if not updated_task:
raise NotFoundError("Task not found")
return updated_task
else:
raise NotFoundError("Task not found")
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
@read_retry()
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
"""Return active 2FA verification requests for an organization.
Queries both tasks and workflow runs where waiting_for_verification_code=True.
Used to provide initial state when a WebSocket notification client connects.
"""
results: list[dict] = []
async with self.Session() as session:
# Tasks waiting for verification (exclude finalized tasks)
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
task_rows = (
await session.scalars(
select(TaskModel)
.filter_by(organization_id=organization_id)
.filter_by(waiting_for_verification_code=True)
.filter_by(workflow_run_id=None)
.filter(TaskModel.status.not_in(finalized_task_statuses))
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
)
).all()
for t in task_rows:
results.append(
{
"task_id": t.task_id,
"workflow_run_id": None,
"verification_code_identifier": t.verification_code_identifier,
"verification_code_polling_started_at": (
t.verification_code_polling_started_at.isoformat()
if t.verification_code_polling_started_at
else None
),
}
)
# Workflow runs waiting for verification (exclude finalized runs)
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
wr_rows = (
await session.scalars(
select(WorkflowRunModel)
.filter_by(organization_id=organization_id)
.filter_by(waiting_for_verification_code=True)
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
)
).all()
for wr in wr_rows:
results.append(
{
"task_id": None,
"workflow_run_id": wr.workflow_run_id,
"verification_code_identifier": wr.verification_code_identifier,
"verification_code_polling_started_at": (
wr.verification_code_polling_started_at.isoformat()
if wr.verification_code_polling_started_at
else None
),
}
)
return results
async def bulk_update_tasks(
self,
task_ids: list[str],
@@ -2794,6 +2890,9 @@ class AgentDB(BaseAlchemyDB):
ai_fallback: bool | None = None,
depends_on_workflow_run_id: str | None = None,
browser_session_id: str | None = None,
waiting_for_verification_code: bool | None = None,
verification_code_identifier: str | None = None,
verification_code_polling_started_at: datetime | None = None,
) -> WorkflowRun:
async with self.Session() as session:
workflow_run = (
@@ -2826,6 +2925,17 @@ class AgentDB(BaseAlchemyDB):
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
if browser_session_id:
workflow_run.browser_session_id = browser_session_id
# 2FA verification code waiting state updates
if waiting_for_verification_code is not None:
workflow_run.waiting_for_verification_code = waiting_for_verification_code
if verification_code_identifier is not None:
workflow_run.verification_code_identifier = verification_code_identifier
if verification_code_polling_started_at is not None:
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
if waiting_for_verification_code is not None and not waiting_for_verification_code:
# Clear related fields when waiting is set to False
workflow_run.verification_code_identifier = None
workflow_run.verification_code_polling_started_at = None
await session.commit()
await save_workflow_run_logs(workflow_run_id)
await session.refresh(workflow_run)
@@ -3995,6 +4105,35 @@ class AgentDB(BaseAlchemyDB):
totp_code = (await session.scalars(query)).all()
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
async def get_otp_codes_by_run(
self,
organization_id: str,
task_id: str | None = None,
workflow_run_id: str | None = None,
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
limit: int = 1,
) -> list[TOTPCode]:
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
The user submits codes manually via the UI, and this method finds them by run context.
"""
if not workflow_run_id and not task_id:
return []
async with self.Session() as session:
query = (
select(TOTPCodeModel)
.filter_by(organization_id=organization_id)
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
)
if workflow_run_id:
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
elif task_id:
query = query.filter(TOTPCodeModel.task_id == task_id)
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
results = (await session.scalars(query)).all()
return [TOTPCode.model_validate(r) for r in results]
async def get_recent_otp_codes(
self,
organization_id: str,