Add gated admin impersonation controls for MCP API-key auth (#4822)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Marc Kelechava
2026-02-19 18:56:06 -08:00
committed by GitHub
parent 36e600eeb9
commit 71f2b7a201
5 changed files with 520 additions and 41 deletions

View File

@@ -4,6 +4,8 @@ import asyncio
import os
import time
from collections import OrderedDict
from contextvars import ContextVar, Token
from dataclasses import dataclass
from threading import RLock
import structlog
@@ -14,6 +16,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
from skyvern.config import settings
from skyvern.forge.sdk.db.agent_db import AgentDB
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
from .api_key_hash import hash_api_key_for_cache
@@ -21,16 +24,94 @@ from .client import reset_api_key_override, set_api_key_override
LOG = structlog.get_logger(__name__)
API_KEY_HEADER = "x-api-key"
TARGET_ORG_ID_HEADER = "x-target-org-id"
HEALTH_PATHS = {"/health", "/healthz"}
_MCP_ALLOWED_TOKEN_TYPES = (
OrganizationAuthTokenType.api,
OrganizationAuthTokenType.mcp_admin_impersonation,
)
_auth_db: AgentDB | None = None
_auth_db_lock = RLock()
_api_key_cache_lock = RLock()
_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict()
_api_key_validation_cache: OrderedDict[str, tuple[MCPAPIKeyValidation | None, float]] = OrderedDict()
_NEGATIVE_CACHE_TTL_SECONDS = 5.0
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
_MAX_VALIDATION_RETRIES = 2
_RETRY_DELAY_SECONDS = 0.25
# ---------------------------------------------------------------------------
# Impersonation session state
# ---------------------------------------------------------------------------
_admin_api_key_hash: ContextVar[str | None] = ContextVar("admin_api_key_hash", default=None)
@dataclass(frozen=True)
class ImpersonationSession:
admin_api_key_hash: str
admin_org_id: str
target_org_id: str
# Stored in plaintext in process memory for up to TTL. Acceptable V1 trade-off
# (in-process only, same as API key cache).
target_api_key: str
expires_at: float # time.monotonic() deadline
ttl_minutes: int
_impersonation_sessions: dict[str, ImpersonationSession] = {}
_impersonation_lock = RLock()
MAX_TTL_MINUTES = 120
DEFAULT_TTL_MINUTES = 30
_MAX_IMPERSONATION_SESSIONS = 100
def get_active_impersonation(admin_key_hash: str) -> ImpersonationSession | None:
with _impersonation_lock:
session = _impersonation_sessions.get(admin_key_hash)
if session is None:
return None
if session.expires_at <= time.monotonic():
_impersonation_sessions.pop(admin_key_hash, None)
return None
return session
def set_impersonation_session(session: ImpersonationSession) -> None:
with _impersonation_lock:
_impersonation_sessions[session.admin_api_key_hash] = session
# Sweep expired sessions, then evict oldest if over capacity
now = time.monotonic()
expired_keys = [k for k, s in _impersonation_sessions.items() if s.expires_at <= now]
for k in expired_keys:
_impersonation_sessions.pop(k, None)
if len(_impersonation_sessions) > _MAX_IMPERSONATION_SESSIONS:
oldest_key = min(_impersonation_sessions, key=lambda k: _impersonation_sessions[k].expires_at)
_impersonation_sessions.pop(oldest_key, None)
def clear_impersonation_session(admin_key_hash: str) -> ImpersonationSession | None:
with _impersonation_lock:
return _impersonation_sessions.pop(admin_key_hash, None)
def clear_all_impersonation_sessions() -> None:
with _impersonation_lock:
_impersonation_sessions.clear()
def get_admin_api_key_hash() -> str | None:
return _admin_api_key_hash.get()
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class MCPAPIKeyValidation:
organization_id: str
token_type: OrganizationAuthTokenType
def _resolve_api_key_cache_ttl_seconds() -> float:
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
@@ -69,6 +150,7 @@ async def close_auth_db() -> None:
_auth_db = None
with _api_key_cache_lock:
_api_key_validation_cache.clear()
clear_all_impersonation_sessions()
if db is None:
return
@@ -78,24 +160,43 @@ async def close_auth_db() -> None:
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
def _cache_key(api_key: str) -> str:
def cache_key(api_key: str) -> str:
return hash_api_key_for_cache(api_key)
async def validate_mcp_api_key(api_key: str) -> str:
"""Validate API key and return organization id for observability."""
key = _cache_key(api_key)
def _admin_organization_ids() -> set[str]:
try:
from cloud.config import settings as cloud_settings # noqa: PLC0415
admin_ids = getattr(cloud_settings, "ADMIN_ORGANIZATION_IDS", [])
except ImportError:
admin_ids = getattr(settings, "ADMIN_ORGANIZATION_IDS", [])
return {org_id for org_id in admin_ids if org_id}
def _is_admin_impersonation_enabled() -> bool:
try:
from cloud.config import settings as cloud_settings # noqa: PLC0415
return bool(getattr(cloud_settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False))
except ImportError:
return bool(getattr(settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False))
async def validate_mcp_api_key(api_key: str) -> MCPAPIKeyValidation:
"""Validate API key and return caller organization + token type."""
key = cache_key(api_key)
# Check cache first.
with _api_key_cache_lock:
cached = _api_key_validation_cache.get(key)
if cached is not None:
organization_id, expires_at = cached
cached_validation, expires_at = cached
if expires_at > time.monotonic():
_api_key_validation_cache.move_to_end(key)
if organization_id is None:
if cached_validation is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return organization_id
return cached_validation
_api_key_validation_cache.pop(key, None)
# Cache miss — do the DB lookup with simple retry on transient errors.
@@ -104,17 +205,24 @@ async def validate_mcp_api_key(api_key: str) -> str:
if attempt > 0:
await asyncio.sleep(_RETRY_DELAY_SECONDS)
try:
validation = await resolve_org_from_api_key(api_key, _get_auth_db())
organization_id = validation.organization.organization_id
validation = await resolve_org_from_api_key(
api_key,
_get_auth_db(),
token_types=_MCP_ALLOWED_TOKEN_TYPES,
)
caller_validation = MCPAPIKeyValidation(
organization_id=validation.organization.organization_id,
token_type=validation.token.token_type,
)
with _api_key_cache_lock:
_api_key_validation_cache[key] = (
organization_id,
caller_validation,
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
)
_api_key_validation_cache.move_to_end(key)
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
_api_key_validation_cache.popitem(last=False)
return organization_id
return caller_validation
except HTTPException as e:
if e.status_code in {401, 403}:
with _api_key_cache_lock:
@@ -149,6 +257,24 @@ def _service_unavailable_response(message: str) -> JSONResponse:
)
def _deny_impersonation(
*,
reason: str,
caller_organization_id: str | None,
target_organization_id: str | None,
token_type: OrganizationAuthTokenType | None = None,
) -> JSONResponse:
log_kwargs: dict[str, object] = {
"reason": reason,
"caller_organization_id": caller_organization_id,
"target_organization_id": target_organization_id,
}
if token_type is not None:
log_kwargs["token_type"] = token_type.value
LOG.warning("MCP admin impersonation denied", **log_kwargs)
return _unauthorized_response("Impersonation not allowed")
class MCPAPIKeyMiddleware:
"""Require x-api-key for MCP HTTP transport and scope requests to that key."""
@@ -176,10 +302,99 @@ class MCPAPIKeyMiddleware:
await response(scope, receive, send)
return
target_org_id_header = request.headers.get(TARGET_ORG_ID_HEADER)
admin_hash_token: Token[str | None] | None = None
try:
organization_id = await validate_mcp_api_key(api_key)
validation = await validate_mcp_api_key(api_key)
caller_organization_id = validation.organization_id
admin_key_hash = cache_key(api_key)
scope.setdefault("state", {})
scope["state"]["organization_id"] = organization_id
if target_org_id_header is not None:
# Explicit header-based impersonation (takes priority over session)
if not _is_admin_impersonation_enabled():
response = _deny_impersonation(
reason="feature_disabled",
caller_organization_id=caller_organization_id,
target_organization_id=target_org_id_header.strip() or target_org_id_header,
)
await response(scope, receive, send)
return
target_organization_id = target_org_id_header.strip()
if not target_organization_id:
response = _deny_impersonation(
reason="missing_target_organization_id",
caller_organization_id=caller_organization_id,
target_organization_id=target_org_id_header,
token_type=validation.token_type,
)
await response(scope, receive, send)
return
# Delegate validation to the single source of truth in cloud/.
# The import can't fail here — _is_admin_impersonation_enabled()
# already returned True, which requires cloud.config to be importable.
try:
from cloud.mcp_admin_tools import validate_impersonation_target # noqa: PLC0415
except ImportError:
response = _deny_impersonation(
reason="impersonation_not_available",
caller_organization_id=caller_organization_id,
target_organization_id=target_organization_id,
)
await response(scope, receive, send)
return
result = await validate_impersonation_target(
caller_organization_id=caller_organization_id,
target_organization_id=target_organization_id,
token_type=validation.token_type,
)
if isinstance(result, str):
response = _deny_impersonation(
reason=result,
caller_organization_id=caller_organization_id,
target_organization_id=target_organization_id,
token_type=validation.token_type,
)
await response(scope, receive, send)
return
resolved_org_id, target_api_key = result
api_key = target_api_key
scope["state"]["organization_id"] = resolved_org_id
scope["state"]["admin_organization_id"] = caller_organization_id
scope["state"]["impersonation_target_organization_id"] = resolved_org_id
admin_hash_token = _admin_api_key_hash.set(admin_key_hash)
LOG.info(
"MCP admin impersonation allowed",
caller_organization_id=caller_organization_id,
target_organization_id=resolved_org_id,
token_type=validation.token_type.value,
)
else:
# No explicit header — check for session-based impersonation
session = get_active_impersonation(admin_key_hash)
if session is not None:
# Session data (target API key) is cached for TTL. If the target org's
# key is revoked mid-session, impersonation continues until expiry —
# acceptable trade-off vs re-validating every request.
api_key = session.target_api_key
scope["state"]["organization_id"] = session.target_org_id
scope["state"]["admin_organization_id"] = session.admin_org_id
scope["state"]["impersonation_target_organization_id"] = session.target_org_id
admin_hash_token = _admin_api_key_hash.set(admin_key_hash)
LOG.info(
"MCP session impersonation applied",
caller_organization_id=session.admin_org_id,
target_organization_id=session.target_org_id,
ttl_minutes=session.ttl_minutes,
)
else:
scope["state"]["organization_id"] = caller_organization_id
except HTTPException as e:
if e.status_code in {401, 403}:
response = _unauthorized_response("Invalid API key")
@@ -204,3 +419,5 @@ class MCPAPIKeyMiddleware:
await self.app(scope, receive, send)
finally:
reset_api_key_override(token)
if admin_hash_token is not None:
_admin_api_key_hash.reset(admin_hash_token)

