Files
Dorod-Sky/tests/unit/test_mcp_http_auth.py
Marc Kelechava 71f2b7a201 Add gated admin impersonation controls for MCP API-key auth (#4822)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-19 18:56:06 -08:00

551 lines
20 KiB
Python

from __future__ import annotations
import asyncio
import time
from types import SimpleNamespace
from unittest.mock import AsyncMock
import httpx
import pytest
from fastapi import HTTPException
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from skyvern.cli.core import client as client_mod
from skyvern.cli.core import mcp_http_auth
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
@pytest.fixture(autouse=True)
def _reset_auth_context() -> None:
client_mod._api_key_override.set(None)
mcp_http_auth._auth_db = None
mcp_http_auth._api_key_validation_cache.clear()
mcp_http_auth._API_KEY_CACHE_TTL_SECONDS = 30.0
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
mcp_http_auth.clear_all_impersonation_sessions()
async def _echo_request_context(request: Request) -> JSONResponse:
return JSONResponse(
{
"api_key": client_mod.get_active_api_key(),
"organization_id": getattr(request.state, "organization_id", None),
"admin_organization_id": getattr(request.state, "admin_organization_id", None),
"impersonation_target_organization_id": getattr(
request.state, "impersonation_target_organization_id", None
),
}
)
def _build_validation(
organization_id: str,
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
) -> mcp_http_auth.MCPAPIKeyValidation:
return mcp_http_auth.MCPAPIKeyValidation(
organization_id=organization_id,
token_type=token_type,
)
def _build_resolved_validation(
organization_id: str,
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
) -> SimpleNamespace:
return SimpleNamespace(
organization=SimpleNamespace(organization_id=organization_id),
token=SimpleNamespace(token_type=token_type),
)
def _build_test_app() -> Starlette:
return Starlette(
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
middleware=[Middleware(mcp_http_auth.MCPAPIKeyMiddleware)],
)
@pytest.mark.asyncio
async def test_mcp_http_auth_rejects_missing_api_key() -> None:
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", json={})
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert "x-api-key" in response.json()["error"]["message"]
@pytest.mark.asyncio
async def test_mcp_http_auth_allows_health_checks_without_api_key() -> None:
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.get("/healthz")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
async def test_mcp_http_auth_rejects_invalid_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "bad-key"}, json={})
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Invalid API key"
@pytest.mark.asyncio
async def test_mcp_http_auth_returns_500_on_non_auth_http_exception(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=HTTPException(status_code=500, detail="db down")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 500
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
@pytest.mark.asyncio
async def test_mcp_http_auth_returns_503_on_transient_validation_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=HTTPException(status_code=503, detail="API key validation temporarily unavailable")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 503
assert response.json()["error"]["code"] == "SERVICE_UNAVAILABLE"
assert response.json()["error"]["message"] == "API key validation temporarily unavailable"
@pytest.mark.asyncio
async def test_mcp_http_auth_returns_500_on_unexpected_validation_error(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=RuntimeError("boom")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 500
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
@pytest.mark.asyncio
async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(return_value=_build_validation("org_123")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 200
assert response.json() == {
"api_key": "sk_live_abc",
"organization_id": "org_123",
"admin_organization_id": None,
"impersonation_target_organization_id": None,
}
assert client_mod.get_active_api_key() != "sk_live_abc"
@pytest.mark.asyncio
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
return _build_resolved_validation("org_cached")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
assert first.organization_id == "org_cached"
assert second.organization_id == "org_cached"
assert calls == 1
@pytest.mark.asyncio
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
return _build_resolved_validation(f"org_{calls}")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
cache_key = mcp_http_auth.cache_key("sk_test_cache_expire")
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
assert first.organization_id == "org_1"
assert second.organization_id == "org_2"
assert calls == 2
@pytest.mark.asyncio
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
raise HTTPException(status_code=401, detail="Invalid credentials")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
with pytest.raises(HTTPException, match="Invalid credentials"):
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
with pytest.raises(HTTPException, match="Invalid API key"):
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
assert calls == 1
@pytest.mark.asyncio
async def test_validate_mcp_api_key_retries_transient_failure_without_negative_cache(
monkeypatch: pytest.MonkeyPatch,
) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
if calls == 1:
raise RuntimeError("transient db error")
return _build_resolved_validation("org_recovered")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
cache_key = mcp_http_auth.cache_key("sk_test_transient")
assert mcp_http_auth._api_key_validation_cache[cache_key][0].organization_id == "org_recovered"
assert recovered_org.organization_id == "org_recovered"
assert calls == 2
@pytest.mark.asyncio
async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Multiple concurrent callers for the same key all succeed; the cache
collapses subsequent calls after the first one populates it."""
calls = 0
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
return _build_resolved_validation("org_concurrent")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
assert all(r.organization_id == "org_concurrent" for r in results)
# First call populates cache; remaining may or may not hit DB depending on
# scheduling, but all must succeed.
assert calls >= 1
@pytest.mark.asyncio
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
raise RuntimeError("persistent db outage")
monkeypatch.setattr(mcp_http_auth, "_MAX_VALIDATION_RETRIES", 2)
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
with pytest.raises(HTTPException, match="temporarily unavailable") as exc_info:
await mcp_http_auth.validate_mcp_api_key("sk_test_transient_exhausted")
assert exc_info.value.status_code == 503
assert calls == 3 # initial + 2 retries
@pytest.mark.asyncio
async def test_close_auth_db_disposes_engine() -> None:
dispose = AsyncMock()
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
mcp_http_auth._api_key_validation_cache["k"] = ("org", 123.0)
await mcp_http_auth.close_auth_db()
dispose.assert_awaited_once()
assert mcp_http_auth._auth_db is None
assert mcp_http_auth._api_key_validation_cache == {}
@pytest.mark.asyncio
async def test_close_auth_db_noop_when_uninitialized() -> None:
mcp_http_auth._auth_db = None
await mcp_http_auth.close_auth_db()
assert mcp_http_auth._auth_db is None
@pytest.mark.asyncio
async def test_mcp_http_auth_denies_target_org_when_feature_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
validate_mock = AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
)
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
validate_mock,
)
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post(
"/mcp",
headers={"x-api-key": "sk_live_admin", "x-target-org-id": "org_target"},
json={},
)
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Impersonation not allowed"
validate_mock.assert_awaited_once_with("sk_live_admin")
@pytest.mark.asyncio
async def test_mcp_http_auth_validates_api_key_before_feature_flag_denial(monkeypatch: pytest.MonkeyPatch) -> None:
validate_mock = AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials"))
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post(
"/mcp",
headers={"x-api-key": "bad-key", "x-target-org-id": "org_target"},
json={},
)
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Invalid API key"
validate_mock.assert_awaited_once_with("bad-key")
@pytest.mark.parametrize("target_org_id", ["", " \t "])
@pytest.mark.asyncio
async def test_mcp_http_auth_denies_empty_or_whitespace_target_org_id_header(
monkeypatch: pytest.MonkeyPatch,
target_org_id: str,
) -> None:
validate_mock = AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
)
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: True)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post(
"/mcp",
headers={"x-api-key": "sk_live_admin", "x-target-org-id": target_org_id},
json={},
)
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Impersonation not allowed"
validate_mock.assert_awaited_once_with("sk_live_admin")
# ---------------------------------------------------------------------------
# Session-based impersonation tests
# ---------------------------------------------------------------------------
def test_impersonation_session_lifecycle() -> None:
"""set → get → clear lifecycle."""
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_target_key",
expires_at=time.monotonic() + 600,
ttl_minutes=10,
)
mcp_http_auth.set_impersonation_session(session)
retrieved = mcp_http_auth.get_active_impersonation(admin_hash)
assert retrieved is not None
assert retrieved.target_org_id == "org_target"
cleared = mcp_http_auth.clear_impersonation_session(admin_hash)
assert cleared is not None
assert cleared.target_org_id == "org_target"
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
def test_impersonation_session_auto_expiry() -> None:
"""Expired sessions are lazily cleaned up on get."""
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_target_key",
expires_at=time.monotonic() - 1, # already expired
ttl_minutes=1,
)
mcp_http_auth.set_impersonation_session(session)
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
@pytest.mark.asyncio
async def test_middleware_applies_session_impersonation(monkeypatch: pytest.MonkeyPatch) -> None:
"""When a session is active, middleware auto-applies impersonation without header."""
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
),
)
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_live_target_key",
expires_at=time.monotonic() + 600,
ttl_minutes=10,
)
mcp_http_auth.set_impersonation_session(session)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
assert response.status_code == 200
body = response.json()
assert body["api_key"] == "sk_live_target_key"
assert body["organization_id"] == "org_target"
assert body["admin_organization_id"] == "org_admin"
assert body["impersonation_target_organization_id"] == "org_target"
@pytest.mark.asyncio
async def test_middleware_ignores_expired_session(monkeypatch: pytest.MonkeyPatch) -> None:
"""Expired session is ignored — middleware reverts to admin's own org."""
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
),
)
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_live_target_key",
expires_at=time.monotonic() - 1, # expired
ttl_minutes=1,
)
mcp_http_auth.set_impersonation_session(session)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
assert response.status_code == 200
body = response.json()
assert body["api_key"] == "sk_live_admin"
assert body["organization_id"] == "org_admin"
assert body["admin_organization_id"] is None
@pytest.mark.asyncio
async def test_close_auth_db_clears_impersonation_sessions() -> None:
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_target_key",
expires_at=time.monotonic() + 600,
ttl_minutes=10,
)
mcp_http_auth.set_impersonation_session(session)
dispose = AsyncMock()
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
await mcp_http_auth.close_auth_db()
assert mcp_http_auth.get_active_impersonation(admin_hash) is None