Add gated admin impersonation controls for MCP API-key auth (#4822)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Marc Kelechava
2026-02-19 18:56:06 -08:00
committed by GitHub
parent 36e600eeb9
commit 71f2b7a201
5 changed files with 520 additions and 41 deletions

View File

@@ -4,6 +4,8 @@ import asyncio
import os import os
import time import time
from collections import OrderedDict from collections import OrderedDict
from contextvars import ContextVar, Token
from dataclasses import dataclass
from threading import RLock from threading import RLock
import structlog import structlog
@@ -14,6 +16,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
from skyvern.config import settings from skyvern.config import settings
from skyvern.forge.sdk.db.agent_db import AgentDB from skyvern.forge.sdk.db.agent_db import AgentDB
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
from .api_key_hash import hash_api_key_for_cache from .api_key_hash import hash_api_key_for_cache
@@ -21,16 +24,94 @@ from .client import reset_api_key_override, set_api_key_override
LOG = structlog.get_logger(__name__) LOG = structlog.get_logger(__name__)
API_KEY_HEADER = "x-api-key" API_KEY_HEADER = "x-api-key"
TARGET_ORG_ID_HEADER = "x-target-org-id"
HEALTH_PATHS = {"/health", "/healthz"} HEALTH_PATHS = {"/health", "/healthz"}
_MCP_ALLOWED_TOKEN_TYPES = (
OrganizationAuthTokenType.api,
OrganizationAuthTokenType.mcp_admin_impersonation,
)
_auth_db: AgentDB | None = None _auth_db: AgentDB | None = None
_auth_db_lock = RLock() _auth_db_lock = RLock()
_api_key_cache_lock = RLock() _api_key_cache_lock = RLock()
_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict() _api_key_validation_cache: OrderedDict[str, tuple[MCPAPIKeyValidation | None, float]] = OrderedDict()
_NEGATIVE_CACHE_TTL_SECONDS = 5.0 _NEGATIVE_CACHE_TTL_SECONDS = 5.0
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable" _VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
_MAX_VALIDATION_RETRIES = 2 _MAX_VALIDATION_RETRIES = 2
_RETRY_DELAY_SECONDS = 0.25 _RETRY_DELAY_SECONDS = 0.25
# ---------------------------------------------------------------------------
# Impersonation session state
# ---------------------------------------------------------------------------
_admin_api_key_hash: ContextVar[str | None] = ContextVar("admin_api_key_hash", default=None)
@dataclass(frozen=True)
class ImpersonationSession:
admin_api_key_hash: str
admin_org_id: str
target_org_id: str
# Stored in plaintext in process memory for up to TTL. Acceptable V1 trade-off
# (in-process only, same as API key cache).
target_api_key: str
expires_at: float # time.monotonic() deadline
ttl_minutes: int
_impersonation_sessions: dict[str, ImpersonationSession] = {}
_impersonation_lock = RLock()
MAX_TTL_MINUTES = 120
DEFAULT_TTL_MINUTES = 30
_MAX_IMPERSONATION_SESSIONS = 100
def get_active_impersonation(admin_key_hash: str) -> ImpersonationSession | None:
with _impersonation_lock:
session = _impersonation_sessions.get(admin_key_hash)
if session is None:
return None
if session.expires_at <= time.monotonic():
_impersonation_sessions.pop(admin_key_hash, None)
return None
return session
def set_impersonation_session(session: ImpersonationSession) -> None:
with _impersonation_lock:
_impersonation_sessions[session.admin_api_key_hash] = session
# Sweep expired sessions, then evict oldest if over capacity
now = time.monotonic()
expired_keys = [k for k, s in _impersonation_sessions.items() if s.expires_at <= now]
for k in expired_keys:
_impersonation_sessions.pop(k, None)
if len(_impersonation_sessions) > _MAX_IMPERSONATION_SESSIONS:
oldest_key = min(_impersonation_sessions, key=lambda k: _impersonation_sessions[k].expires_at)
_impersonation_sessions.pop(oldest_key, None)
def clear_impersonation_session(admin_key_hash: str) -> ImpersonationSession | None:
with _impersonation_lock:
return _impersonation_sessions.pop(admin_key_hash, None)
def clear_all_impersonation_sessions() -> None:
with _impersonation_lock:
_impersonation_sessions.clear()
def get_admin_api_key_hash() -> str | None:
return _admin_api_key_hash.get()
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class MCPAPIKeyValidation:
organization_id: str
token_type: OrganizationAuthTokenType
def _resolve_api_key_cache_ttl_seconds() -> float: def _resolve_api_key_cache_ttl_seconds() -> float:
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30") raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
@@ -69,6 +150,7 @@ async def close_auth_db() -> None:
_auth_db = None _auth_db = None
with _api_key_cache_lock: with _api_key_cache_lock:
_api_key_validation_cache.clear() _api_key_validation_cache.clear()
clear_all_impersonation_sessions()
if db is None: if db is None:
return return
@@ -78,24 +160,43 @@ async def close_auth_db() -> None:
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True) LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
def _cache_key(api_key: str) -> str: def cache_key(api_key: str) -> str:
return hash_api_key_for_cache(api_key) return hash_api_key_for_cache(api_key)
async def validate_mcp_api_key(api_key: str) -> str: def _admin_organization_ids() -> set[str]:
"""Validate API key and return organization id for observability.""" try:
key = _cache_key(api_key) from cloud.config import settings as cloud_settings # noqa: PLC0415
admin_ids = getattr(cloud_settings, "ADMIN_ORGANIZATION_IDS", [])
except ImportError:
admin_ids = getattr(settings, "ADMIN_ORGANIZATION_IDS", [])
return {org_id for org_id in admin_ids if org_id}
def _is_admin_impersonation_enabled() -> bool:
try:
from cloud.config import settings as cloud_settings # noqa: PLC0415
return bool(getattr(cloud_settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False))
except ImportError:
return bool(getattr(settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False))
async def validate_mcp_api_key(api_key: str) -> MCPAPIKeyValidation:
"""Validate API key and return caller organization + token type."""
key = cache_key(api_key)
# Check cache first. # Check cache first.
with _api_key_cache_lock: with _api_key_cache_lock:
cached = _api_key_validation_cache.get(key) cached = _api_key_validation_cache.get(key)
if cached is not None: if cached is not None:
organization_id, expires_at = cached cached_validation, expires_at = cached
if expires_at > time.monotonic(): if expires_at > time.monotonic():
_api_key_validation_cache.move_to_end(key) _api_key_validation_cache.move_to_end(key)
if organization_id is None: if cached_validation is None:
raise HTTPException(status_code=401, detail="Invalid API key") raise HTTPException(status_code=401, detail="Invalid API key")
return organization_id return cached_validation
_api_key_validation_cache.pop(key, None) _api_key_validation_cache.pop(key, None)
# Cache miss — do the DB lookup with simple retry on transient errors. # Cache miss — do the DB lookup with simple retry on transient errors.
@@ -104,17 +205,24 @@ async def validate_mcp_api_key(api_key: str) -> str:
if attempt > 0: if attempt > 0:
await asyncio.sleep(_RETRY_DELAY_SECONDS) await asyncio.sleep(_RETRY_DELAY_SECONDS)
try: try:
validation = await resolve_org_from_api_key(api_key, _get_auth_db()) validation = await resolve_org_from_api_key(
organization_id = validation.organization.organization_id api_key,
_get_auth_db(),
token_types=_MCP_ALLOWED_TOKEN_TYPES,
)
caller_validation = MCPAPIKeyValidation(
organization_id=validation.organization.organization_id,
token_type=validation.token.token_type,
)
with _api_key_cache_lock: with _api_key_cache_lock:
_api_key_validation_cache[key] = ( _api_key_validation_cache[key] = (
organization_id, caller_validation,
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS, time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
) )
_api_key_validation_cache.move_to_end(key) _api_key_validation_cache.move_to_end(key)
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE: while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
_api_key_validation_cache.popitem(last=False) _api_key_validation_cache.popitem(last=False)
return organization_id return caller_validation
except HTTPException as e: except HTTPException as e:
if e.status_code in {401, 403}: if e.status_code in {401, 403}:
with _api_key_cache_lock: with _api_key_cache_lock:
@@ -149,6 +257,24 @@ def _service_unavailable_response(message: str) -> JSONResponse:
) )
def _deny_impersonation(
*,
reason: str,
caller_organization_id: str | None,
target_organization_id: str | None,
token_type: OrganizationAuthTokenType | None = None,
) -> JSONResponse:
log_kwargs: dict[str, object] = {
"reason": reason,
"caller_organization_id": caller_organization_id,
"target_organization_id": target_organization_id,
}
if token_type is not None:
log_kwargs["token_type"] = token_type.value
LOG.warning("MCP admin impersonation denied", **log_kwargs)
return _unauthorized_response("Impersonation not allowed")
class MCPAPIKeyMiddleware: class MCPAPIKeyMiddleware:
"""Require x-api-key for MCP HTTP transport and scope requests to that key.""" """Require x-api-key for MCP HTTP transport and scope requests to that key."""
@@ -176,10 +302,99 @@ class MCPAPIKeyMiddleware:
await response(scope, receive, send) await response(scope, receive, send)
return return
target_org_id_header = request.headers.get(TARGET_ORG_ID_HEADER)
admin_hash_token: Token[str | None] | None = None
try: try:
organization_id = await validate_mcp_api_key(api_key) validation = await validate_mcp_api_key(api_key)
caller_organization_id = validation.organization_id
admin_key_hash = cache_key(api_key)
scope.setdefault("state", {}) scope.setdefault("state", {})
scope["state"]["organization_id"] = organization_id if target_org_id_header is not None:
# Explicit header-based impersonation (takes priority over session)
if not _is_admin_impersonation_enabled():
response = _deny_impersonation(
reason="feature_disabled",
caller_organization_id=caller_organization_id,
target_organization_id=target_org_id_header.strip() or target_org_id_header,
)
await response(scope, receive, send)
return
target_organization_id = target_org_id_header.strip()
if not target_organization_id:
response = _deny_impersonation(
reason="missing_target_organization_id",
caller_organization_id=caller_organization_id,
target_organization_id=target_org_id_header,
token_type=validation.token_type,
)
await response(scope, receive, send)
return
# Delegate validation to the single source of truth in cloud/.
# The import can't fail here — _is_admin_impersonation_enabled()
# already returned True, which requires cloud.config to be importable.
try:
from cloud.mcp_admin_tools import validate_impersonation_target # noqa: PLC0415
except ImportError:
response = _deny_impersonation(
reason="impersonation_not_available",
caller_organization_id=caller_organization_id,
target_organization_id=target_organization_id,
)
await response(scope, receive, send)
return
result = await validate_impersonation_target(
caller_organization_id=caller_organization_id,
target_organization_id=target_organization_id,
token_type=validation.token_type,
)
if isinstance(result, str):
response = _deny_impersonation(
reason=result,
caller_organization_id=caller_organization_id,
target_organization_id=target_organization_id,
token_type=validation.token_type,
)
await response(scope, receive, send)
return
resolved_org_id, target_api_key = result
api_key = target_api_key
scope["state"]["organization_id"] = resolved_org_id
scope["state"]["admin_organization_id"] = caller_organization_id
scope["state"]["impersonation_target_organization_id"] = resolved_org_id
admin_hash_token = _admin_api_key_hash.set(admin_key_hash)
LOG.info(
"MCP admin impersonation allowed",
caller_organization_id=caller_organization_id,
target_organization_id=resolved_org_id,
token_type=validation.token_type.value,
)
else:
# No explicit header — check for session-based impersonation
session = get_active_impersonation(admin_key_hash)
if session is not None:
# Session data (target API key) is cached for TTL. If the target org's
# key is revoked mid-session, impersonation continues until expiry —
# acceptable trade-off vs re-validating every request.
api_key = session.target_api_key
scope["state"]["organization_id"] = session.target_org_id
scope["state"]["admin_organization_id"] = session.admin_org_id
scope["state"]["impersonation_target_organization_id"] = session.target_org_id
admin_hash_token = _admin_api_key_hash.set(admin_key_hash)
LOG.info(
"MCP session impersonation applied",
caller_organization_id=session.admin_org_id,
target_organization_id=session.target_org_id,
ttl_minutes=session.ttl_minutes,
)
else:
scope["state"]["organization_id"] = caller_organization_id
except HTTPException as e: except HTTPException as e:
if e.status_code in {401, 403}: if e.status_code in {401, 403}:
response = _unauthorized_response("Invalid API key") response = _unauthorized_response("Invalid API key")
@@ -204,3 +419,5 @@ class MCPAPIKeyMiddleware:
await self.app(scope, receive, send) await self.app(scope, receive, send)
finally: finally:
reset_api_key_override(token) reset_api_key_override(token)
if admin_hash_token is not None:
_admin_api_key_hash.reset(admin_hash_token)