View File

@@ -286,6 +286,14 @@ mcp.tool()(skyvern_workflow_run)
mcp.tool()(skyvern_workflow_status)
mcp.tool()(skyvern_workflow_cancel)
# -- Admin impersonation (cloud-only, session-level org switching) --
try:
from cloud.mcp_admin_tools import register_admin_tools # noqa: PLC0415
register_admin_tools(mcp)
except ImportError:
pass
# -- Prompts (methodology guides injected into LLM conversations) --
mcp.prompt()(build_workflow)
mcp.prompt()(debug_automation)

View File

@@ -3,6 +3,7 @@ from enum import StrEnum
class OrganizationAuthTokenType(StrEnum):
api = "api"
mcp_admin_impersonation = "mcp_admin_impersonation"
onepassword_service_account = "onepassword_service_account"
azure_client_secret_credential = "azure_client_secret_credential"
custom_credential_service = "custom_credential_service"

View File

@@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import Annotated
from typing import Annotated, Sequence
import structlog
from asyncache import cached
@@ -176,6 +176,7 @@ async def _authenticate_user_helper(authorization: str) -> str:
async def resolve_org_from_api_key(
x_api_key: str,
db: AgentDB,
token_types: Sequence[OrganizationAuthTokenType] = (OrganizationAuthTokenType.api,),
) -> ApiKeyValidationResult:
"""Decode and validate the API key against the database."""
try:
@@ -202,12 +203,18 @@ async def resolve_org_from_api_key(
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
raise HTTPException(status_code=404, detail="Organization not found")
api_key_db_obj: OrganizationAuthToken | None = None
# Try token types in priority order and stop at the first valid match.
for token_type in token_types:
api_key_db_obj = await db.validate_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
token_type=token_type,
token=x_api_key,
valid=None,
)
if api_key_db_obj:
break
if not api_key_db_obj:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import time
from types import SimpleNamespace
from unittest.mock import AsyncMock
@@ -15,6 +16,7 @@ from starlette.routing import Route
from skyvern.cli.core import client as client_mod
from skyvern.cli.core import mcp_http_auth
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
@pytest.fixture(autouse=True)
@@ -26,6 +28,7 @@ def _reset_auth_context() -> None:
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
mcp_http_auth.clear_all_impersonation_sessions()
async def _echo_request_context(request: Request) -> JSONResponse:
@@ -33,10 +36,34 @@ async def _echo_request_context(request: Request) -> JSONResponse:
{
"api_key": client_mod.get_active_api_key(),
"organization_id": getattr(request.state, "organization_id", None),
"admin_organization_id": getattr(request.state, "admin_organization_id", None),
"impersonation_target_organization_id": getattr(
request.state, "impersonation_target_organization_id", None
),
}
)
def _build_validation(
organization_id: str,
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
) -> mcp_http_auth.MCPAPIKeyValidation:
return mcp_http_auth.MCPAPIKeyValidation(
organization_id=organization_id,
token_type=token_type,
)
def _build_resolved_validation(
organization_id: str,
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
) -> SimpleNamespace:
return SimpleNamespace(
organization=SimpleNamespace(organization_id=organization_id),
token=SimpleNamespace(token_type=token_type),
)
def _build_test_app() -> Starlette:
return Starlette(
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
@@ -138,7 +165,7 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(return_value="org_123"),
AsyncMock(return_value=_build_validation("org_123")),
)
app = _build_test_app()
@@ -149,6 +176,8 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
assert response.json() == {
"api_key": "sk_live_abc",
"organization_id": "org_123",
"admin_organization_id": None,
"impersonation_target_organization_id": None,
}
assert client_mod.get_active_api_key() != "sk_live_abc"
@@ -157,10 +186,10 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached"))
return _build_resolved_validation("org_cached")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
@@ -168,8 +197,8 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
assert first == "org_cached"
assert second == "org_cached"
assert first.organization_id == "org_cached"
assert second.organization_id == "org_cached"
assert calls == 1
@@ -177,21 +206,21 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}"))
return _build_resolved_validation(f"org_{calls}")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire")
cache_key = mcp_http_auth.cache_key("sk_test_cache_expire")
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
assert first == "org_1"
assert second == "org_2"
assert first.organization_id == "org_1"
assert second.organization_id == "org_2"
assert calls == 2
@@ -199,7 +228,7 @@ async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatc
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
raise HTTPException(status_code=401, detail="Invalid credentials")
@@ -222,22 +251,22 @@ async def test_validate_mcp_api_key_retries_transient_failure_without_negative_c
) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
if calls == 1:
raise RuntimeError("transient db error")
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered"))
return _build_resolved_validation("org_recovered")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
cache_key = mcp_http_auth._cache_key("sk_test_transient")
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered"
cache_key = mcp_http_auth.cache_key("sk_test_transient")
assert mcp_http_auth._api_key_validation_cache[cache_key][0].organization_id == "org_recovered"
assert recovered_org == "org_recovered"
assert recovered_org.organization_id == "org_recovered"
assert calls == 2
@@ -249,16 +278,16 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
collapses subsequent calls after the first one populates it."""
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent"))
return _build_resolved_validation("org_concurrent")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
assert all(r == "org_concurrent" for r in results)
assert all(r.organization_id == "org_concurrent" for r in results)
# First call populates cache; remaining may or may not hit DB depending on
# scheduling, but all must succeed.
assert calls >= 1
@@ -268,7 +297,7 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls
calls += 1
raise RuntimeError("persistent db outage")
@@ -302,3 +331,220 @@ async def test_close_auth_db_noop_when_uninitialized() -> None:
mcp_http_auth._auth_db = None
await mcp_http_auth.close_auth_db()
assert mcp_http_auth._auth_db is None
@pytest.mark.asyncio
async def test_mcp_http_auth_denies_target_org_when_feature_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
validate_mock = AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
)
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
validate_mock,
)
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post(
"/mcp",
headers={"x-api-key": "sk_live_admin", "x-target-org-id": "org_target"},
json={},
)
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Impersonation not allowed"
validate_mock.assert_awaited_once_with("sk_live_admin")
@pytest.mark.asyncio
async def test_mcp_http_auth_validates_api_key_before_feature_flag_denial(monkeypatch: pytest.MonkeyPatch) -> None:
validate_mock = AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials"))
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post(
"/mcp",
headers={"x-api-key": "bad-key", "x-target-org-id": "org_target"},
json={},
)
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Invalid API key"
validate_mock.assert_awaited_once_with("bad-key")
@pytest.mark.parametrize("target_org_id", ["", " \t "])
@pytest.mark.asyncio
async def test_mcp_http_auth_denies_empty_or_whitespace_target_org_id_header(
monkeypatch: pytest.MonkeyPatch,
target_org_id: str,
) -> None:
validate_mock = AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
)
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: True)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post(
"/mcp",
headers={"x-api-key": "sk_live_admin", "x-target-org-id": target_org_id},
json={},
)
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Impersonation not allowed"
validate_mock.assert_awaited_once_with("sk_live_admin")
# ---------------------------------------------------------------------------
# Session-based impersonation tests
# ---------------------------------------------------------------------------
def test_impersonation_session_lifecycle() -> None:
"""set → get → clear lifecycle."""
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_target_key",
expires_at=time.monotonic() + 600,
ttl_minutes=10,
)
mcp_http_auth.set_impersonation_session(session)
retrieved = mcp_http_auth.get_active_impersonation(admin_hash)
assert retrieved is not None
assert retrieved.target_org_id == "org_target"
cleared = mcp_http_auth.clear_impersonation_session(admin_hash)
assert cleared is not None
assert cleared.target_org_id == "org_target"
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
def test_impersonation_session_auto_expiry() -> None:
"""Expired sessions are lazily cleaned up on get."""
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_target_key",
expires_at=time.monotonic() - 1, # already expired
ttl_minutes=1,
)
mcp_http_auth.set_impersonation_session(session)
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
@pytest.mark.asyncio
async def test_middleware_applies_session_impersonation(monkeypatch: pytest.MonkeyPatch) -> None:
"""When a session is active, middleware auto-applies impersonation without header."""
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
),
)
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_live_target_key",
expires_at=time.monotonic() + 600,
ttl_minutes=10,
)
mcp_http_auth.set_impersonation_session(session)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
assert response.status_code == 200
body = response.json()
assert body["api_key"] == "sk_live_target_key"
assert body["organization_id"] == "org_target"
assert body["admin_organization_id"] == "org_admin"
assert body["impersonation_target_organization_id"] == "org_target"
@pytest.mark.asyncio
async def test_middleware_ignores_expired_session(monkeypatch: pytest.MonkeyPatch) -> None:
"""Expired session is ignored — middleware reverts to admin's own org."""
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
),
)
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_live_target_key",
expires_at=time.monotonic() - 1, # expired
ttl_minutes=1,
)
mcp_http_auth.set_impersonation_session(session)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
assert response.status_code == 200
body = response.json()
assert body["api_key"] == "sk_live_admin"
assert body["organization_id"] == "org_admin"
assert body["admin_organization_id"] is None
@pytest.mark.asyncio
async def test_close_auth_db_clears_impersonation_sessions() -> None:
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_target_key",
expires_at=time.monotonic() + 600,
ttl_minutes=10,
)
mcp_http_auth.set_impersonation_session(session)
dispose = AsyncMock()
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
await mcp_http_auth.close_auth_db()
assert mcp_http_auth.get_active_impersonation(admin_hash) is None