align workflow CLI commands with MCP parity (#4792)

This commit is contained in:
Marc Kelechava
2026-02-18 11:34:12 -08:00
committed by GitHub
parent 2f6850ce20
commit 46a7ec1d26
12 changed files with 1609 additions and 151 deletions

View File

@@ -296,3 +296,175 @@ class TestBrowserCommands:
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is False
assert "Invalid state" in parsed["error"]["message"]
# ---------------------------------------------------------------------------
# Workflow command behavior
# ---------------------------------------------------------------------------
class TestWorkflowCommands:
def test_workflow_get_outputs_mcp_envelope_in_json_mode(
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
) -> None:
from skyvern.cli import workflow as workflow_cmd
expected = {
"ok": True,
"action": "skyvern_workflow_get",
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
"data": {"workflow_permanent_id": "wpid_123"},
"artifacts": [],
"timing_ms": {},
"warnings": [],
"error": None,
}
tool = AsyncMock(return_value=expected)
monkeypatch.setattr(workflow_cmd, "tool_workflow_get", tool)
workflow_cmd.workflow_get(workflow_id="wpid_123", version=2, json_output=True)
parsed = json.loads(capsys.readouterr().out)
assert parsed == expected
assert tool.await_args.kwargs == {"workflow_id": "wpid_123", "version": 2}
def test_workflow_create_reads_definition_from_file(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture,
) -> None:
from skyvern.cli import workflow as workflow_cmd
definition_file = tmp_path / "workflow.json"
definition_text = '{"title": "Example", "workflow_definition": {"blocks": []}}'
definition_file.write_text(definition_text)
tool = AsyncMock(
return_value={
"ok": True,
"action": "skyvern_workflow_create",
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
"data": {"workflow_permanent_id": "wpid_new"},
"artifacts": [],
"timing_ms": {},
"warnings": [],
"error": None,
}
)
monkeypatch.setattr(workflow_cmd, "tool_workflow_create", tool)
workflow_cmd.workflow_create(
definition=f"@{definition_file}",
definition_format="json",
folder_id="fld_123",
json_output=True,
)
assert tool.await_args.kwargs == {
"definition": definition_text,
"format": "json",
"folder_id": "fld_123",
}
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is True
assert parsed["data"]["workflow_permanent_id"] == "wpid_new"
def test_workflow_run_reads_params_file_and_maps_options(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture,
) -> None:
from skyvern.cli import workflow as workflow_cmd
params_file = tmp_path / "params.json"
params_file.write_text('{"company": "Acme"}')
tool = AsyncMock(
return_value={
"ok": True,
"action": "skyvern_workflow_run",
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
"data": {"run_id": "wr_123", "status": "queued"},
"artifacts": [],
"timing_ms": {},
"warnings": [],
"error": None,
}
)
monkeypatch.setattr(workflow_cmd, "tool_workflow_run", tool)
workflow_cmd.workflow_run(
workflow_id="wpid_123",
params=f"@{params_file}",
session="pbs_456",
webhook="https://example.com/webhook",
proxy="RESIDENTIAL",
wait=True,
timeout=450,
json_output=True,
)
assert tool.await_args.kwargs == {
"workflow_id": "wpid_123",
"parameters": '{"company": "Acme"}',
"browser_session_id": "pbs_456",
"webhook_url": "https://example.com/webhook",
"proxy_location": "RESIDENTIAL",
"wait": True,
"timeout_seconds": 450,
}
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is True
assert parsed["data"]["run_id"] == "wr_123"
def test_workflow_status_json_error_exits_nonzero(
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
) -> None:
from skyvern.cli import workflow as workflow_cmd
tool = AsyncMock(
return_value={
"ok": False,
"action": "skyvern_workflow_status",
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
"data": None,
"artifacts": [],
"timing_ms": {},
"warnings": [],
"error": {
"code": "RUN_NOT_FOUND",
"message": "Run 'wr_missing' not found",
"hint": "Verify the run ID",
"details": {},
},
}
)
monkeypatch.setattr(workflow_cmd, "tool_workflow_status", tool)
with pytest.raises(SystemExit, match="1"):
workflow_cmd.workflow_status(run_id="wr_missing", json_output=True)
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is False
assert parsed["error"]["code"] == "RUN_NOT_FOUND"
def test_workflow_update_missing_definition_file_raises_bad_parameter(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
from skyvern.cli import workflow as workflow_cmd
tool = AsyncMock()
monkeypatch.setattr(workflow_cmd, "tool_workflow_update", tool)
missing_file = tmp_path / "missing-definition.json"
with pytest.raises(typer.BadParameter, match="Unable to read definition file"):
workflow_cmd.workflow_update(
workflow_id="wpid_123",
definition=f"@{missing_file}",
definition_format="json",
json_output=False,
)
tool.assert_not_called()

View File

@@ -0,0 +1,304 @@
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock
import httpx
import pytest
from fastapi import HTTPException
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from skyvern.cli.core import client as client_mod
from skyvern.cli.core import mcp_http_auth
@pytest.fixture(autouse=True)
def _reset_auth_context() -> None:
client_mod._api_key_override.set(None)
mcp_http_auth._auth_db = None
mcp_http_auth._api_key_validation_cache.clear()
mcp_http_auth._API_KEY_CACHE_TTL_SECONDS = 30.0
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
async def _echo_request_context(request: Request) -> JSONResponse:
return JSONResponse(
{
"api_key": client_mod.get_active_api_key(),
"organization_id": getattr(request.state, "organization_id", None),
}
)
def _build_test_app() -> Starlette:
return Starlette(
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
middleware=[Middleware(mcp_http_auth.MCPAPIKeyMiddleware)],
)
@pytest.mark.asyncio
async def test_mcp_http_auth_rejects_missing_api_key() -> None:
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", json={})
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert "x-api-key" in response.json()["error"]["message"]
@pytest.mark.asyncio
async def test_mcp_http_auth_allows_health_checks_without_api_key() -> None:
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.get("/healthz")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
async def test_mcp_http_auth_rejects_invalid_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "bad-key"}, json={})
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Invalid API key"
@pytest.mark.asyncio
async def test_mcp_http_auth_returns_500_on_non_auth_http_exception(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=HTTPException(status_code=500, detail="db down")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 500
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
@pytest.mark.asyncio
async def test_mcp_http_auth_returns_503_on_transient_validation_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=HTTPException(status_code=503, detail="API key validation temporarily unavailable")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 503
assert response.json()["error"]["code"] == "SERVICE_UNAVAILABLE"
assert response.json()["error"]["message"] == "API key validation temporarily unavailable"
@pytest.mark.asyncio
async def test_mcp_http_auth_returns_500_on_unexpected_validation_error(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=RuntimeError("boom")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 500
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
@pytest.mark.asyncio
async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(return_value="org_123"),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 200
assert response.json() == {
"api_key": "sk_live_abc",
"organization_id": "org_123",
}
assert client_mod.get_active_api_key() != "sk_live_abc"
@pytest.mark.asyncio
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached"))
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
assert first == "org_cached"
assert second == "org_cached"
assert calls == 1
@pytest.mark.asyncio
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}"))
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire")
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
assert first == "org_1"
assert second == "org_2"
assert calls == 2
@pytest.mark.asyncio
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
raise HTTPException(status_code=401, detail="Invalid credentials")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
with pytest.raises(HTTPException, match="Invalid credentials"):
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
with pytest.raises(HTTPException, match="Invalid API key"):
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
assert calls == 1
@pytest.mark.asyncio
async def test_validate_mcp_api_key_retries_transient_failure_without_negative_cache(
monkeypatch: pytest.MonkeyPatch,
) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
if calls == 1:
raise RuntimeError("transient db error")
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered"))
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
cache_key = mcp_http_auth._cache_key("sk_test_transient")
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered"
assert recovered_org == "org_recovered"
assert calls == 2
@pytest.mark.asyncio
async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Multiple concurrent callers for the same key all succeed; the cache
collapses subsequent calls after the first one populates it."""
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent"))
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
assert all(r == "org_concurrent" for r in results)
# First call populates cache; remaining may or may not hit DB depending on
# scheduling, but all must succeed.
assert calls >= 1
@pytest.mark.asyncio
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
raise RuntimeError("persistent db outage")
monkeypatch.setattr(mcp_http_auth, "_MAX_VALIDATION_RETRIES", 2)
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
with pytest.raises(HTTPException, match="temporarily unavailable") as exc_info:
await mcp_http_auth.validate_mcp_api_key("sk_test_transient_exhausted")
assert exc_info.value.status_code == 503
assert calls == 3 # initial + 2 retries
@pytest.mark.asyncio
async def test_close_auth_db_disposes_engine() -> None:
dispose = AsyncMock()
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
mcp_http_auth._api_key_validation_cache["k"] = ("org", 123.0)
await mcp_http_auth.close_auth_db()
dispose.assert_awaited_once()
assert mcp_http_auth._auth_db is None
assert mcp_http_auth._api_key_validation_cache == {}
@pytest.mark.asyncio
async def test_close_auth_db_noop_when_uninitialized() -> None:
mcp_http_auth._auth_db = None
await mcp_http_auth.close_auth_db()
assert mcp_http_auth._auth_db is None

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
@@ -14,10 +15,13 @@ from skyvern.cli.mcp_tools import session as mcp_session
@pytest.fixture(autouse=True)
def _reset_singletons() -> None:
client_mod._skyvern_instance.set(None)
client_mod._api_key_override.set(None)
client_mod._global_skyvern_instance = None
client_mod._api_key_clients.clear()
session_manager._current_session.set(None)
session_manager._global_session = None
session_manager.set_stateless_http_mode(False)
mcp_session.set_current_session(mcp_session.SessionState())
@@ -47,6 +51,115 @@ def test_get_skyvern_reuses_global_instance_across_contexts(monkeypatch: pytest.
assert len(created) == 1
def test_get_skyvern_reuses_override_instance_per_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
created_keys: list[str] = []
class FakeSkyvern:
def __init__(self, *args: object, **kwargs: object) -> None:
created_keys.append(kwargs["api_key"])
@classmethod
def local(cls) -> FakeSkyvern:
return cls(api_key="local")
async def aclose(self) -> None:
return None
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
token = client_mod.set_api_key_override("sk_key_a")
try:
first = client_mod.get_skyvern()
client_mod._skyvern_instance.set(None)
second = client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
assert first is second
assert created_keys == ["sk_key_a"]
def test_get_skyvern_override_client_cache_uses_lru_eviction(monkeypatch: pytest.MonkeyPatch) -> None:
created_keys: list[str] = []
class FakeSkyvern:
def __init__(self, *args: object, **kwargs: object) -> None:
created_keys.append(kwargs["api_key"])
@classmethod
def local(cls) -> FakeSkyvern:
return cls(api_key="local")
async def aclose(self) -> None:
return None
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 2)
for key in ("sk_key_a", "sk_key_b"):
token = client_mod.set_api_key_override(key)
try:
client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
# Touch key_a so key_b becomes least-recently-used.
token = client_mod.set_api_key_override("sk_key_a")
try:
client_mod._skyvern_instance.set(None)
client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
# Adding key_c should evict key_b.
token = client_mod.set_api_key_override("sk_key_c")
try:
client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
assert list(client_mod._api_key_clients.keys()) == [
client_mod._cache_key("sk_key_a"),
client_mod._cache_key("sk_key_c"),
]
# key_a, key_b, key_c were created exactly once each.
assert created_keys == ["sk_key_a", "sk_key_b", "sk_key_c"]
def test_get_skyvern_override_cache_closes_evicted_client(monkeypatch: pytest.MonkeyPatch) -> None:
closed_keys: list[str] = []
class FakeSkyvern:
def __init__(self, *args: object, **kwargs: object) -> None:
self.api_key = kwargs["api_key"]
@classmethod
def local(cls) -> FakeSkyvern:
return cls(api_key="local")
async def aclose(self) -> None:
closed_keys.append(self.api_key)
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 1)
for key in ("sk_key_a", "sk_key_b"):
token = client_mod.set_api_key_override(key)
try:
client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
assert list(client_mod._api_key_clients.keys()) == [client_mod._cache_key("sk_key_b")]
assert closed_keys == ["sk_key_a"]
@pytest.mark.asyncio
async def test_close_skyvern_closes_singleton() -> None:
fake = MagicMock()
@@ -75,12 +188,53 @@ def test_get_current_session_falls_back_to_global_state() -> None:
assert recovered is state
def test_get_current_session_stateless_mode_ignores_global_state() -> None:
global_state = session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_999"),
)
session_manager._global_session = global_state
session_manager._current_session.set(None)
session_manager.set_stateless_http_mode(True)
try:
recovered = session_manager.get_current_session()
finally:
session_manager.set_stateless_http_mode(False)
assert recovered is not global_state
assert recovered.browser is None
assert recovered.context is None
def test_set_current_session_stateless_mode_does_not_override_global_state() -> None:
global_state = session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
)
session_manager._global_session = global_state
replacement = session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_request"),
)
session_manager.set_stateless_http_mode(True)
try:
session_manager.set_current_session(replacement)
finally:
session_manager.set_stateless_http_mode(False)
assert session_manager._global_session is global_state
assert session_manager._current_session.get() is replacement
@pytest.mark.asyncio
async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None:
current_browser = MagicMock()
current_state = session_manager.SessionState(
browser=current_browser,
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
api_key_hash=session_manager._api_key_hash(client_mod.get_active_api_key()),
)
session_manager.set_current_session(current_state)
@@ -95,6 +249,92 @@ async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest
fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_browser_does_not_reuse_session_for_different_api_key(
monkeypatch: pytest.MonkeyPatch,
) -> None:
current_browser = MagicMock()
session_manager.set_current_session(
session_manager.SessionState(
browser=current_browser,
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
api_key_hash=session_manager._api_key_hash("sk_key_a"),
)
)
replacement_browser = MagicMock()
fake_skyvern = MagicMock()
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
token = client_mod.set_api_key_override("sk_key_b")
try:
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
finally:
client_mod.reset_api_key_override(token)
assert browser is replacement_browser
assert ctx.session_id == "pbs_123"
fake_skyvern.connect_to_cloud_browser_session.assert_awaited_once_with("pbs_123")
@pytest.mark.asyncio
async def test_resolve_browser_stateless_mode_does_not_write_global_session(monkeypatch: pytest.MonkeyPatch) -> None:
global_state = session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
)
session_manager._global_session = global_state
replacement_browser = MagicMock()
fake_skyvern = MagicMock()
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
session_manager.set_stateless_http_mode(True)
try:
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
finally:
session_manager.set_stateless_http_mode(False)
assert browser is replacement_browser
assert ctx.session_id == "pbs_123"
assert session_manager._global_session is global_state
@pytest.mark.asyncio
async def test_resolve_browser_blocks_implicit_session_in_stateless_mode() -> None:
session_manager.set_current_session(
session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
api_key_hash=session_manager._api_key_hash("sk_key_a"),
)
)
session_manager.set_stateless_http_mode(True)
try:
with pytest.raises(session_manager.BrowserNotAvailableError):
await session_manager.resolve_browser()
finally:
session_manager.set_stateless_http_mode(False)
@pytest.mark.asyncio
async def test_resolve_browser_raises_for_invalid_matching_state(monkeypatch: pytest.MonkeyPatch) -> None:
session_manager.set_current_session(
session_manager.SessionState(
browser=None,
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
)
)
monkeypatch.setattr(session_manager, "_matches_current", lambda *args, **kwargs: True)
monkeypatch.setattr(session_manager, "get_skyvern", lambda: MagicMock())
with pytest.raises(RuntimeError, match="Expected active browser and context"):
await session_manager.resolve_browser(session_id="pbs_123")
@pytest.mark.asyncio
async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None:
current_browser = MagicMock()
@@ -262,3 +502,73 @@ async def test_close_current_session_still_closes_browser_when_api_fails(monkeyp
# _browser_session_id should NOT be cleared (API close failed, let browser.close() try)
assert browser._browser_session_id == "pbs_fail"
assert session_manager.get_current_session().browser is None
# ---------------------------------------------------------------------------
# Tests for stateless HTTP mode session creation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_create_stateless_mode_returns_session_without_persisting_browser(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_manager.set_stateless_http_mode(True)
fake_skyvern = MagicMock()
fake_skyvern.create_browser_session = AsyncMock(return_value=SimpleNamespace(browser_session_id="pbs_abc"))
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
do_session_create = AsyncMock()
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
try:
result = await mcp_session.skyvern_session_create(timeout=45)
finally:
session_manager.set_stateless_http_mode(False)
assert result["ok"] is True
assert result["data"] == {"session_id": "pbs_abc", "timeout_minutes": 45}
do_session_create.assert_not_awaited()
assert mcp_session.get_current_session().browser is None
assert mcp_session.get_current_session().context is None
@pytest.mark.asyncio
async def test_session_create_stateless_mode_rejects_local() -> None:
session_manager.set_stateless_http_mode(True)
try:
result = await mcp_session.skyvern_session_create(local=True)
finally:
session_manager.set_stateless_http_mode(False)
assert result["ok"] is False
assert result["error"]["code"] == mcp_session.ErrorCode.INVALID_INPUT
@pytest.mark.asyncio
async def test_session_create_persists_active_api_key_hash_in_session_state(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_skyvern = MagicMock()
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
fake_browser = MagicMock()
do_session_create = AsyncMock(
return_value=(
fake_browser,
SimpleNamespace(local=False, session_id="pbs_123", timeout_minutes=60, headless=False),
)
)
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
token = client_mod.set_api_key_override("sk_key_create")
try:
result = await mcp_session.skyvern_session_create(timeout=60)
finally:
client_mod.reset_api_key_override(token)
assert result["ok"] is True
current = mcp_session.get_current_session()
assert current.browser is fake_browser
assert current.context == BrowserContext(mode="cloud_session", session_id="pbs_123")
assert current.api_key_hash == session_manager._api_key_hash("sk_key_create")
assert current.api_key_hash != "sk_key_create"

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, call
import pytest
@@ -13,6 +13,42 @@ def _reset_cleanup_state() -> None:
run_commands._mcp_cleanup_done = False
@pytest.mark.asyncio
async def test_cleanup_mcp_resources_closes_auth_db(monkeypatch: pytest.MonkeyPatch) -> None:
close_current_session = AsyncMock()
close_skyvern = AsyncMock()
close_auth_db = AsyncMock()
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
monkeypatch.setattr(run_commands, "close_skyvern", close_skyvern)
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
await run_commands._cleanup_mcp_resources()
close_current_session.assert_awaited_once()
close_skyvern.assert_awaited_once()
close_auth_db.assert_awaited_once()
@pytest.mark.asyncio
async def test_cleanup_mcp_resources_closes_auth_db_on_skyvern_close_error(monkeypatch: pytest.MonkeyPatch) -> None:
close_current_session = AsyncMock()
close_auth_db = AsyncMock()
async def _failing_close_skyvern() -> None:
raise RuntimeError("close failed")
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
monkeypatch.setattr(run_commands, "close_skyvern", _failing_close_skyvern)
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
with pytest.raises(RuntimeError, match="close failed"):
await run_commands._cleanup_mcp_resources()
close_current_session.assert_awaited_once()
close_auth_db.assert_awaited_once()
def test_cleanup_mcp_resources_sync_runs_without_running_loop(monkeypatch: pytest.MonkeyPatch) -> None:
cleanup = AsyncMock()
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup)
@@ -50,14 +86,56 @@ def test_run_mcp_calls_blocking_cleanup_in_finally(monkeypatch: pytest.MonkeyPat
cleanup_blocking = MagicMock()
register = MagicMock()
run = MagicMock(side_effect=RuntimeError("boom"))
set_stateless = MagicMock()
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
monkeypatch.setattr(run_commands.atexit, "register", register)
monkeypatch.setattr(run_commands.mcp, "run", run)
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
with pytest.raises(RuntimeError, match="boom"):
run_commands.run_mcp()
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
run.assert_called_once_with(transport="stdio")
set_stateless.assert_has_calls([call(False), call(False)])
cleanup_blocking.assert_called_once()
def test_run_mcp_http_transport_wires_auth_middleware(monkeypatch: pytest.MonkeyPatch) -> None:
cleanup_blocking = MagicMock()
register = MagicMock()
run = MagicMock()
set_stateless = MagicMock()
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
monkeypatch.setattr(run_commands.atexit, "register", register)
monkeypatch.setattr(run_commands.mcp, "run", run)
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
run_commands.run_mcp(
transport="streamable-http",
host="127.0.0.1",
port=9010,
path="mcp",
stateless_http=True,
)
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
run.assert_called_once()
kwargs = run.call_args.kwargs
assert kwargs["transport"] == "streamable-http"
assert kwargs["host"] == "127.0.0.1"
assert kwargs["port"] == 9010
assert kwargs["path"] == "/mcp"
assert kwargs["stateless_http"] is True
middleware = kwargs["middleware"]
assert len(middleware) == 1
assert middleware[0].cls is run_commands.MCPAPIKeyMiddleware
set_stateless.assert_has_calls([call(True), call(False)])
cleanup_blocking.assert_called_once()
def test_run_task_tool_registration_points_to_browser_module() -> None:
tool = run_commands.mcp._tool_manager._tools["skyvern_run_task"] # type: ignore[attr-defined]
assert tool.fn.__module__ == "skyvern.cli.mcp_tools.browser"