Extract shared core from MCP tools, add CLI browser commands (#4768)

This commit is contained in:
Marc Kelechava
2026-02-17 11:24:56 -08:00
committed by GitHub
parent aacc612365
commit 7c5be8fefe
14 changed files with 1304 additions and 113 deletions

View File

@@ -0,0 +1,149 @@
"""Tests for CLI commands infrastructure: _state.py and _output.py."""
from __future__ import annotations
import json
from pathlib import Path
import pytest
import typer
from skyvern.cli.commands._state import CLIState, clear_state, load_state, save_state
# ---------------------------------------------------------------------------
# _state.py
# ---------------------------------------------------------------------------
def _patch_state_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
monkeypatch.setattr("skyvern.cli.commands._state.STATE_DIR", tmp_path)
monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "state.json")
class TestCLIState:
def test_save_load_roundtrip(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id="pbs_123", cdp_url=None, mode="cloud"))
loaded = load_state()
assert loaded is not None
assert loaded.session_id == "pbs_123"
assert loaded.cdp_url is None
assert loaded.mode == "cloud"
def test_save_load_roundtrip_cdp(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id=None, cdp_url="ws://localhost:9222/devtools/browser/abc", mode="cdp"))
loaded = load_state()
assert loaded is not None
assert loaded.session_id is None
assert loaded.cdp_url == "ws://localhost:9222/devtools/browser/abc"
assert loaded.mode == "cdp"
def test_load_returns_none_when_missing(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "nonexistent.json")
assert load_state() is None
def test_24h_ttl_expires(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
_patch_state_dir(monkeypatch, tmp_path)
state_file = tmp_path / "state.json"
state_file.write_text(
json.dumps(
{
"session_id": "pbs_old",
"mode": "cloud",
"created_at": "2020-01-01T00:00:00+00:00",
}
)
)
assert load_state() is None
def test_clear_state(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id="pbs_123"))
clear_state()
assert not (tmp_path / "state.json").exists()
def test_load_ignores_corrupt_file(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
state_file = tmp_path / "state.json"
monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", state_file)
state_file.write_text("not-json")
assert load_state() is None
# ---------------------------------------------------------------------------
# _output.py
# ---------------------------------------------------------------------------
class TestOutput:
def test_json_envelope(self, capsys: pytest.CaptureFixture) -> None:
from skyvern.cli.commands._output import output
output({"key": "value"}, action="test", json_mode=True)
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is True
assert parsed["action"] == "test"
assert parsed["data"]["key"] == "value"
def test_json_error(self, capsys: pytest.CaptureFixture) -> None:
from skyvern.cli.commands._output import output_error
with pytest.raises(SystemExit, match="1"):
output_error("bad thing", hint="fix it", json_mode=True)
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is False
assert parsed["error"]["message"] == "bad thing"
# ---------------------------------------------------------------------------
# Connection resolution
# ---------------------------------------------------------------------------
class TestResolveConnection:
def test_explicit_session_wins(self) -> None:
from skyvern.cli.commands.browser import _resolve_connection
result = _resolve_connection("pbs_explicit", None)
assert result.mode == "cloud"
assert result.session_id == "pbs_explicit"
assert result.cdp_url is None
def test_explicit_cdp_wins(self) -> None:
from skyvern.cli.commands.browser import _resolve_connection
result = _resolve_connection(None, "ws://localhost:9222/devtools/browser/abc")
assert result.mode == "cdp"
assert result.session_id is None
assert result.cdp_url == "ws://localhost:9222/devtools/browser/abc"
def test_rejects_both_connection_flags(self) -> None:
from skyvern.cli.commands.browser import _resolve_connection
with pytest.raises(typer.BadParameter, match="Pass only one of --session or --cdp"):
_resolve_connection("pbs_explicit", "ws://localhost:9222/devtools/browser/abc")
def test_state_fallback(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
from skyvern.cli.commands.browser import _resolve_connection
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id="pbs_from_state", mode="cloud"))
result = _resolve_connection(None, None)
assert result.mode == "cloud"
assert result.session_id == "pbs_from_state"
def test_state_fallback_cdp(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
from skyvern.cli.commands.browser import _resolve_connection
_patch_state_dir(monkeypatch, tmp_path)
save_state(CLIState(session_id=None, cdp_url="ws://localhost:9222/devtools/browser/abc", mode="cdp"))
result = _resolve_connection(None, None)
assert result.mode == "cdp"
assert result.cdp_url == "ws://localhost:9222/devtools/browser/abc"
def test_no_session_raises(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
from skyvern.cli.commands.browser import _resolve_connection
monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "nonexistent.json")
with pytest.raises(typer.BadParameter, match="No active browser connection"):
_resolve_connection(None, None)

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
import logging
import skyvern.cli.commands as cli_commands
def test_configure_cli_logging_is_idempotent(monkeypatch) -> None:
setup_calls: list[str] = []
monkeypatch.setattr(cli_commands, "_setup_logger", lambda: setup_calls.append("called"))
monkeypatch.setattr(cli_commands, "_cli_logging_configured", False)
logger_names = ("skyvern", "httpx", "litellm", "playwright", "httpcore")
previous_levels = {name: logging.getLogger(name).level for name in logger_names}
try:
cli_commands.configure_cli_logging()
assert setup_calls == ["called"]
for name in logger_names:
assert logging.getLogger(name).level == logging.WARNING
cli_commands.configure_cli_logging()
assert setup_calls == ["called"]
finally:
for name, level in previous_levels.items():
logging.getLogger(name).setLevel(level)
def test_cli_callback_configures_logging(monkeypatch) -> None:
called = False
def _fake_configure() -> None:
nonlocal called
called = True
monkeypatch.setattr(cli_commands, "configure_cli_logging", _fake_configure)
cli_commands.cli_callback()
assert called

View File

@@ -0,0 +1,234 @@
"""Tests for skyvern.cli.core shared modules (guards, browser_ops, session_ops)."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from skyvern.cli.core.browser_ops import do_act, do_extract, do_navigate, do_screenshot, parse_extract_schema
from skyvern.cli.core.guards import (
GuardError,
check_js_password,
check_password_prompt,
resolve_ai_mode,
validate_button,
validate_wait_until,
)
from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list
# ---------------------------------------------------------------------------
# guards.py
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"text",
[
"enter your password",
"use credential to login",
"type the secret",
"enter passphrase",
"enter passcode",
"enter your pin code",
"type pwd here",
"enter passwd",
],
)
def test_password_guard_blocks_sensitive_text(text: str) -> None:
with pytest.raises(GuardError) as exc_info:
check_password_prompt(text)
assert exc_info.value.hint # hint should always be populated
@pytest.mark.parametrize("text", ["click the submit button", "fill in the email field", ""])
def test_password_guard_allows_normal_text(text: str) -> None:
check_password_prompt(text) # should not raise
def test_js_password_guard() -> None:
with pytest.raises(GuardError):
check_js_password('input[type=password].value = "secret"')
with pytest.raises(GuardError):
check_js_password('.type === "password"; el.value = "x"')
check_js_password("document.title") # allowed
@pytest.mark.parametrize("value", ["load", "domcontentloaded", "networkidle", "commit", None])
def test_wait_until_accepts_valid(value: str | None) -> None:
validate_wait_until(value)
def test_wait_until_rejects_invalid() -> None:
with pytest.raises(GuardError, match="Invalid wait_until"):
validate_wait_until("badvalue")
@pytest.mark.parametrize("value", ["left", "right", "middle", None])
def test_button_accepts_valid(value: str | None) -> None:
validate_button(value)
def test_button_rejects_invalid() -> None:
with pytest.raises(GuardError, match="Invalid button"):
validate_button("double")
@pytest.mark.parametrize(
"selector,intent,expected",
[
(None, "click it", ("proactive", None)),
("#btn", "click it", ("fallback", None)),
("#btn", None, (None, None)),
(None, None, (None, "INVALID_INPUT")),
],
)
def test_resolve_ai_mode(selector: str | None, intent: str | None, expected: tuple) -> None:
assert resolve_ai_mode(selector, intent) == expected
# ---------------------------------------------------------------------------
# browser_ops.py
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_do_navigate_success() -> None:
page = MagicMock()
page.goto = AsyncMock()
page.url = "https://example.com/final"
page.title = AsyncMock(return_value="Example")
result = await do_navigate(page, "https://example.com")
assert result.url == "https://example.com/final"
assert result.title == "Example"
@pytest.mark.asyncio
async def test_do_navigate_passes_wait_until_through() -> None:
page = MagicMock()
page.goto = AsyncMock()
page.url = "https://example.com/final"
page.title = AsyncMock(return_value="Example")
result = await do_navigate(page, "https://example.com", wait_until="badvalue")
assert result.url == "https://example.com/final"
page.goto.assert_awaited_once_with("https://example.com", timeout=30000, wait_until="badvalue")
@pytest.mark.asyncio
async def test_do_screenshot_full_page() -> None:
page = MagicMock()
page.screenshot = AsyncMock(return_value=b"png-data")
result = await do_screenshot(page, full_page=True)
assert result.data == b"png-data"
assert result.full_page is True
@pytest.mark.asyncio
async def test_do_screenshot_with_selector() -> None:
page = MagicMock()
element = MagicMock()
element.screenshot = AsyncMock(return_value=b"element-data")
page.locator.return_value = element
result = await do_screenshot(page, selector="#header")
assert result.data == b"element-data"
@pytest.mark.asyncio
async def test_do_act_success() -> None:
page = MagicMock()
page.act = AsyncMock()
result = await do_act(page, "enter the password")
assert result.prompt == "enter the password"
assert result.completed is True
@pytest.mark.asyncio
async def test_do_extract_rejects_bad_schema() -> None:
with pytest.raises(GuardError, match="Invalid JSON schema"):
await do_extract(MagicMock(), "get data", schema="not-json")
@pytest.mark.asyncio
async def test_do_extract_success() -> None:
page = MagicMock()
page.extract = AsyncMock(return_value={"title": "Example"})
result = await do_extract(page, "get the title")
assert result.extracted == {"title": "Example"}
def test_parse_extract_schema_accepts_preparsed_dict() -> None:
schema = {"type": "object", "properties": {"title": {"type": "string"}}}
parsed = parse_extract_schema(schema)
assert parsed is schema
@pytest.mark.asyncio
async def test_do_extract_accepts_preparsed_dict() -> None:
page = MagicMock()
page.extract = AsyncMock(return_value={"title": "Example"})
schema = {"type": "object", "properties": {"title": {"type": "string"}}}
result = await do_extract(page, "get the title", schema=schema)
assert result.extracted == {"title": "Example"}
page.extract.assert_awaited_once_with(prompt="get the title", schema=schema)
# ---------------------------------------------------------------------------
# session_ops.py
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_do_session_create_local() -> None:
skyvern = MagicMock()
skyvern.launch_local_browser = AsyncMock(return_value=MagicMock())
browser, result = await do_session_create(skyvern, local=True, headless=True)
assert result.local is True
assert result.session_id is None
@pytest.mark.asyncio
async def test_do_session_create_cloud() -> None:
skyvern = MagicMock()
browser_mock = MagicMock()
browser_mock.browser_session_id = "pbs_123"
skyvern.launch_cloud_browser = AsyncMock(return_value=browser_mock)
browser, result = await do_session_create(skyvern, timeout=30)
assert result.session_id == "pbs_123"
assert result.timeout_minutes == 30
@pytest.mark.asyncio
async def test_do_session_close() -> None:
skyvern = MagicMock()
skyvern.close_browser_session = AsyncMock()
result = await do_session_close(skyvern, "pbs_123")
assert result.session_id == "pbs_123"
assert result.closed is True
@pytest.mark.asyncio
async def test_do_session_list() -> None:
session = MagicMock()
session.browser_session_id = "pbs_1"
session.status = "active"
session.started_at = None
session.timeout = 60
session.runnable_id = None
session.browser_address = "ws://localhost:1234"
skyvern = MagicMock()
skyvern.get_browser_sessions = AsyncMock(return_value=[session])
result = await do_session_list(skyvern)
assert len(result) == 1
assert result[0].session_id == "pbs_1"
assert result[0].available is True

View File

@@ -0,0 +1,65 @@
"""Tests for MCP browser tool preflight validation behavior."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from skyvern.cli.core.result import BrowserContext
from skyvern.cli.mcp_tools import browser as mcp_browser
@pytest.mark.asyncio
async def test_skyvern_extract_invalid_schema_preflight_before_session(monkeypatch: pytest.MonkeyPatch) -> None:
get_page = AsyncMock(side_effect=AssertionError("get_page should not be called for invalid schema"))
monkeypatch.setattr(mcp_browser, "get_page", get_page)
result = await mcp_browser.skyvern_extract(prompt="extract data", schema="{invalid")
assert result["ok"] is False
assert result["error"]["code"] == mcp_browser.ErrorCode.INVALID_INPUT
assert "Invalid JSON schema" in result["error"]["message"]
get_page.assert_not_awaited()
@pytest.mark.asyncio
async def test_skyvern_extract_preparsed_schema_passed_to_core(monkeypatch: pytest.MonkeyPatch) -> None:
page = object()
context = BrowserContext(mode="cloud_session", session_id="pbs_test")
monkeypatch.setattr(mcp_browser, "get_page", AsyncMock(return_value=(page, context)))
do_extract = AsyncMock(return_value=SimpleNamespace(extracted={"ok": True}))
monkeypatch.setattr(mcp_browser, "do_extract", do_extract)
result = await mcp_browser.skyvern_extract(prompt="extract data", schema='{"type":"object"}')
assert result["ok"] is True
await_args = do_extract.await_args
assert await_args is not None
assert isinstance(await_args.kwargs["schema"], dict)
@pytest.mark.asyncio
async def test_skyvern_navigate_invalid_wait_until_preflight_before_session(monkeypatch: pytest.MonkeyPatch) -> None:
get_page = AsyncMock(side_effect=AssertionError("get_page should not be called for invalid wait_until"))
monkeypatch.setattr(mcp_browser, "get_page", get_page)
result = await mcp_browser.skyvern_navigate(url="https://example.com", wait_until="not-a-real-wait-until")
assert result["ok"] is False
assert result["error"]["code"] == mcp_browser.ErrorCode.INVALID_INPUT
get_page.assert_not_awaited()
@pytest.mark.asyncio
async def test_skyvern_act_password_prompt_preflight_before_session(monkeypatch: pytest.MonkeyPatch) -> None:
get_page = AsyncMock(side_effect=AssertionError("get_page should not be called for password prompt"))
monkeypatch.setattr(mcp_browser, "get_page", get_page)
result = await mcp_browser.skyvern_act(prompt="enter the password and submit")
assert result["ok"] is False
assert result["error"]["code"] == mcp_browser.ErrorCode.INVALID_INPUT
get_page.assert_not_awaited()