View File

@@ -286,6 +286,14 @@ mcp.tool()(skyvern_workflow_run)
mcp.tool()(skyvern_workflow_status) mcp.tool()(skyvern_workflow_status)
mcp.tool()(skyvern_workflow_cancel) mcp.tool()(skyvern_workflow_cancel)
# -- Admin impersonation (cloud-only, session-level org switching) --
try:
from cloud.mcp_admin_tools import register_admin_tools # noqa: PLC0415
register_admin_tools(mcp)
except ImportError:
pass
# -- Prompts (methodology guides injected into LLM conversations) -- # -- Prompts (methodology guides injected into LLM conversations) --
mcp.prompt()(build_workflow) mcp.prompt()(build_workflow)
mcp.prompt()(debug_automation) mcp.prompt()(debug_automation)

View File

@@ -3,6 +3,7 @@ from enum import StrEnum
class OrganizationAuthTokenType(StrEnum): class OrganizationAuthTokenType(StrEnum):
api = "api" api = "api"
mcp_admin_impersonation = "mcp_admin_impersonation"
onepassword_service_account = "onepassword_service_account" onepassword_service_account = "onepassword_service_account"
azure_client_secret_credential = "azure_client_secret_credential" azure_client_secret_credential = "azure_client_secret_credential"
custom_credential_service = "custom_credential_service" custom_credential_service = "custom_credential_service"

