Add gated admin impersonation controls for MCP API-key auth (#4822)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
@@ -15,6 +16,7 @@ from starlette.routing import Route
|
||||
|
||||
from skyvern.cli.core import client as client_mod
|
||||
from skyvern.cli.core import mcp_http_auth
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -26,6 +28,7 @@ def _reset_auth_context() -> None:
|
||||
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
|
||||
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
|
||||
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
|
||||
mcp_http_auth.clear_all_impersonation_sessions()
|
||||
|
||||
|
||||
async def _echo_request_context(request: Request) -> JSONResponse:
|
||||
@@ -33,10 +36,34 @@ async def _echo_request_context(request: Request) -> JSONResponse:
|
||||
{
|
||||
"api_key": client_mod.get_active_api_key(),
|
||||
"organization_id": getattr(request.state, "organization_id", None),
|
||||
"admin_organization_id": getattr(request.state, "admin_organization_id", None),
|
||||
"impersonation_target_organization_id": getattr(
|
||||
request.state, "impersonation_target_organization_id", None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_validation(
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
|
||||
) -> mcp_http_auth.MCPAPIKeyValidation:
|
||||
return mcp_http_auth.MCPAPIKeyValidation(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
)
|
||||
|
||||
|
||||
def _build_resolved_validation(
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
organization=SimpleNamespace(organization_id=organization_id),
|
||||
token=SimpleNamespace(token_type=token_type),
|
||||
)
|
||||
|
||||
|
||||
def _build_test_app() -> Starlette:
|
||||
return Starlette(
|
||||
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
|
||||
@@ -138,7 +165,7 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(return_value="org_123"),
|
||||
AsyncMock(return_value=_build_validation("org_123")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
@@ -149,6 +176,8 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
|
||||
assert response.json() == {
|
||||
"api_key": "sk_live_abc",
|
||||
"organization_id": "org_123",
|
||||
"admin_organization_id": None,
|
||||
"impersonation_target_organization_id": None,
|
||||
}
|
||||
assert client_mod.get_active_api_key() != "sk_live_abc"
|
||||
|
||||
@@ -157,10 +186,10 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
|
||||
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached"))
|
||||
return _build_resolved_validation("org_cached")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
@@ -168,8 +197,8 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
|
||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||
|
||||
assert first == "org_cached"
|
||||
assert second == "org_cached"
|
||||
assert first.organization_id == "org_cached"
|
||||
assert second.organization_id == "org_cached"
|
||||
assert calls == 1
|
||||
|
||||
|
||||
@@ -177,21 +206,21 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
|
||||
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}"))
|
||||
return _build_resolved_validation(f"org_{calls}")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire")
|
||||
cache_key = mcp_http_auth.cache_key("sk_test_cache_expire")
|
||||
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
|
||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||
|
||||
assert first == "org_1"
|
||||
assert second == "org_2"
|
||||
assert first.organization_id == "org_1"
|
||||
assert second.organization_id == "org_2"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
@@ -199,7 +228,7 @@ async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatc
|
||||
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
@@ -222,22 +251,22 @@ async def test_validate_mcp_api_key_retries_transient_failure_without_negative_c
|
||||
) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
if calls == 1:
|
||||
raise RuntimeError("transient db error")
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered"))
|
||||
return _build_resolved_validation("org_recovered")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
|
||||
|
||||
cache_key = mcp_http_auth._cache_key("sk_test_transient")
|
||||
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered"
|
||||
cache_key = mcp_http_auth.cache_key("sk_test_transient")
|
||||
assert mcp_http_auth._api_key_validation_cache[cache_key][0].organization_id == "org_recovered"
|
||||
|
||||
assert recovered_org == "org_recovered"
|
||||
assert recovered_org.organization_id == "org_recovered"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
@@ -249,16 +278,16 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
||||
collapses subsequent calls after the first one populates it."""
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent"))
|
||||
return _build_resolved_validation("org_concurrent")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
|
||||
assert all(r == "org_concurrent" for r in results)
|
||||
assert all(r.organization_id == "org_concurrent" for r in results)
|
||||
# First call populates cache; remaining may or may not hit DB depending on
|
||||
# scheduling, but all must succeed.
|
||||
assert calls >= 1
|
||||
@@ -268,7 +297,7 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
||||
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise RuntimeError("persistent db outage")
|
||||
@@ -302,3 +331,220 @@ async def test_close_auth_db_noop_when_uninitialized() -> None:
|
||||
mcp_http_auth._auth_db = None
|
||||
await mcp_http_auth.close_auth_db()
|
||||
assert mcp_http_auth._auth_db is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_denies_target_org_when_feature_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
validate_mock = AsyncMock(
|
||||
return_value=_build_validation(
|
||||
"org_admin",
|
||||
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
validate_mock,
|
||||
)
|
||||
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post(
|
||||
"/mcp",
|
||||
headers={"x-api-key": "sk_live_admin", "x-target-org-id": "org_target"},
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert response.json()["error"]["message"] == "Impersonation not allowed"
|
||||
validate_mock.assert_awaited_once_with("sk_live_admin")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_validates_api_key_before_feature_flag_denial(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
validate_mock = AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials"))
|
||||
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
|
||||
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post(
|
||||
"/mcp",
|
||||
headers={"x-api-key": "bad-key", "x-target-org-id": "org_target"},
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert response.json()["error"]["message"] == "Invalid API key"
|
||||
validate_mock.assert_awaited_once_with("bad-key")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("target_org_id", ["", " \t "])
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_denies_empty_or_whitespace_target_org_id_header(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
target_org_id: str,
|
||||
) -> None:
|
||||
validate_mock = AsyncMock(
|
||||
return_value=_build_validation(
|
||||
"org_admin",
|
||||
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
|
||||
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: True)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post(
|
||||
"/mcp",
|
||||
headers={"x-api-key": "sk_live_admin", "x-target-org-id": target_org_id},
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert response.json()["error"]["message"] == "Impersonation not allowed"
|
||||
validate_mock.assert_awaited_once_with("sk_live_admin")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session-based impersonation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_impersonation_session_lifecycle() -> None:
|
||||
"""set → get → clear lifecycle."""
|
||||
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_target_key",
|
||||
expires_at=time.monotonic() + 600,
|
||||
ttl_minutes=10,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
retrieved = mcp_http_auth.get_active_impersonation(admin_hash)
|
||||
assert retrieved is not None
|
||||
assert retrieved.target_org_id == "org_target"
|
||||
|
||||
cleared = mcp_http_auth.clear_impersonation_session(admin_hash)
|
||||
assert cleared is not None
|
||||
assert cleared.target_org_id == "org_target"
|
||||
|
||||
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
|
||||
|
||||
|
||||
def test_impersonation_session_auto_expiry() -> None:
|
||||
"""Expired sessions are lazily cleaned up on get."""
|
||||
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_target_key",
|
||||
expires_at=time.monotonic() - 1, # already expired
|
||||
ttl_minutes=1,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_applies_session_impersonation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""When a session is active, middleware auto-applies impersonation without header."""
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(
|
||||
return_value=_build_validation(
|
||||
"org_admin",
|
||||
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_live_target_key",
|
||||
expires_at=time.monotonic() + 600,
|
||||
ttl_minutes=10,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
app = _build_test_app()
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["api_key"] == "sk_live_target_key"
|
||||
assert body["organization_id"] == "org_target"
|
||||
assert body["admin_organization_id"] == "org_admin"
|
||||
assert body["impersonation_target_organization_id"] == "org_target"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_ignores_expired_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Expired session is ignored — middleware reverts to admin's own org."""
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(
|
||||
return_value=_build_validation(
|
||||
"org_admin",
|
||||
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_live_target_key",
|
||||
expires_at=time.monotonic() - 1, # expired
|
||||
ttl_minutes=1,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
app = _build_test_app()
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["api_key"] == "sk_live_admin"
|
||||
assert body["organization_id"] == "org_admin"
|
||||
assert body["admin_organization_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_auth_db_clears_impersonation_sessions() -> None:
|
||||
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_target_key",
|
||||
expires_at=time.monotonic() + 600,
|
||||
ttl_minutes=10,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
dispose = AsyncMock()
|
||||
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
|
||||
|
||||
await mcp_http_auth.close_auth_db()
|
||||
|
||||
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
|
||||
|
||||
Reference in New Issue
Block a user