[SKY-7980] Patch Credential TOTP Over Webhook Logic (#4811)

This commit is contained in:
Aaron Perez
2026-02-19 18:14:44 -05:00
committed by GitHub
parent f80451f37a
commit f8f9d2a17f
4 changed files with 154 additions and 4 deletions

View File

@@ -4529,7 +4529,7 @@ class ForgeAgent:
# Try credential TOTP first (highest priority, doesn't need totp_url/totp_identifier) # Try credential TOTP first (highest priority, doesn't need totp_url/totp_identifier)
otp_value = try_generate_totp_from_credential(task.workflow_run_id) otp_value = try_generate_totp_from_credential(task.workflow_run_id)
# Fall back to webhook/totp_identifier # Fall back to webhook/totp_identifier
if not otp_value and (task.totp_verification_url or task.totp_identifier): if not otp_value:
workflow_id = workflow_permanent_id = None workflow_id = workflow_permanent_id = None
if task.workflow_run_id: if task.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(task.workflow_run_id) workflow_run = await app.DATABASE.get_workflow_run(task.workflow_run_id)

View File

@@ -9,7 +9,7 @@ from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from skyvern.config import settings from skyvern.config import settings
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.routes.routers import base_router from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
@@ -22,6 +22,23 @@ async def notification_stream(
websocket: WebSocket, websocket: WebSocket,
apikey: str | None = None, apikey: str | None = None,
token: str | None = None, token: str | None = None,
) -> None:
return await _notification_stream_handler(websocket=websocket, apikey=apikey, token=token)
@legacy_base_router.websocket("/stream/notifications")
async def notification_stream_legacy(
websocket: WebSocket,
apikey: str | None = None,
token: str | None = None,
) -> None:
return await _notification_stream_handler(websocket=websocket, apikey=apikey, token=token)
async def _notification_stream_handler(
websocket: WebSocket,
apikey: str | None = None,
token: str | None = None,
) -> None: ) -> None:
auth = local_auth if settings.ENV == "local" else real_auth auth = local_auth if settings.ENV == "local" else real_auth
organization_id = await auth(apikey=apikey, token=token, websocket=websocket) organization_id = await auth(apikey=apikey, token=token, websocket=websocket)

View File

