[SKY-7980] Patch Credential TOTP Over Webhook Logic (#4811)
This commit is contained in:
@@ -4529,7 +4529,7 @@ class ForgeAgent:
|
|||||||
# Try credential TOTP first (highest priority, doesn't need totp_url/totp_identifier)
|
# Try credential TOTP first (highest priority, doesn't need totp_url/totp_identifier)
|
||||||
otp_value = try_generate_totp_from_credential(task.workflow_run_id)
|
otp_value = try_generate_totp_from_credential(task.workflow_run_id)
|
||||||
# Fall back to webhook/totp_identifier
|
# Fall back to webhook/totp_identifier
|
||||||
if not otp_value and (task.totp_verification_url or task.totp_identifier):
|
if not otp_value:
|
||||||
workflow_id = workflow_permanent_id = None
|
workflow_id = workflow_permanent_id = None
|
||||||
if task.workflow_run_id:
|
if task.workflow_run_id:
|
||||||
workflow_run = await app.DATABASE.get_workflow_run(task.workflow_run_id)
|
workflow_run = await app.DATABASE.get_workflow_run(task.workflow_run_id)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
|||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||||
from skyvern.forge.sdk.routes.routers import base_router
|
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||||
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
||||||
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
||||||
|
|
||||||
@@ -22,6 +22,23 @@ async def notification_stream(
|
|||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
apikey: str | None = None,
|
apikey: str | None = None,
|
||||||
token: str | None = None,
|
token: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
return await _notification_stream_handler(websocket=websocket, apikey=apikey, token=token)
|
||||||
|
|
||||||
|
|
||||||
|
@legacy_base_router.websocket("/stream/notifications")
|
||||||
|
async def notification_stream_legacy(
|
||||||
|
websocket: WebSocket,
|
||||||
|
apikey: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
return await _notification_stream_handler(websocket=websocket, apikey=apikey, token=token)
|
||||||
|
|
||||||
|
|
||||||
|
async def _notification_stream_handler(
|
||||||
|
websocket: WebSocket,
|
||||||
|
apikey: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
auth = local_auth if settings.ENV == "local" else real_auth
|
auth = local_auth if settings.ENV == "local" else real_auth
|
||||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||||
|
|||||||
@@ -916,7 +916,7 @@ async def generate_cua_fallback_actions(
|
|||||||
# Try credential TOTP first (highest priority, doesn't need totp_url/totp_identifier)
|
# Try credential TOTP first (highest priority, doesn't need totp_url/totp_identifier)
|
||||||
otp_value = try_generate_totp_from_credential(task.workflow_run_id)
|
otp_value = try_generate_totp_from_credential(task.workflow_run_id)
|
||||||
# Fall back to webhook/totp_identifier
|
# Fall back to webhook/totp_identifier
|
||||||
if not otp_value and (task.totp_verification_url or task.totp_identifier) and task.organization_id:
|
if not otp_value and task.organization_id:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Getting verification code for CUA",
|
"Getting verification code for CUA",
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ from credential secrets stored in workflow run context, and that callers check
|
|||||||
credential TOTP before falling back to poll_otp_value.
|
credential TOTP before falling back to poll_otp_value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pyotp
|
import pyotp
|
||||||
|
import pytest
|
||||||
|
|
||||||
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
||||||
from skyvern.services.otp_service import OTPValue, try_generate_totp_from_credential
|
from skyvern.services.otp_service import OTPValue, try_generate_totp_from_credential
|
||||||
@@ -165,3 +166,135 @@ class TestTryGenerateTotpFromCredential:
|
|||||||
mock_app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context.return_value = ctx
|
mock_app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context.return_value = ctx
|
||||||
result = try_generate_totp_from_credential("wfr_123")
|
result = try_generate_totp_from_credential("wfr_123")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPollOtpCalledWithoutTotpConfig:
|
||||||
|
"""Tests that poll_otp_value is called even when totp_verification_url and
|
||||||
|
totp_identifier are both None — i.e. the manual 2FA code submission path."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_otp_called_when_no_totp_config_cua_fallback(self) -> None:
|
||||||
|
"""When credential TOTP returns None and no totp_verification_url or
|
||||||
|
totp_identifier is set, poll_otp_value should still be called via
|
||||||
|
generate_cua_fallback_actions so the manual code submission flow works."""
|
||||||
|
mock_task = MagicMock()
|
||||||
|
mock_task.totp_verification_url = None
|
||||||
|
mock_task.totp_identifier = None
|
||||||
|
mock_task.organization_id = "org_123"
|
||||||
|
mock_task.task_id = "task_456"
|
||||||
|
mock_task.workflow_run_id = None
|
||||||
|
mock_task.navigation_goal = "test goal"
|
||||||
|
|
||||||
|
mock_step = MagicMock()
|
||||||
|
mock_step.step_id = "step_789"
|
||||||
|
mock_step.order = 0
|
||||||
|
|
||||||
|
mock_otp_value = MagicMock(spec=OTPValue)
|
||||||
|
mock_otp_value.get_otp_type.return_value = OTPType.TOTP
|
||||||
|
mock_otp_value.value = "123456"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"skyvern.webeye.actions.parse_actions.try_generate_totp_from_credential",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"skyvern.webeye.actions.parse_actions.poll_otp_value",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_otp_value,
|
||||||
|
) as mock_poll,
|
||||||
|
patch("skyvern.webeye.actions.parse_actions.app") as mock_app,
|
||||||
|
patch("skyvern.webeye.actions.parse_actions.prompt_engine") as mock_prompt_engine,
|
||||||
|
):
|
||||||
|
mock_prompt_engine.load_prompt.return_value = "test prompt"
|
||||||
|
# LLM returns get_verification_code action
|
||||||
|
mock_app.LLM_API_HANDLER = AsyncMock(
|
||||||
|
return_value={"action": "get_verification_code", "useful_information": "Need 2FA code"},
|
||||||
|
)
|
||||||
|
|
||||||
|
from skyvern.webeye.actions.parse_actions import generate_cua_fallback_actions
|
||||||
|
|
||||||
|
actions = await generate_cua_fallback_actions(
|
||||||
|
task=mock_task,
|
||||||
|
step=mock_step,
|
||||||
|
assistant_message="Enter verification code",
|
||||||
|
reasoning="Need 2FA code",
|
||||||
|
)
|
||||||
|
|
||||||
|
# poll_otp_value should have been called even though
|
||||||
|
# totp_verification_url and totp_identifier are both None
|
||||||
|
mock_poll.assert_called_once_with(
|
||||||
|
organization_id="org_123",
|
||||||
|
task_id="task_456",
|
||||||
|
workflow_run_id=None,
|
||||||
|
totp_verification_url=None,
|
||||||
|
totp_identifier=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got a VerificationCodeAction back
|
||||||
|
assert len(actions) == 1
|
||||||
|
from skyvern.webeye.actions.actions import VerificationCodeAction
|
||||||
|
|
||||||
|
assert isinstance(actions[0], VerificationCodeAction)
|
||||||
|
assert actions[0].verification_code == "123456"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_otp_called_when_no_totp_config_agent(self) -> None:
|
||||||
|
"""When credential TOTP returns None and no totp_verification_url or
|
||||||
|
totp_identifier is set, poll_otp_value should still be called via the
|
||||||
|
agent's handle_potential_verification_code so the manual code submission
|
||||||
|
flow works."""
|
||||||
|
# Return None from poll_otp_value so we hit the early return at line 4548
|
||||||
|
# (no valid OTP) — this avoids needing to mock the deeper context/LLM calls
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"skyvern.forge.agent.try_generate_totp_from_credential",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"skyvern.forge.agent.poll_otp_value",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
) as mock_poll,
|
||||||
|
patch("skyvern.forge.agent.app") as mock_app,
|
||||||
|
):
|
||||||
|
mock_app.DATABASE.get_workflow_run = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
from skyvern.forge.agent import ForgeAgent
|
||||||
|
|
||||||
|
agent = ForgeAgent.__new__(ForgeAgent)
|
||||||
|
|
||||||
|
mock_task = MagicMock()
|
||||||
|
mock_task.totp_verification_url = None
|
||||||
|
mock_task.totp_identifier = None
|
||||||
|
mock_task.organization_id = "org_123"
|
||||||
|
mock_task.task_id = "task_456"
|
||||||
|
mock_task.workflow_run_id = None
|
||||||
|
|
||||||
|
json_response = {
|
||||||
|
"place_to_enter_verification_code": True,
|
||||||
|
"should_enter_verification_code": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await agent.handle_potential_verification_code(
|
||||||
|
task=mock_task,
|
||||||
|
step=MagicMock(),
|
||||||
|
scraped_page=MagicMock(),
|
||||||
|
browser_state=MagicMock(),
|
||||||
|
json_response=json_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
# poll_otp_value should have been called even though
|
||||||
|
# totp_verification_url and totp_identifier are both None
|
||||||
|
mock_poll.assert_called_once_with(
|
||||||
|
organization_id="org_123",
|
||||||
|
task_id="task_456",
|
||||||
|
workflow_id=None,
|
||||||
|
workflow_run_id=None,
|
||||||
|
workflow_permanent_id=None,
|
||||||
|
totp_verification_url=None,
|
||||||
|
totp_identifier=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# When poll_otp_value returns None, the method returns json_response unchanged
|
||||||
|
assert result == json_response
|
||||||
|
|||||||
Reference in New Issue
Block a user