align workflow CLI commands with MCP parity (#4792)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -14,10 +15,13 @@ from skyvern.cli.mcp_tools import session as mcp_session
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singletons() -> None:
|
||||
client_mod._skyvern_instance.set(None)
|
||||
client_mod._api_key_override.set(None)
|
||||
client_mod._global_skyvern_instance = None
|
||||
client_mod._api_key_clients.clear()
|
||||
|
||||
session_manager._current_session.set(None)
|
||||
session_manager._global_session = None
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
mcp_session.set_current_session(mcp_session.SessionState())
|
||||
|
||||
|
||||
@@ -47,6 +51,115 @@ def test_get_skyvern_reuses_global_instance_across_contexts(monkeypatch: pytest.
|
||||
assert len(created) == 1
|
||||
|
||||
|
||||
def test_get_skyvern_reuses_override_instance_per_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created_keys: list[str] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
created_keys.append(kwargs["api_key"])
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls(api_key="local")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
|
||||
token = client_mod.set_api_key_override("sk_key_a")
|
||||
try:
|
||||
first = client_mod.get_skyvern()
|
||||
client_mod._skyvern_instance.set(None)
|
||||
second = client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert first is second
|
||||
assert created_keys == ["sk_key_a"]
|
||||
|
||||
|
||||
def test_get_skyvern_override_client_cache_uses_lru_eviction(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created_keys: list[str] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
created_keys.append(kwargs["api_key"])
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls(api_key="local")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 2)
|
||||
|
||||
for key in ("sk_key_a", "sk_key_b"):
|
||||
token = client_mod.set_api_key_override(key)
|
||||
try:
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
# Touch key_a so key_b becomes least-recently-used.
|
||||
token = client_mod.set_api_key_override("sk_key_a")
|
||||
try:
|
||||
client_mod._skyvern_instance.set(None)
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
# Adding key_c should evict key_b.
|
||||
token = client_mod.set_api_key_override("sk_key_c")
|
||||
try:
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert list(client_mod._api_key_clients.keys()) == [
|
||||
client_mod._cache_key("sk_key_a"),
|
||||
client_mod._cache_key("sk_key_c"),
|
||||
]
|
||||
# key_a, key_b, key_c were created exactly once each.
|
||||
assert created_keys == ["sk_key_a", "sk_key_b", "sk_key_c"]
|
||||
|
||||
|
||||
def test_get_skyvern_override_cache_closes_evicted_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
closed_keys: list[str] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
self.api_key = kwargs["api_key"]
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls(api_key="local")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
closed_keys.append(self.api_key)
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 1)
|
||||
|
||||
for key in ("sk_key_a", "sk_key_b"):
|
||||
token = client_mod.set_api_key_override(key)
|
||||
try:
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert list(client_mod._api_key_clients.keys()) == [client_mod._cache_key("sk_key_b")]
|
||||
assert closed_keys == ["sk_key_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_skyvern_closes_singleton() -> None:
|
||||
fake = MagicMock()
|
||||
@@ -75,12 +188,53 @@ def test_get_current_session_falls_back_to_global_state() -> None:
|
||||
assert recovered is state
|
||||
|
||||
|
||||
def test_get_current_session_stateless_mode_ignores_global_state() -> None:
|
||||
global_state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_999"),
|
||||
)
|
||||
session_manager._global_session = global_state
|
||||
session_manager._current_session.set(None)
|
||||
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
recovered = session_manager.get_current_session()
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert recovered is not global_state
|
||||
assert recovered.browser is None
|
||||
assert recovered.context is None
|
||||
|
||||
|
||||
def test_set_current_session_stateless_mode_does_not_override_global_state() -> None:
|
||||
global_state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
|
||||
)
|
||||
session_manager._global_session = global_state
|
||||
replacement = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_request"),
|
||||
)
|
||||
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
session_manager.set_current_session(replacement)
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert session_manager._global_session is global_state
|
||||
assert session_manager._current_session.get() is replacement
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_browser = MagicMock()
|
||||
current_state = session_manager.SessionState(
|
||||
browser=current_browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
api_key_hash=session_manager._api_key_hash(client_mod.get_active_api_key()),
|
||||
)
|
||||
session_manager.set_current_session(current_state)
|
||||
|
||||
@@ -95,6 +249,92 @@ async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest
|
||||
fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_does_not_reuse_session_for_different_api_key(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
current_browser = MagicMock()
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=current_browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
api_key_hash=session_manager._api_key_hash("sk_key_a"),
|
||||
)
|
||||
)
|
||||
|
||||
replacement_browser = MagicMock()
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
token = client_mod.set_api_key_override("sk_key_b")
|
||||
try:
|
||||
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert browser is replacement_browser
|
||||
assert ctx.session_id == "pbs_123"
|
||||
fake_skyvern.connect_to_cloud_browser_session.assert_awaited_once_with("pbs_123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_stateless_mode_does_not_write_global_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
global_state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
|
||||
)
|
||||
session_manager._global_session = global_state
|
||||
|
||||
replacement_browser = MagicMock()
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert browser is replacement_browser
|
||||
assert ctx.session_id == "pbs_123"
|
||||
assert session_manager._global_session is global_state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_blocks_implicit_session_in_stateless_mode() -> None:
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
api_key_hash=session_manager._api_key_hash("sk_key_a"),
|
||||
)
|
||||
)
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
with pytest.raises(session_manager.BrowserNotAvailableError):
|
||||
await session_manager.resolve_browser()
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_raises_for_invalid_matching_state(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=None,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(session_manager, "_matches_current", lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: MagicMock())
|
||||
|
||||
with pytest.raises(RuntimeError, match="Expected active browser and context"):
|
||||
await session_manager.resolve_browser(session_id="pbs_123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_browser = MagicMock()
|
||||
@@ -262,3 +502,73 @@ async def test_close_current_session_still_closes_browser_when_api_fails(monkeyp
|
||||
# _browser_session_id should NOT be cleared (API close failed, let browser.close() try)
|
||||
assert browser._browser_session_id == "pbs_fail"
|
||||
assert session_manager.get_current_session().browser is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for stateless HTTP mode session creation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_create_stateless_mode_returns_session_without_persisting_browser(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.create_browser_session = AsyncMock(return_value=SimpleNamespace(browser_session_id="pbs_abc"))
|
||||
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||
do_session_create = AsyncMock()
|
||||
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
|
||||
|
||||
try:
|
||||
result = await mcp_session.skyvern_session_create(timeout=45)
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["data"] == {"session_id": "pbs_abc", "timeout_minutes": 45}
|
||||
do_session_create.assert_not_awaited()
|
||||
assert mcp_session.get_current_session().browser is None
|
||||
assert mcp_session.get_current_session().context is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_create_stateless_mode_rejects_local() -> None:
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
result = await mcp_session.skyvern_session_create(local=True)
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["error"]["code"] == mcp_session.ErrorCode.INVALID_INPUT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_create_persists_active_api_key_hash_in_session_state(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_skyvern = MagicMock()
|
||||
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
fake_browser = MagicMock()
|
||||
do_session_create = AsyncMock(
|
||||
return_value=(
|
||||
fake_browser,
|
||||
SimpleNamespace(local=False, session_id="pbs_123", timeout_minutes=60, headless=False),
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
|
||||
|
||||
token = client_mod.set_api_key_override("sk_key_create")
|
||||
try:
|
||||
result = await mcp_session.skyvern_session_create(timeout=60)
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert result["ok"] is True
|
||||
current = mcp_session.get_current_session()
|
||||
assert current.browser is fake_browser
|
||||
assert current.context == BrowserContext(mode="cloud_session", session_id="pbs_123")
|
||||
assert current.api_key_hash == session_manager._api_key_hash("sk_key_create")
|
||||
assert current.api_key_hash != "sk_key_create"
|
||||
|
||||
Reference in New Issue
Block a user