@@ -916,7 +916,7 @@ async def generate_cua_fallback_actions(
# Try credential TOTP first (highest priority, doesn't need totp_url/totp_identifier) # Try credential TOTP first (highest priority, doesn't need totp_url/totp_identifier)
otp_value = try_generate_totp_from_credential(task.workflow_run_id) otp_value = try_generate_totp_from_credential(task.workflow_run_id)
# Fall back to webhook/totp_identifier # Fall back to webhook/totp_identifier
if not otp_value and (task.totp_verification_url or task.totp_identifier) and task.organization_id: if not otp_value and task.organization_id:
LOG.info( LOG.info(
"Getting verification code for CUA", "Getting verification code for CUA",
task_id=task.task_id, task_id=task.task_id,

View File

@@ -5,9 +5,10 @@ from credential secrets stored in workflow run context, and that callers check
credential TOTP before falling back to poll_otp_value. credential TOTP before falling back to poll_otp_value.
""" """
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pyotp import pyotp
import pytest
from skyvern.forge.sdk.schemas.totp_codes import OTPType from skyvern.forge.sdk.schemas.totp_codes import OTPType
from skyvern.services.otp_service import OTPValue, try_generate_totp_from_credential from skyvern.services.otp_service import OTPValue, try_generate_totp_from_credential
@@ -165,3 +166,135 @@ class TestTryGenerateTotpFromCredential:
mock_app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context.return_value = ctx mock_app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context.return_value = ctx
result = try_generate_totp_from_credential("wfr_123") result = try_generate_totp_from_credential("wfr_123")
assert result is None assert result is None
class TestPollOtpCalledWithoutTotpConfig:
"""Tests that poll_otp_value is called even when totp_verification_url and
totp_identifier are both None — i.e. the manual 2FA code submission path."""
@pytest.mark.asyncio
async def test_poll_otp_called_when_no_totp_config_cua_fallback(self) -> None:
"""When credential TOTP returns None and no totp_verification_url or
totp_identifier is set, poll_otp_value should still be called via
generate_cua_fallback_actions so the manual code submission flow works."""
mock_task = MagicMock()
mock_task.totp_verification_url = None
mock_task.totp_identifier = None
mock_task.organization_id = "org_123"
mock_task.task_id = "task_456"
mock_task.workflow_run_id = None
mock_task.navigation_goal = "test goal"
mock_step = MagicMock()
mock_step.step_id = "step_789"
mock_step.order = 0
mock_otp_value = MagicMock(spec=OTPValue)
mock_otp_value.get_otp_type.return_value = OTPType.TOTP
mock_otp_value.value = "123456"
with (
patch(
"skyvern.webeye.actions.parse_actions.try_generate_totp_from_credential",
return_value=None,
),
patch(
"skyvern.webeye.actions.parse_actions.poll_otp_value",
new_callable=AsyncMock,
return_value=mock_otp_value,
) as mock_poll,
patch("skyvern.webeye.actions.parse_actions.app") as mock_app,
patch("skyvern.webeye.actions.parse_actions.prompt_engine") as mock_prompt_engine,
):
mock_prompt_engine.load_prompt.return_value = "test prompt"
# LLM returns get_verification_code action
mock_app.LLM_API_HANDLER = AsyncMock(
return_value={"action": "get_verification_code", "useful_information": "Need 2FA code"},
)
from skyvern.webeye.actions.parse_actions import generate_cua_fallback_actions
actions = await generate_cua_fallback_actions(
task=mock_task,
step=mock_step,
assistant_message="Enter verification code",
reasoning="Need 2FA code",
)
# poll_otp_value should have been called even though
# totp_verification_url and totp_identifier are both None
mock_poll.assert_called_once_with(
organization_id="org_123",
task_id="task_456",
workflow_run_id=None,
totp_verification_url=None,
totp_identifier=None,
)
# Verify we got a VerificationCodeAction back
assert len(actions) == 1
from skyvern.webeye.actions.actions import VerificationCodeAction
assert isinstance(actions[0], VerificationCodeAction)
assert actions[0].verification_code == "123456"
@pytest.mark.asyncio
async def test_poll_otp_called_when_no_totp_config_agent(self) -> None:
"""When credential TOTP returns None and no totp_verification_url or
totp_identifier is set, poll_otp_value should still be called via the
agent's handle_potential_verification_code so the manual code submission
flow works."""
# Return None from poll_otp_value so we hit the early return at line 4548
# (no valid OTP) — this avoids needing to mock the deeper context/LLM calls
with (
patch(
"skyvern.forge.agent.try_generate_totp_from_credential",
return_value=None,
),
patch(
"skyvern.forge.agent.poll_otp_value",
new_callable=AsyncMock,
return_value=None,
) as mock_poll,
patch("skyvern.forge.agent.app") as mock_app,
):
mock_app.DATABASE.get_workflow_run = AsyncMock(return_value=None)
from skyvern.forge.agent import ForgeAgent
agent = ForgeAgent.__new__(ForgeAgent)
mock_task = MagicMock()
mock_task.totp_verification_url = None
mock_task.totp_identifier = None
mock_task.organization_id = "org_123"
mock_task.task_id = "task_456"
mock_task.workflow_run_id = None
json_response = {
"place_to_enter_verification_code": True,
"should_enter_verification_code": True,
}
result = await agent.handle_potential_verification_code(
task=mock_task,
step=MagicMock(),
scraped_page=MagicMock(),
browser_state=MagicMock(),
json_response=json_response,
)
# poll_otp_value should have been called even though
# totp_verification_url and totp_identifier are both None
mock_poll.assert_called_once_with(
organization_id="org_123",
task_id="task_456",
workflow_id=None,
workflow_run_id=None,
workflow_permanent_id=None,
totp_verification_url=None,
totp_identifier=None,
)
# When poll_otp_value returns None, the method returns json_response unchanged
assert result == json_response