align workflow CLI commands with MCP parity (#4792)
This commit is contained in:
@@ -296,3 +296,175 @@ class TestBrowserCommands:
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is False
|
||||
assert "Invalid state" in parsed["error"]["message"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow command behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWorkflowCommands:
|
||||
def test_workflow_get_outputs_mcp_envelope_in_json_mode(
|
||||
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
expected = {
|
||||
"ok": True,
|
||||
"action": "skyvern_workflow_get",
|
||||
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||
"data": {"workflow_permanent_id": "wpid_123"},
|
||||
"artifacts": [],
|
||||
"timing_ms": {},
|
||||
"warnings": [],
|
||||
"error": None,
|
||||
}
|
||||
tool = AsyncMock(return_value=expected)
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_get", tool)
|
||||
|
||||
workflow_cmd.workflow_get(workflow_id="wpid_123", version=2, json_output=True)
|
||||
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed == expected
|
||||
assert tool.await_args.kwargs == {"workflow_id": "wpid_123", "version": 2}
|
||||
|
||||
def test_workflow_create_reads_definition_from_file(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
capsys: pytest.CaptureFixture,
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
definition_file = tmp_path / "workflow.json"
|
||||
definition_text = '{"title": "Example", "workflow_definition": {"blocks": []}}'
|
||||
definition_file.write_text(definition_text)
|
||||
|
||||
tool = AsyncMock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"action": "skyvern_workflow_create",
|
||||
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||
"data": {"workflow_permanent_id": "wpid_new"},
|
||||
"artifacts": [],
|
||||
"timing_ms": {},
|
||||
"warnings": [],
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_create", tool)
|
||||
|
||||
workflow_cmd.workflow_create(
|
||||
definition=f"@{definition_file}",
|
||||
definition_format="json",
|
||||
folder_id="fld_123",
|
||||
json_output=True,
|
||||
)
|
||||
|
||||
assert tool.await_args.kwargs == {
|
||||
"definition": definition_text,
|
||||
"format": "json",
|
||||
"folder_id": "fld_123",
|
||||
}
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is True
|
||||
assert parsed["data"]["workflow_permanent_id"] == "wpid_new"
|
||||
|
||||
def test_workflow_run_reads_params_file_and_maps_options(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
capsys: pytest.CaptureFixture,
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
params_file = tmp_path / "params.json"
|
||||
params_file.write_text('{"company": "Acme"}')
|
||||
|
||||
tool = AsyncMock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"action": "skyvern_workflow_run",
|
||||
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||
"data": {"run_id": "wr_123", "status": "queued"},
|
||||
"artifacts": [],
|
||||
"timing_ms": {},
|
||||
"warnings": [],
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_run", tool)
|
||||
|
||||
workflow_cmd.workflow_run(
|
||||
workflow_id="wpid_123",
|
||||
params=f"@{params_file}",
|
||||
session="pbs_456",
|
||||
webhook="https://example.com/webhook",
|
||||
proxy="RESIDENTIAL",
|
||||
wait=True,
|
||||
timeout=450,
|
||||
json_output=True,
|
||||
)
|
||||
|
||||
assert tool.await_args.kwargs == {
|
||||
"workflow_id": "wpid_123",
|
||||
"parameters": '{"company": "Acme"}',
|
||||
"browser_session_id": "pbs_456",
|
||||
"webhook_url": "https://example.com/webhook",
|
||||
"proxy_location": "RESIDENTIAL",
|
||||
"wait": True,
|
||||
"timeout_seconds": 450,
|
||||
}
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is True
|
||||
assert parsed["data"]["run_id"] == "wr_123"
|
||||
|
||||
def test_workflow_status_json_error_exits_nonzero(
|
||||
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
tool = AsyncMock(
|
||||
return_value={
|
||||
"ok": False,
|
||||
"action": "skyvern_workflow_status",
|
||||
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||
"data": None,
|
||||
"artifacts": [],
|
||||
"timing_ms": {},
|
||||
"warnings": [],
|
||||
"error": {
|
||||
"code": "RUN_NOT_FOUND",
|
||||
"message": "Run 'wr_missing' not found",
|
||||
"hint": "Verify the run ID",
|
||||
"details": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_status", tool)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
workflow_cmd.workflow_status(run_id="wr_missing", json_output=True)
|
||||
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is False
|
||||
assert parsed["error"]["code"] == "RUN_NOT_FOUND"
|
||||
|
||||
def test_workflow_update_missing_definition_file_raises_bad_parameter(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
tool = AsyncMock()
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_update", tool)
|
||||
missing_file = tmp_path / "missing-definition.json"
|
||||
|
||||
with pytest.raises(typer.BadParameter, match="Unable to read definition file"):
|
||||
workflow_cmd.workflow_update(
|
||||
workflow_id="wpid_123",
|
||||
definition=f"@{missing_file}",
|
||||
definition_format="json",
|
||||
json_output=False,
|
||||
)
|
||||
|
||||
tool.assert_not_called()
|
||||
|
||||
304
tests/unit/test_mcp_http_auth.py
Normal file
304
tests/unit/test_mcp_http_auth.py
Normal file
@@ -0,0 +1,304 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from skyvern.cli.core import client as client_mod
|
||||
from skyvern.cli.core import mcp_http_auth
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_auth_context() -> None:
|
||||
client_mod._api_key_override.set(None)
|
||||
mcp_http_auth._auth_db = None
|
||||
mcp_http_auth._api_key_validation_cache.clear()
|
||||
mcp_http_auth._API_KEY_CACHE_TTL_SECONDS = 30.0
|
||||
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
|
||||
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
|
||||
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
|
||||
|
||||
|
||||
async def _echo_request_context(request: Request) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{
|
||||
"api_key": client_mod.get_active_api_key(),
|
||||
"organization_id": getattr(request.state, "organization_id", None),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_test_app() -> Starlette:
|
||||
return Starlette(
|
||||
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
|
||||
middleware=[Middleware(mcp_http_auth.MCPAPIKeyMiddleware)],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_rejects_missing_api_key() -> None:
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", json={})
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert "x-api-key" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_allows_health_checks_without_api_key() -> None:
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.get("/healthz")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_rejects_invalid_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "bad-key"}, json={})
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert response.json()["error"]["message"] == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_returns_500_on_non_auth_http_exception(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(side_effect=HTTPException(status_code=500, detail="db down")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||
|
||||
assert response.status_code == 500
|
||||
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_returns_503_on_transient_validation_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(side_effect=HTTPException(status_code=503, detail="API key validation temporarily unavailable")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||
|
||||
assert response.status_code == 503
|
||||
assert response.json()["error"]["code"] == "SERVICE_UNAVAILABLE"
|
||||
assert response.json()["error"]["message"] == "API key validation temporarily unavailable"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_returns_500_on_unexpected_validation_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(side_effect=RuntimeError("boom")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||
|
||||
assert response.status_code == 500
|
||||
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(return_value="org_123"),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"api_key": "sk_live_abc",
|
||||
"organization_id": "org_123",
|
||||
}
|
||||
assert client_mod.get_active_api_key() != "sk_live_abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached"))
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||
|
||||
assert first == "org_cached"
|
||||
assert second == "org_cached"
|
||||
assert calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}"))
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire")
|
||||
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
|
||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||
|
||||
assert first == "org_1"
|
||||
assert second == "org_2"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
with pytest.raises(HTTPException, match="Invalid credentials"):
|
||||
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
|
||||
|
||||
with pytest.raises(HTTPException, match="Invalid API key"):
|
||||
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
|
||||
|
||||
assert calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_retries_transient_failure_without_negative_cache(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
if calls == 1:
|
||||
raise RuntimeError("transient db error")
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered"))
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
|
||||
|
||||
cache_key = mcp_http_auth._cache_key("sk_test_transient")
|
||||
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered"
|
||||
|
||||
assert recovered_org == "org_recovered"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Multiple concurrent callers for the same key all succeed; the cache
|
||||
collapses subsequent calls after the first one populates it."""
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent"))
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
|
||||
assert all(r == "org_concurrent" for r in results)
|
||||
# First call populates cache; remaining may or may not hit DB depending on
|
||||
# scheduling, but all must succeed.
|
||||
assert calls >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise RuntimeError("persistent db outage")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "_MAX_VALIDATION_RETRIES", 2)
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
with pytest.raises(HTTPException, match="temporarily unavailable") as exc_info:
|
||||
await mcp_http_auth.validate_mcp_api_key("sk_test_transient_exhausted")
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert calls == 3 # initial + 2 retries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_auth_db_disposes_engine() -> None:
|
||||
dispose = AsyncMock()
|
||||
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
|
||||
mcp_http_auth._api_key_validation_cache["k"] = ("org", 123.0)
|
||||
|
||||
await mcp_http_auth.close_auth_db()
|
||||
|
||||
dispose.assert_awaited_once()
|
||||
assert mcp_http_auth._auth_db is None
|
||||
assert mcp_http_auth._api_key_validation_cache == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_auth_db_noop_when_uninitialized() -> None:
|
||||
mcp_http_auth._auth_db = None
|
||||
await mcp_http_auth.close_auth_db()
|
||||
assert mcp_http_auth._auth_db is None
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -14,10 +15,13 @@ from skyvern.cli.mcp_tools import session as mcp_session
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singletons() -> None:
|
||||
client_mod._skyvern_instance.set(None)
|
||||
client_mod._api_key_override.set(None)
|
||||
client_mod._global_skyvern_instance = None
|
||||
client_mod._api_key_clients.clear()
|
||||
|
||||
session_manager._current_session.set(None)
|
||||
session_manager._global_session = None
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
mcp_session.set_current_session(mcp_session.SessionState())
|
||||
|
||||
|
||||
@@ -47,6 +51,115 @@ def test_get_skyvern_reuses_global_instance_across_contexts(monkeypatch: pytest.
|
||||
assert len(created) == 1
|
||||
|
||||
|
||||
def test_get_skyvern_reuses_override_instance_per_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created_keys: list[str] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
created_keys.append(kwargs["api_key"])
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls(api_key="local")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
|
||||
token = client_mod.set_api_key_override("sk_key_a")
|
||||
try:
|
||||
first = client_mod.get_skyvern()
|
||||
client_mod._skyvern_instance.set(None)
|
||||
second = client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert first is second
|
||||
assert created_keys == ["sk_key_a"]
|
||||
|
||||
|
||||
def test_get_skyvern_override_client_cache_uses_lru_eviction(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created_keys: list[str] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
created_keys.append(kwargs["api_key"])
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls(api_key="local")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 2)
|
||||
|
||||
for key in ("sk_key_a", "sk_key_b"):
|
||||
token = client_mod.set_api_key_override(key)
|
||||
try:
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
# Touch key_a so key_b becomes least-recently-used.
|
||||
token = client_mod.set_api_key_override("sk_key_a")
|
||||
try:
|
||||
client_mod._skyvern_instance.set(None)
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
# Adding key_c should evict key_b.
|
||||
token = client_mod.set_api_key_override("sk_key_c")
|
||||
try:
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert list(client_mod._api_key_clients.keys()) == [
|
||||
client_mod._cache_key("sk_key_a"),
|
||||
client_mod._cache_key("sk_key_c"),
|
||||
]
|
||||
# key_a, key_b, key_c were created exactly once each.
|
||||
assert created_keys == ["sk_key_a", "sk_key_b", "sk_key_c"]
|
||||
|
||||
|
||||
def test_get_skyvern_override_cache_closes_evicted_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
closed_keys: list[str] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
self.api_key = kwargs["api_key"]
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls(api_key="local")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
closed_keys.append(self.api_key)
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 1)
|
||||
|
||||
for key in ("sk_key_a", "sk_key_b"):
|
||||
token = client_mod.set_api_key_override(key)
|
||||
try:
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert list(client_mod._api_key_clients.keys()) == [client_mod._cache_key("sk_key_b")]
|
||||
assert closed_keys == ["sk_key_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_skyvern_closes_singleton() -> None:
|
||||
fake = MagicMock()
|
||||
@@ -75,12 +188,53 @@ def test_get_current_session_falls_back_to_global_state() -> None:
|
||||
assert recovered is state
|
||||
|
||||
|
||||
def test_get_current_session_stateless_mode_ignores_global_state() -> None:
|
||||
global_state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_999"),
|
||||
)
|
||||
session_manager._global_session = global_state
|
||||
session_manager._current_session.set(None)
|
||||
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
recovered = session_manager.get_current_session()
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert recovered is not global_state
|
||||
assert recovered.browser is None
|
||||
assert recovered.context is None
|
||||
|
||||
|
||||
def test_set_current_session_stateless_mode_does_not_override_global_state() -> None:
|
||||
global_state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
|
||||
)
|
||||
session_manager._global_session = global_state
|
||||
replacement = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_request"),
|
||||
)
|
||||
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
session_manager.set_current_session(replacement)
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert session_manager._global_session is global_state
|
||||
assert session_manager._current_session.get() is replacement
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_browser = MagicMock()
|
||||
current_state = session_manager.SessionState(
|
||||
browser=current_browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
api_key_hash=session_manager._api_key_hash(client_mod.get_active_api_key()),
|
||||
)
|
||||
session_manager.set_current_session(current_state)
|
||||
|
||||
@@ -95,6 +249,92 @@ async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest
|
||||
fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_does_not_reuse_session_for_different_api_key(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
current_browser = MagicMock()
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=current_browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
api_key_hash=session_manager._api_key_hash("sk_key_a"),
|
||||
)
|
||||
)
|
||||
|
||||
replacement_browser = MagicMock()
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
token = client_mod.set_api_key_override("sk_key_b")
|
||||
try:
|
||||
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert browser is replacement_browser
|
||||
assert ctx.session_id == "pbs_123"
|
||||
fake_skyvern.connect_to_cloud_browser_session.assert_awaited_once_with("pbs_123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_stateless_mode_does_not_write_global_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
global_state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
|
||||
)
|
||||
session_manager._global_session = global_state
|
||||
|
||||
replacement_browser = MagicMock()
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert browser is replacement_browser
|
||||
assert ctx.session_id == "pbs_123"
|
||||
assert session_manager._global_session is global_state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_blocks_implicit_session_in_stateless_mode() -> None:
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
api_key_hash=session_manager._api_key_hash("sk_key_a"),
|
||||
)
|
||||
)
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
with pytest.raises(session_manager.BrowserNotAvailableError):
|
||||
await session_manager.resolve_browser()
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_raises_for_invalid_matching_state(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=None,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(session_manager, "_matches_current", lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: MagicMock())
|
||||
|
||||
with pytest.raises(RuntimeError, match="Expected active browser and context"):
|
||||
await session_manager.resolve_browser(session_id="pbs_123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_browser = MagicMock()
|
||||
@@ -262,3 +502,73 @@ async def test_close_current_session_still_closes_browser_when_api_fails(monkeyp
|
||||
# _browser_session_id should NOT be cleared (API close failed, let browser.close() try)
|
||||
assert browser._browser_session_id == "pbs_fail"
|
||||
assert session_manager.get_current_session().browser is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for stateless HTTP mode session creation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_create_stateless_mode_returns_session_without_persisting_browser(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.create_browser_session = AsyncMock(return_value=SimpleNamespace(browser_session_id="pbs_abc"))
|
||||
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||
do_session_create = AsyncMock()
|
||||
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
|
||||
|
||||
try:
|
||||
result = await mcp_session.skyvern_session_create(timeout=45)
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["data"] == {"session_id": "pbs_abc", "timeout_minutes": 45}
|
||||
do_session_create.assert_not_awaited()
|
||||
assert mcp_session.get_current_session().browser is None
|
||||
assert mcp_session.get_current_session().context is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_create_stateless_mode_rejects_local() -> None:
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
result = await mcp_session.skyvern_session_create(local=True)
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["error"]["code"] == mcp_session.ErrorCode.INVALID_INPUT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_create_persists_active_api_key_hash_in_session_state(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_skyvern = MagicMock()
|
||||
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
fake_browser = MagicMock()
|
||||
do_session_create = AsyncMock(
|
||||
return_value=(
|
||||
fake_browser,
|
||||
SimpleNamespace(local=False, session_id="pbs_123", timeout_minutes=60, headless=False),
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
|
||||
|
||||
token = client_mod.set_api_key_override("sk_key_create")
|
||||
try:
|
||||
result = await mcp_session.skyvern_session_create(timeout=60)
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert result["ok"] is True
|
||||
current = mcp_session.get_current_session()
|
||||
assert current.browser is fake_browser
|
||||
assert current.context == BrowserContext(mode="cloud_session", session_id="pbs_123")
|
||||
assert current.api_key_hash == session_manager._api_key_hash("sk_key_create")
|
||||
assert current.api_key_hash != "sk_key_create"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,6 +13,42 @@ def _reset_cleanup_state() -> None:
|
||||
run_commands._mcp_cleanup_done = False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_mcp_resources_closes_auth_db(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
close_current_session = AsyncMock()
|
||||
close_skyvern = AsyncMock()
|
||||
close_auth_db = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
|
||||
monkeypatch.setattr(run_commands, "close_skyvern", close_skyvern)
|
||||
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
|
||||
|
||||
await run_commands._cleanup_mcp_resources()
|
||||
|
||||
close_current_session.assert_awaited_once()
|
||||
close_skyvern.assert_awaited_once()
|
||||
close_auth_db.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_mcp_resources_closes_auth_db_on_skyvern_close_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
close_current_session = AsyncMock()
|
||||
close_auth_db = AsyncMock()
|
||||
|
||||
async def _failing_close_skyvern() -> None:
|
||||
raise RuntimeError("close failed")
|
||||
|
||||
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
|
||||
monkeypatch.setattr(run_commands, "close_skyvern", _failing_close_skyvern)
|
||||
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
|
||||
|
||||
with pytest.raises(RuntimeError, match="close failed"):
|
||||
await run_commands._cleanup_mcp_resources()
|
||||
|
||||
close_current_session.assert_awaited_once()
|
||||
close_auth_db.assert_awaited_once()
|
||||
|
||||
|
||||
def test_cleanup_mcp_resources_sync_runs_without_running_loop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = AsyncMock()
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup)
|
||||
@@ -50,14 +86,56 @@ def test_run_mcp_calls_blocking_cleanup_in_finally(monkeypatch: pytest.MonkeyPat
|
||||
cleanup_blocking = MagicMock()
|
||||
register = MagicMock()
|
||||
run = MagicMock(side_effect=RuntimeError("boom"))
|
||||
set_stateless = MagicMock()
|
||||
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
|
||||
monkeypatch.setattr(run_commands.atexit, "register", register)
|
||||
monkeypatch.setattr(run_commands.mcp, "run", run)
|
||||
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
run_commands.run_mcp()
|
||||
|
||||
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
|
||||
run.assert_called_once_with(transport="stdio")
|
||||
set_stateless.assert_has_calls([call(False), call(False)])
|
||||
cleanup_blocking.assert_called_once()
|
||||
|
||||
|
||||
def test_run_mcp_http_transport_wires_auth_middleware(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup_blocking = MagicMock()
|
||||
register = MagicMock()
|
||||
run = MagicMock()
|
||||
set_stateless = MagicMock()
|
||||
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
|
||||
monkeypatch.setattr(run_commands.atexit, "register", register)
|
||||
monkeypatch.setattr(run_commands.mcp, "run", run)
|
||||
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
|
||||
|
||||
run_commands.run_mcp(
|
||||
transport="streamable-http",
|
||||
host="127.0.0.1",
|
||||
port=9010,
|
||||
path="mcp",
|
||||
stateless_http=True,
|
||||
)
|
||||
|
||||
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
|
||||
run.assert_called_once()
|
||||
kwargs = run.call_args.kwargs
|
||||
assert kwargs["transport"] == "streamable-http"
|
||||
assert kwargs["host"] == "127.0.0.1"
|
||||
assert kwargs["port"] == 9010
|
||||
assert kwargs["path"] == "/mcp"
|
||||
assert kwargs["stateless_http"] is True
|
||||
middleware = kwargs["middleware"]
|
||||
assert len(middleware) == 1
|
||||
assert middleware[0].cls is run_commands.MCPAPIKeyMiddleware
|
||||
set_stateless.assert_has_calls([call(True), call(False)])
|
||||
cleanup_blocking.assert_called_once()
|
||||
|
||||
|
||||
def test_run_task_tool_registration_points_to_browser_module() -> None:
|
||||
tool = run_commands.mcp._tool_manager._tools["skyvern_run_task"] # type: ignore[attr-defined]
|
||||
assert tool.fn.__module__ == "skyvern.cli.mcp_tools.browser"
|
||||
|
||||
Reference in New Issue
Block a user