545 lines
18 KiB
Python
545 lines
18 KiB
Python
"""Tests for manual 2FA input without pre-configured TOTP credentials (SKY-6)."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from skyvern.constants import SPECIAL_FIELD_VERIFICATION_CODE
|
|
from skyvern.forge.agent import ForgeAgent
|
|
from skyvern.forge.sdk.db.agent_db import AgentDB
|
|
from skyvern.forge.sdk.notification.local import LocalNotificationRegistry
|
|
from skyvern.forge.sdk.routes.credentials import send_totp_code
|
|
from skyvern.forge.sdk.schemas.totp_codes import TOTPCodeCreate
|
|
from skyvern.schemas.runs import RunEngine
|
|
from skyvern.services.otp_service import OTPValue, _get_otp_value_by_run, poll_otp_value
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_otp_codes_by_run_exists():
|
|
"""get_otp_codes_by_run should exist on AgentDB."""
|
|
assert hasattr(AgentDB, "get_otp_codes_by_run"), "AgentDB missing get_otp_codes_by_run method"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_otp_codes_by_run_returns_empty_without_identifiers():
|
|
"""get_otp_codes_by_run should return [] when neither task_id nor workflow_run_id is given."""
|
|
db = AgentDB.__new__(AgentDB)
|
|
result = await db.get_otp_codes_by_run(
|
|
organization_id="org_1",
|
|
)
|
|
assert result == []
|
|
|
|
|
|
# === Task 2: _get_otp_value_by_run OTP service function ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_otp_value_by_run_returns_code():
|
|
"""_get_otp_value_by_run should find OTP codes by task_id."""
|
|
mock_code = MagicMock()
|
|
mock_code.code = "123456"
|
|
mock_code.otp_type = "totp"
|
|
|
|
mock_db = AsyncMock()
|
|
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.DATABASE = mock_db
|
|
|
|
with patch("skyvern.services.otp_service.app", new=mock_app):
|
|
result = await _get_otp_value_by_run(
|
|
organization_id="org_1",
|
|
task_id="tsk_1",
|
|
)
|
|
assert result is not None
|
|
assert result.value == "123456"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_otp_value_by_run_returns_none_when_no_codes():
|
|
"""_get_otp_value_by_run should return None when no codes found."""
|
|
mock_db = AsyncMock()
|
|
mock_db.get_otp_codes_by_run.return_value = []
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.DATABASE = mock_db
|
|
|
|
with patch("skyvern.services.otp_service.app", new=mock_app):
|
|
result = await _get_otp_value_by_run(
|
|
organization_id="org_1",
|
|
task_id="tsk_1",
|
|
)
|
|
assert result is None
|
|
|
|
|
|
# === Task 3: poll_otp_value without identifier ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_poll_otp_value_without_identifier_uses_run_lookup():
|
|
"""poll_otp_value should use _get_otp_value_by_run when no identifier/URL provided."""
|
|
mock_code = MagicMock()
|
|
mock_code.code = "123456"
|
|
mock_code.otp_type = "totp"
|
|
|
|
mock_db = AsyncMock()
|
|
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
|
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
|
mock_db.update_task_2fa_state = AsyncMock()
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.DATABASE = mock_db
|
|
|
|
with (
|
|
patch("skyvern.services.otp_service.app", new=mock_app),
|
|
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
|
):
|
|
result = await poll_otp_value(
|
|
organization_id="org_1",
|
|
task_id="tsk_1",
|
|
)
|
|
assert result is not None
|
|
assert result.value == "123456"
|
|
|
|
|
|
# === Task 6: Integration test — handle_potential_OTP_actions without TOTP config ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_potential_OTP_actions_without_totp_config():
|
|
"""When LLM detects 2FA but no TOTP config exists, should still enter verification flow."""
|
|
agent = ForgeAgent.__new__(ForgeAgent)
|
|
|
|
task = MagicMock()
|
|
task.organization_id = "org_1"
|
|
task.totp_verification_url = None
|
|
task.totp_identifier = None
|
|
task.task_id = "tsk_1"
|
|
task.workflow_run_id = None
|
|
|
|
step = MagicMock()
|
|
scraped_page = MagicMock()
|
|
browser_state = MagicMock()
|
|
|
|
json_response = {
|
|
"should_enter_verification_code": True,
|
|
"place_to_enter_verification_code": "input#otp-code",
|
|
"actions": [],
|
|
}
|
|
|
|
with patch.object(agent, "handle_potential_verification_code", new_callable=AsyncMock) as mock_handler:
|
|
mock_handler.return_value = {"actions": []}
|
|
with patch("skyvern.forge.agent.parse_actions", return_value=[]):
|
|
result_json, result_actions = await agent.handle_potential_OTP_actions(
|
|
task, step, scraped_page, browser_state, json_response
|
|
)
|
|
mock_handler.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_potential_OTP_actions_skips_magic_link_without_totp_config():
|
|
"""Magic links should still require TOTP config."""
|
|
agent = ForgeAgent.__new__(ForgeAgent)
|
|
|
|
task = MagicMock()
|
|
task.organization_id = "org_1"
|
|
task.totp_verification_url = None
|
|
task.totp_identifier = None
|
|
|
|
step = MagicMock()
|
|
scraped_page = MagicMock()
|
|
browser_state = MagicMock()
|
|
|
|
json_response = {
|
|
"should_verify_by_magic_link": True,
|
|
}
|
|
|
|
with patch.object(agent, "handle_potential_magic_link", new_callable=AsyncMock) as mock_handler:
|
|
result_json, result_actions = await agent.handle_potential_OTP_actions(
|
|
task, step, scraped_page, browser_state, json_response
|
|
)
|
|
mock_handler.assert_not_called()
|
|
assert result_actions == []
|
|
|
|
|
|
# === Task 7: verification_code_check always True in LLM prompt ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verification_code_check_always_true_without_totp_config():
|
|
"""_build_extract_action_prompt should receive verification_code_check=True even when task has no TOTP config."""
|
|
agent = ForgeAgent.__new__(ForgeAgent)
|
|
agent.async_operation_pool = MagicMock()
|
|
|
|
task = MagicMock()
|
|
task.totp_verification_url = None
|
|
task.totp_identifier = None
|
|
task.task_id = "tsk_1"
|
|
task.workflow_run_id = None
|
|
task.organization_id = "org_1"
|
|
task.url = "https://example.com"
|
|
|
|
step = MagicMock()
|
|
step.step_id = "step_1"
|
|
step.order = 0
|
|
step.retry_index = 0
|
|
|
|
scraped_page = MagicMock()
|
|
scraped_page.elements = []
|
|
|
|
browser_state = MagicMock()
|
|
|
|
with (
|
|
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
|
|
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
|
|
patch.object(
|
|
agent,
|
|
"_build_extract_action_prompt",
|
|
new_callable=AsyncMock,
|
|
return_value=("prompt", False, "extract_action"),
|
|
) as mock_build,
|
|
):
|
|
mock_ctx.current.return_value = None
|
|
await agent.build_and_record_step_prompt(
|
|
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
|
|
)
|
|
mock_build.assert_called_once()
|
|
_, kwargs = mock_build.call_args
|
|
assert kwargs["verification_code_check"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verification_code_check_always_true_with_totp_config():
|
|
"""_build_extract_action_prompt should receive verification_code_check=True when task HAS TOTP config (unchanged)."""
|
|
agent = ForgeAgent.__new__(ForgeAgent)
|
|
agent.async_operation_pool = MagicMock()
|
|
|
|
task = MagicMock()
|
|
task.totp_verification_url = "https://otp.example.com"
|
|
task.totp_identifier = "user@example.com"
|
|
task.task_id = "tsk_2"
|
|
task.workflow_run_id = None
|
|
task.organization_id = "org_1"
|
|
task.url = "https://example.com"
|
|
|
|
step = MagicMock()
|
|
step.step_id = "step_1"
|
|
step.order = 0
|
|
step.retry_index = 0
|
|
|
|
scraped_page = MagicMock()
|
|
scraped_page.elements = []
|
|
|
|
browser_state = MagicMock()
|
|
|
|
with (
|
|
patch("skyvern.forge.agent.skyvern_context") as mock_ctx,
|
|
patch.object(agent, "_scrape_with_type", new_callable=AsyncMock, return_value=scraped_page),
|
|
patch.object(
|
|
agent,
|
|
"_build_extract_action_prompt",
|
|
new_callable=AsyncMock,
|
|
return_value=("prompt", False, "extract_action"),
|
|
) as mock_build,
|
|
):
|
|
mock_ctx.current.return_value = None
|
|
await agent.build_and_record_step_prompt(
|
|
task, step, browser_state, RunEngine.skyvern_v1, persist_artifacts=False
|
|
)
|
|
mock_build.assert_called_once()
|
|
_, kwargs = mock_build.call_args
|
|
assert kwargs["verification_code_check"] is True
|
|
|
|
|
|
# === Fix: poll_otp_value should pass workflow_id, not workflow_permanent_id ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_poll_otp_value_passes_workflow_id_not_permanent_id():
|
|
"""poll_otp_value should pass workflow_id (w_* format) to _get_otp_value_from_db, not workflow_permanent_id."""
|
|
mock_db = AsyncMock()
|
|
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
|
mock_db.update_workflow_run = AsyncMock()
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.DATABASE = mock_db
|
|
|
|
with (
|
|
patch("skyvern.services.otp_service.app", new=mock_app),
|
|
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
|
patch(
|
|
"skyvern.services.otp_service._get_otp_value_from_db",
|
|
new_callable=AsyncMock,
|
|
return_value=OTPValue(value="654321", type="totp"),
|
|
) as mock_get_from_db,
|
|
):
|
|
result = await poll_otp_value(
|
|
organization_id="org_1",
|
|
workflow_id="w_123",
|
|
workflow_run_id="wr_789",
|
|
workflow_permanent_id="wpid_456",
|
|
totp_identifier="user@example.com",
|
|
)
|
|
assert result is not None
|
|
assert result.value == "654321"
|
|
mock_get_from_db.assert_called_once_with(
|
|
"org_1",
|
|
"user@example.com",
|
|
task_id=None,
|
|
workflow_id="w_123",
|
|
workflow_run_id="wr_789",
|
|
)
|
|
|
|
|
|
# === Fix: send_totp_code should resolve wpid_* to w_* before storage ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_totp_code_resolves_wpid_to_workflow_id():
|
|
"""send_totp_code should resolve wpid_* to w_* before storing in DB."""
|
|
mock_workflow = MagicMock()
|
|
mock_workflow.workflow_id = "w_abc123"
|
|
|
|
mock_totp_code = MagicMock()
|
|
|
|
mock_db = AsyncMock()
|
|
mock_db.get_workflow_by_permanent_id = AsyncMock(return_value=mock_workflow)
|
|
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.DATABASE = mock_db
|
|
|
|
data = TOTPCodeCreate(
|
|
totp_identifier="user@example.com",
|
|
content="123456",
|
|
workflow_id="wpid_xyz789",
|
|
)
|
|
curr_org = MagicMock()
|
|
curr_org.organization_id = "org_1"
|
|
|
|
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
|
await send_totp_code(data=data, curr_org=curr_org)
|
|
|
|
mock_db.create_otp_code.assert_called_once()
|
|
call_kwargs = mock_db.create_otp_code.call_args[1]
|
|
assert call_kwargs["workflow_id"] == "w_abc123", f"Expected w_abc123 but got {call_kwargs['workflow_id']}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_totp_code_w_format_passes_through():
|
|
"""send_totp_code should resolve and store w_* format workflow_id correctly."""
|
|
mock_workflow = MagicMock()
|
|
mock_workflow.workflow_id = "w_abc123"
|
|
|
|
mock_totp_code = MagicMock()
|
|
|
|
mock_db = AsyncMock()
|
|
mock_db.get_workflow = AsyncMock(return_value=mock_workflow)
|
|
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.DATABASE = mock_db
|
|
|
|
data = TOTPCodeCreate(
|
|
totp_identifier="user@example.com",
|
|
content="123456",
|
|
workflow_id="w_abc123",
|
|
)
|
|
curr_org = MagicMock()
|
|
curr_org.organization_id = "org_1"
|
|
|
|
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
|
await send_totp_code(data=data, curr_org=curr_org)
|
|
|
|
call_kwargs = mock_db.create_otp_code.call_args[1]
|
|
assert call_kwargs["workflow_id"] == "w_abc123"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_totp_code_none_workflow_id():
|
|
"""send_totp_code should pass None workflow_id when not provided."""
|
|
mock_totp_code = MagicMock()
|
|
|
|
mock_db = AsyncMock()
|
|
mock_db.create_otp_code = AsyncMock(return_value=mock_totp_code)
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.DATABASE = mock_db
|
|
|
|
data = TOTPCodeCreate(
|
|
totp_identifier="user@example.com",
|
|
content="123456",
|
|
)
|
|
curr_org = MagicMock()
|
|
curr_org.organization_id = "org_1"
|
|
|
|
with patch("skyvern.forge.sdk.routes.credentials.app", new=mock_app):
|
|
await send_totp_code(data=data, curr_org=curr_org)
|
|
|
|
call_kwargs = mock_db.create_otp_code.call_args[1]
|
|
assert call_kwargs["workflow_id"] is None
|
|
|
|
|
|
# === Fix: _build_navigation_payload should inject code without TOTP config ===
|
|
|
|
|
|
def test_build_navigation_payload_injects_code_without_totp_config():
|
|
"""_build_navigation_payload should inject SPECIAL_FIELD_VERIFICATION_CODE even when
|
|
task has no totp_verification_url or totp_identifier (manual 2FA flow)."""
|
|
agent = ForgeAgent.__new__(ForgeAgent)
|
|
|
|
task = MagicMock()
|
|
task.totp_verification_url = None
|
|
task.totp_identifier = None
|
|
task.task_id = "tsk_manual_2fa"
|
|
task.workflow_run_id = "wr_123"
|
|
task.navigation_payload = {"username": "user@example.com"}
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.totp_codes = {"tsk_manual_2fa": "123456"}
|
|
mock_context.has_magic_link_page.return_value = False
|
|
|
|
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
|
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
|
result = agent._build_navigation_payload(task)
|
|
|
|
assert isinstance(result, dict)
|
|
assert SPECIAL_FIELD_VERIFICATION_CODE in result
|
|
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "123456"
|
|
# Original payload preserved
|
|
assert result["username"] == "user@example.com"
|
|
|
|
|
|
def test_build_navigation_payload_injects_code_when_payload_is_none():
|
|
"""_build_navigation_payload should create a dict with the code when payload is None."""
|
|
agent = ForgeAgent.__new__(ForgeAgent)
|
|
|
|
task = MagicMock()
|
|
task.totp_verification_url = None
|
|
task.totp_identifier = None
|
|
task.task_id = "tsk_manual_2fa"
|
|
task.workflow_run_id = "wr_123"
|
|
task.navigation_payload = None
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.totp_codes = {"tsk_manual_2fa": "999999"}
|
|
mock_context.has_magic_link_page.return_value = False
|
|
|
|
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
|
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
|
result = agent._build_navigation_payload(task)
|
|
|
|
assert isinstance(result, dict)
|
|
assert result[SPECIAL_FIELD_VERIFICATION_CODE] == "999999"
|
|
|
|
|
|
def test_build_navigation_payload_no_code_no_injection():
|
|
"""_build_navigation_payload should NOT inject anything when no code in context."""
|
|
agent = ForgeAgent.__new__(ForgeAgent)
|
|
|
|
task = MagicMock()
|
|
task.totp_verification_url = None
|
|
task.totp_identifier = None
|
|
task.task_id = "tsk_no_code"
|
|
task.workflow_run_id = "wr_456"
|
|
task.navigation_payload = {"field": "value"}
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.totp_codes = {} # No code in context
|
|
mock_context.has_magic_link_page.return_value = False
|
|
|
|
with patch("skyvern.forge.agent.skyvern_context") as mock_skyvern_ctx:
|
|
mock_skyvern_ctx.ensure_context.return_value = mock_context
|
|
result = agent._build_navigation_payload(task)
|
|
|
|
assert isinstance(result, dict)
|
|
assert SPECIAL_FIELD_VERIFICATION_CODE not in result
|
|
assert result["field"] == "value"
|
|
|
|
|
|
# === Task: poll_otp_value publishes 2FA events to notification registry ===
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_poll_otp_value_publishes_required_event_for_task():
|
|
"""poll_otp_value should publish verification_code_required when task waiting state is set."""
|
|
mock_code = MagicMock()
|
|
mock_code.code = "123456"
|
|
mock_code.otp_type = "totp"
|
|
|
|
mock_db = AsyncMock()
|
|
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
|
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
|
mock_db.update_task_2fa_state = AsyncMock()
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.DATABASE = mock_db
|
|
|
|
registry = LocalNotificationRegistry()
|
|
queue = registry.subscribe("org_1")
|
|
|
|
with (
|
|
patch("skyvern.services.otp_service.app", new=mock_app),
|
|
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
|
patch(
|
|
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
|
|
new=registry,
|
|
),
|
|
):
|
|
await poll_otp_value(organization_id="org_1", task_id="tsk_1")
|
|
|
|
# Should have received required + resolved messages
|
|
messages = []
|
|
while not queue.empty():
|
|
messages.append(queue.get_nowait())
|
|
|
|
types = [m["type"] for m in messages]
|
|
assert "verification_code_required" in types
|
|
assert "verification_code_resolved" in types
|
|
|
|
required = next(m for m in messages if m["type"] == "verification_code_required")
|
|
assert required["task_id"] == "tsk_1"
|
|
|
|
resolved = next(m for m in messages if m["type"] == "verification_code_resolved")
|
|
assert resolved["task_id"] == "tsk_1"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_poll_otp_value_publishes_required_event_for_workflow_run():
|
|
"""poll_otp_value should publish verification_code_required when workflow run waiting state is set."""
|
|
mock_code = MagicMock()
|
|
mock_code.code = "654321"
|
|
mock_code.otp_type = "totp"
|
|
|
|
mock_db = AsyncMock()
|
|
mock_db.get_valid_org_auth_token.return_value = MagicMock(token="tok")
|
|
mock_db.update_workflow_run = AsyncMock()
|
|
mock_db.get_otp_codes_by_run.return_value = [mock_code]
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.DATABASE = mock_db
|
|
|
|
registry = LocalNotificationRegistry()
|
|
queue = registry.subscribe("org_1")
|
|
|
|
with (
|
|
patch("skyvern.services.otp_service.app", new=mock_app),
|
|
patch("skyvern.services.otp_service.asyncio.sleep", new_callable=AsyncMock),
|
|
patch(
|
|
"skyvern.forge.sdk.notification.factory.NotificationRegistryFactory._NotificationRegistryFactory__registry",
|
|
new=registry,
|
|
),
|
|
):
|
|
await poll_otp_value(organization_id="org_1", workflow_run_id="wr_1")
|
|
|
|
messages = []
|
|
while not queue.empty():
|
|
messages.append(queue.get_nowait())
|
|
|
|
types = [m["type"] for m in messages]
|
|
assert "verification_code_required" in types
|
|
assert "verification_code_resolved" in types
|
|
|
|
required = next(m for m in messages if m["type"] == "verification_code_required")
|
|
assert required["workflow_run_id"] == "wr_1"
|