View File

@@ -1,6 +1,6 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated from typing import Annotated, Sequence
import structlog import structlog
from asyncache import cached from asyncache import cached
@@ -176,6 +176,7 @@ async def _authenticate_user_helper(authorization: str) -> str:
async def resolve_org_from_api_key( async def resolve_org_from_api_key(
x_api_key: str, x_api_key: str,
db: AgentDB, db: AgentDB,
token_types: Sequence[OrganizationAuthTokenType] = (OrganizationAuthTokenType.api,),
) -> ApiKeyValidationResult: ) -> ApiKeyValidationResult:
"""Decode and validate the API key against the database.""" """Decode and validate the API key against the database."""
try: try:
@@ -202,12 +203,18 @@ async def resolve_org_from_api_key(
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload) LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
raise HTTPException(status_code=404, detail="Organization not found") raise HTTPException(status_code=404, detail="Organization not found")
api_key_db_obj: OrganizationAuthToken | None = None
# Try token types in priority order and stop at the first valid match.
for token_type in token_types:
api_key_db_obj = await db.validate_org_auth_token( api_key_db_obj = await db.validate_org_auth_token(
organization_id=organization.organization_id, organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api, token_type=token_type,
token=x_api_key, token=x_api_key,
valid=None, valid=None,
) )
if api_key_db_obj:
break
if not api_key_db_obj: if not api_key_db_obj:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import time
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@@ -15,6 +16,7 @@ from starlette.routing import Route
from skyvern.cli.core import client as client_mod from skyvern.cli.core import client as client_mod
from skyvern.cli.core import mcp_http_auth from skyvern.cli.core import mcp_http_auth
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -26,6 +28,7 @@ def _reset_auth_context() -> None:
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024 mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
mcp_http_auth._MAX_VALIDATION_RETRIES = 2 mcp_http_auth._MAX_VALIDATION_RETRIES = 2
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
mcp_http_auth.clear_all_impersonation_sessions()
async def _echo_request_context(request: Request) -> JSONResponse: async def _echo_request_context(request: Request) -> JSONResponse:
@@ -33,10 +36,34 @@ async def _echo_request_context(request: Request) -> JSONResponse:
{ {
"api_key": client_mod.get_active_api_key(), "api_key": client_mod.get_active_api_key(),
"organization_id": getattr(request.state, "organization_id", None), "organization_id": getattr(request.state, "organization_id", None),
"admin_organization_id": getattr(request.state, "admin_organization_id", None),
"impersonation_target_organization_id": getattr(
request.state, "impersonation_target_organization_id", None
),
} }
) )
def _build_validation(
organization_id: str,
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
) -> mcp_http_auth.MCPAPIKeyValidation:
return mcp_http_auth.MCPAPIKeyValidation(
organization_id=organization_id,
token_type=token_type,
)
def _build_resolved_validation(
organization_id: str,
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
) -> SimpleNamespace:
return SimpleNamespace(
organization=SimpleNamespace(organization_id=organization_id),
token=SimpleNamespace(token_type=token_type),
)
def _build_test_app() -> Starlette: def _build_test_app() -> Starlette:
return Starlette( return Starlette(
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])], routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
@@ -138,7 +165,7 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
monkeypatch.setattr( monkeypatch.setattr(
mcp_http_auth, mcp_http_auth,
"validate_mcp_api_key", "validate_mcp_api_key",
AsyncMock(return_value="org_123"), AsyncMock(return_value=_build_validation("org_123")),
) )
app = _build_test_app() app = _build_test_app()
@@ -149,6 +176,8 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
assert response.json() == { assert response.json() == {
"api_key": "sk_live_abc", "api_key": "sk_live_abc",
"organization_id": "org_123", "organization_id": "org_123",
"admin_organization_id": None,
"impersonation_target_organization_id": None,
} }
assert client_mod.get_active_api_key() != "sk_live_abc" assert client_mod.get_active_api_key() != "sk_live_abc"
@@ -157,10 +186,10 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None: async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0 calls = 0
async def _resolve(_api_key: str, _db: object) -> object: async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls nonlocal calls
calls += 1 calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached")) return _build_resolved_validation("org_cached")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
@@ -168,8 +197,8 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache") first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache") second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
assert first == "org_cached" assert first.organization_id == "org_cached"
assert second == "org_cached" assert second.organization_id == "org_cached"
assert calls == 1 assert calls == 1
@@ -177,21 +206,21 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None: async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0 calls = 0
async def _resolve(_api_key: str, _db: object) -> object: async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls nonlocal calls
calls += 1 calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}")) return _build_resolved_validation(f"org_{calls}")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire") first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire") cache_key = mcp_http_auth.cache_key("sk_test_cache_expire")
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0) mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire") second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
assert first == "org_1" assert first.organization_id == "org_1"
assert second == "org_2" assert second.organization_id == "org_2"
assert calls == 2 assert calls == 2
@@ -199,7 +228,7 @@ async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatc
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None: async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0 calls = 0
async def _resolve(_api_key: str, _db: object) -> object: async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls nonlocal calls
calls += 1 calls += 1
raise HTTPException(status_code=401, detail="Invalid credentials") raise HTTPException(status_code=401, detail="Invalid credentials")
@@ -222,22 +251,22 @@ async def test_validate_mcp_api_key_retries_transient_failure_without_negative_c
) -> None: ) -> None:
calls = 0 calls = 0
async def _resolve(_api_key: str, _db: object) -> object: async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls nonlocal calls
calls += 1 calls += 1
if calls == 1: if calls == 1:
raise RuntimeError("transient db error") raise RuntimeError("transient db error")
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered")) return _build_resolved_validation("org_recovered")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient") recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
cache_key = mcp_http_auth._cache_key("sk_test_transient") cache_key = mcp_http_auth.cache_key("sk_test_transient")
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered" assert mcp_http_auth._api_key_validation_cache[cache_key][0].organization_id == "org_recovered"
assert recovered_org == "org_recovered" assert recovered_org.organization_id == "org_recovered"
assert calls == 2 assert calls == 2
@@ -249,16 +278,16 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
collapses subsequent calls after the first one populates it.""" collapses subsequent calls after the first one populates it."""
calls = 0 calls = 0
async def _resolve(_api_key: str, _db: object) -> object: async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls nonlocal calls
calls += 1 calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent")) return _build_resolved_validation("org_concurrent")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)]) results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
assert all(r == "org_concurrent" for r in results) assert all(r.organization_id == "org_concurrent" for r in results)
# First call populates cache; remaining may or may not hit DB depending on # First call populates cache; remaining may or may not hit DB depending on
# scheduling, but all must succeed. # scheduling, but all must succeed.
assert calls >= 1 assert calls >= 1
@@ -268,7 +297,7 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None: async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0 calls = 0
async def _resolve(_api_key: str, _db: object) -> object: async def _resolve(_api_key: str, _db: object, **_: object) -> object:
nonlocal calls nonlocal calls
calls += 1 calls += 1
raise RuntimeError("persistent db outage") raise RuntimeError("persistent db outage")
@@ -302,3 +331,220 @@ async def test_close_auth_db_noop_when_uninitialized() -> None:
mcp_http_auth._auth_db = None mcp_http_auth._auth_db = None
await mcp_http_auth.close_auth_db() await mcp_http_auth.close_auth_db()
assert mcp_http_auth._auth_db is None assert mcp_http_auth._auth_db is None
@pytest.mark.asyncio
async def test_mcp_http_auth_denies_target_org_when_feature_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
validate_mock = AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
)
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
validate_mock,
)
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post(
"/mcp",
headers={"x-api-key": "sk_live_admin", "x-target-org-id": "org_target"},
json={},
)
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Impersonation not allowed"
validate_mock.assert_awaited_once_with("sk_live_admin")
@pytest.mark.asyncio
async def test_mcp_http_auth_validates_api_key_before_feature_flag_denial(monkeypatch: pytest.MonkeyPatch) -> None:
validate_mock = AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials"))
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post(
"/mcp",
headers={"x-api-key": "bad-key", "x-target-org-id": "org_target"},
json={},
)
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Invalid API key"
validate_mock.assert_awaited_once_with("bad-key")
@pytest.mark.parametrize("target_org_id", ["", " \t "])
@pytest.mark.asyncio
async def test_mcp_http_auth_denies_empty_or_whitespace_target_org_id_header(
monkeypatch: pytest.MonkeyPatch,
target_org_id: str,
) -> None:
validate_mock = AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
)
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: True)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post(
"/mcp",
headers={"x-api-key": "sk_live_admin", "x-target-org-id": target_org_id},
json={},
)
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Impersonation not allowed"
validate_mock.assert_awaited_once_with("sk_live_admin")
# ---------------------------------------------------------------------------
# Session-based impersonation tests
# ---------------------------------------------------------------------------
def test_impersonation_session_lifecycle() -> None:
"""set → get → clear lifecycle."""
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_target_key",
expires_at=time.monotonic() + 600,
ttl_minutes=10,
)
mcp_http_auth.set_impersonation_session(session)
retrieved = mcp_http_auth.get_active_impersonation(admin_hash)
assert retrieved is not None
assert retrieved.target_org_id == "org_target"
cleared = mcp_http_auth.clear_impersonation_session(admin_hash)
assert cleared is not None
assert cleared.target_org_id == "org_target"
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
def test_impersonation_session_auto_expiry() -> None:
"""Expired sessions are lazily cleaned up on get."""
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_target_key",
expires_at=time.monotonic() - 1, # already expired
ttl_minutes=1,
)
mcp_http_auth.set_impersonation_session(session)
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
@pytest.mark.asyncio
async def test_middleware_applies_session_impersonation(monkeypatch: pytest.MonkeyPatch) -> None:
"""When a session is active, middleware auto-applies impersonation without header."""
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
),
)
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_live_target_key",
expires_at=time.monotonic() + 600,
ttl_minutes=10,
)
mcp_http_auth.set_impersonation_session(session)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
assert response.status_code == 200
body = response.json()
assert body["api_key"] == "sk_live_target_key"
assert body["organization_id"] == "org_target"
assert body["admin_organization_id"] == "org_admin"
assert body["impersonation_target_organization_id"] == "org_target"
@pytest.mark.asyncio
async def test_middleware_ignores_expired_session(monkeypatch: pytest.MonkeyPatch) -> None:
"""Expired session is ignored — middleware reverts to admin's own org."""
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(
return_value=_build_validation(
"org_admin",
OrganizationAuthTokenType.mcp_admin_impersonation,
)
),
)
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_live_target_key",
expires_at=time.monotonic() - 1, # expired
ttl_minutes=1,
)
mcp_http_auth.set_impersonation_session(session)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
assert response.status_code == 200
body = response.json()
assert body["api_key"] == "sk_live_admin"
assert body["organization_id"] == "org_admin"
assert body["admin_organization_id"] is None
@pytest.mark.asyncio
async def test_close_auth_db_clears_impersonation_sessions() -> None:
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
session = mcp_http_auth.ImpersonationSession(
admin_api_key_hash=admin_hash,
admin_org_id="org_admin",
target_org_id="org_target",
target_api_key="sk_target_key",
expires_at=time.monotonic() + 600,
ttl_minutes=10,
)
mcp_http_auth.set_impersonation_session(session)
dispose = AsyncMock()
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
await mcp_http_auth.close_auth_db()
assert mcp_http_auth.get_active_impersonation(admin_hash) is None