[SKY-6] Backend: Enable 2FA code detection without TOTP credentials (#4786)
This commit is contained in:
@@ -849,6 +849,102 @@ class AgentDB(BaseAlchemyDB):
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_task_2fa_state(
|
||||
self,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
waiting_for_verification_code: bool,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
) -> Task:
|
||||
"""Update task 2FA verification code waiting state."""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
if task := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
task.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
task.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
task.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if not waiting_for_verification_code:
|
||||
# Clear identifiers when no longer waiting
|
||||
task.verification_code_identifier = None
|
||||
task.verification_code_polling_started_at = None
|
||||
await session.commit()
|
||||
updated_task = await self.get_task(task_id, organization_id=organization_id)
|
||||
if not updated_task:
|
||||
raise NotFoundError("Task not found")
|
||||
return updated_task
|
||||
else:
|
||||
raise NotFoundError("Task not found")
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
@read_retry()
|
||||
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
|
||||
"""Return active 2FA verification requests for an organization.
|
||||
|
||||
Queries both tasks and workflow runs where waiting_for_verification_code=True.
|
||||
Used to provide initial state when a WebSocket notification client connects.
|
||||
"""
|
||||
results: list[dict] = []
|
||||
async with self.Session() as session:
|
||||
# Tasks waiting for verification (exclude finalized tasks)
|
||||
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
|
||||
task_rows = (
|
||||
await session.scalars(
|
||||
select(TaskModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter_by(workflow_run_id=None)
|
||||
.filter(TaskModel.status.not_in(finalized_task_statuses))
|
||||
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for t in task_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": t.task_id,
|
||||
"workflow_run_id": None,
|
||||
"verification_code_identifier": t.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
t.verification_code_polling_started_at.isoformat()
|
||||
if t.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
# Workflow runs waiting for verification (exclude finalized runs)
|
||||
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
|
||||
wr_rows = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(waiting_for_verification_code=True)
|
||||
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
|
||||
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
||||
)
|
||||
).all()
|
||||
for wr in wr_rows:
|
||||
results.append(
|
||||
{
|
||||
"task_id": None,
|
||||
"workflow_run_id": wr.workflow_run_id,
|
||||
"verification_code_identifier": wr.verification_code_identifier,
|
||||
"verification_code_polling_started_at": (
|
||||
wr.verification_code_polling_started_at.isoformat()
|
||||
if wr.verification_code_polling_started_at
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
async def bulk_update_tasks(
|
||||
self,
|
||||
task_ids: list[str],
|
||||
@@ -2794,6 +2890,9 @@ class AgentDB(BaseAlchemyDB):
|
||||
ai_fallback: bool | None = None,
|
||||
depends_on_workflow_run_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
waiting_for_verification_code: bool | None = None,
|
||||
verification_code_identifier: str | None = None,
|
||||
verification_code_polling_started_at: datetime | None = None,
|
||||
) -> WorkflowRun:
|
||||
async with self.Session() as session:
|
||||
workflow_run = (
|
||||
@@ -2826,6 +2925,17 @@ class AgentDB(BaseAlchemyDB):
|
||||
workflow_run.depends_on_workflow_run_id = depends_on_workflow_run_id
|
||||
if browser_session_id:
|
||||
workflow_run.browser_session_id = browser_session_id
|
||||
# 2FA verification code waiting state updates
|
||||
if waiting_for_verification_code is not None:
|
||||
workflow_run.waiting_for_verification_code = waiting_for_verification_code
|
||||
if verification_code_identifier is not None:
|
||||
workflow_run.verification_code_identifier = verification_code_identifier
|
||||
if verification_code_polling_started_at is not None:
|
||||
workflow_run.verification_code_polling_started_at = verification_code_polling_started_at
|
||||
if waiting_for_verification_code is not None and not waiting_for_verification_code:
|
||||
# Clear related fields when waiting is set to False
|
||||
workflow_run.verification_code_identifier = None
|
||||
workflow_run.verification_code_polling_started_at = None
|
||||
await session.commit()
|
||||
await save_workflow_run_logs(workflow_run_id)
|
||||
await session.refresh(workflow_run)
|
||||
@@ -3995,6 +4105,35 @@ class AgentDB(BaseAlchemyDB):
|
||||
totp_code = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]
|
||||
|
||||
async def get_otp_codes_by_run(
|
||||
self,
|
||||
organization_id: str,
|
||||
task_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
||||
limit: int = 1,
|
||||
) -> list[TOTPCode]:
|
||||
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
|
||||
|
||||
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
|
||||
The user submits codes manually via the UI, and this method finds them by run context.
|
||||
"""
|
||||
if not workflow_run_id and not task_id:
|
||||
return []
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(TOTPCodeModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
|
||||
)
|
||||
if workflow_run_id:
|
||||
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
||||
elif task_id:
|
||||
query = query.filter(TOTPCodeModel.task_id == task_id)
|
||||
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
||||
results = (await session.scalars(query)).all()
|
||||
return [TOTPCode.model_validate(r) for r in results]
|
||||
|
||||
async def get_recent_otp_codes(
|
||||
self,
|
||||
organization_id: str,
|
||||
|
||||
Reference in New Issue
Block a user