From 71f2b7a2018e9f19582e120cb9d089da68fb2b50 Mon Sep 17 00:00:00 2001 From: Marc Kelechava Date: Thu, 19 Feb 2026 18:56:06 -0800 Subject: [PATCH] Add gated admin impersonation controls for MCP API-key auth (#4822) Co-authored-by: Claude Opus 4.6 (1M context) --- skyvern/cli/core/mcp_http_auth.py | 245 ++++++++++++++- skyvern/cli/mcp_tools/__init__.py | 8 + skyvern/forge/sdk/db/enums.py | 1 + .../forge/sdk/services/org_auth_service.py | 21 +- tests/unit/test_mcp_http_auth.py | 286 ++++++++++++++++-- 5 files changed, 520 insertions(+), 41 deletions(-) diff --git a/skyvern/cli/core/mcp_http_auth.py b/skyvern/cli/core/mcp_http_auth.py index 07a7a168..44f1f846 100644 --- a/skyvern/cli/core/mcp_http_auth.py +++ b/skyvern/cli/core/mcp_http_auth.py @@ -4,6 +4,8 @@ import asyncio import os import time from collections import OrderedDict +from contextvars import ContextVar, Token +from dataclasses import dataclass from threading import RLock import structlog @@ -14,6 +16,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send from skyvern.config import settings from skyvern.forge.sdk.db.agent_db import AgentDB +from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key from .api_key_hash import hash_api_key_for_cache @@ -21,16 +24,94 @@ from .client import reset_api_key_override, set_api_key_override LOG = structlog.get_logger(__name__) API_KEY_HEADER = "x-api-key" +TARGET_ORG_ID_HEADER = "x-target-org-id" HEALTH_PATHS = {"/health", "/healthz"} +_MCP_ALLOWED_TOKEN_TYPES = ( + OrganizationAuthTokenType.api, + OrganizationAuthTokenType.mcp_admin_impersonation, +) _auth_db: AgentDB | None = None _auth_db_lock = RLock() _api_key_cache_lock = RLock() -_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict() +_api_key_validation_cache: OrderedDict[str, tuple[MCPAPIKeyValidation | None, float]] = OrderedDict() _NEGATIVE_CACHE_TTL_SECONDS = 5.0 _VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable" _MAX_VALIDATION_RETRIES = 2 _RETRY_DELAY_SECONDS = 0.25 +# --------------------------------------------------------------------------- +# Impersonation session state +# --------------------------------------------------------------------------- + +_admin_api_key_hash: ContextVar[str | None] = ContextVar("admin_api_key_hash", default=None) + + +@dataclass(frozen=True) +class ImpersonationSession: + admin_api_key_hash: str + admin_org_id: str + target_org_id: str + # Stored in plaintext in process memory for up to TTL. Acceptable V1 trade-off + # (in-process only, same as API key cache). + target_api_key: str + expires_at: float # time.monotonic() deadline + ttl_minutes: int + + +_impersonation_sessions: dict[str, ImpersonationSession] = {} +_impersonation_lock = RLock() + +MAX_TTL_MINUTES = 120 +DEFAULT_TTL_MINUTES = 30 +_MAX_IMPERSONATION_SESSIONS = 100 + + +def get_active_impersonation(admin_key_hash: str) -> ImpersonationSession | None: + with _impersonation_lock: + session = _impersonation_sessions.get(admin_key_hash) + if session is None: + return None + if session.expires_at <= time.monotonic(): + _impersonation_sessions.pop(admin_key_hash, None) + return None + return session + + +def set_impersonation_session(session: ImpersonationSession) -> None: + with _impersonation_lock: + _impersonation_sessions[session.admin_api_key_hash] = session + # Sweep expired sessions, then evict oldest if over capacity + now = time.monotonic() + expired_keys = [k for k, s in _impersonation_sessions.items() if s.expires_at <= now] + for k in expired_keys: + _impersonation_sessions.pop(k, None) + if len(_impersonation_sessions) > _MAX_IMPERSONATION_SESSIONS: + oldest_key = min(_impersonation_sessions, key=lambda k: _impersonation_sessions[k].expires_at) + _impersonation_sessions.pop(oldest_key, None) + + +def clear_impersonation_session(admin_key_hash: str) -> ImpersonationSession | None: + with _impersonation_lock: + return _impersonation_sessions.pop(admin_key_hash, None) + + +def clear_all_impersonation_sessions() -> None: + with _impersonation_lock: + _impersonation_sessions.clear() + + +def get_admin_api_key_hash() -> str | None: + return _admin_api_key_hash.get() + + +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class MCPAPIKeyValidation: + organization_id: str + token_type: OrganizationAuthTokenType + def _resolve_api_key_cache_ttl_seconds() -> float: raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30") @@ -69,6 +150,7 @@ async def close_auth_db() -> None: _auth_db = None with _api_key_cache_lock: _api_key_validation_cache.clear() + clear_all_impersonation_sessions() if db is None: return @@ -78,24 +160,43 @@ async def close_auth_db() -> None: LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True) -def _cache_key(api_key: str) -> str: +def cache_key(api_key: str) -> str: return hash_api_key_for_cache(api_key) -async def validate_mcp_api_key(api_key: str) -> str: - """Validate API key and return organization id for observability.""" - key = _cache_key(api_key) +def _admin_organization_ids() -> set[str]: + try: + from cloud.config import settings as cloud_settings # noqa: PLC0415 + + admin_ids = getattr(cloud_settings, "ADMIN_ORGANIZATION_IDS", []) + except ImportError: + admin_ids = getattr(settings, "ADMIN_ORGANIZATION_IDS", []) + return {org_id for org_id in admin_ids if org_id} + + +def _is_admin_impersonation_enabled() -> bool: + try: + from cloud.config import settings as cloud_settings # noqa: PLC0415 + + return bool(getattr(cloud_settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False)) + except ImportError: + return bool(getattr(settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False)) + + +async def validate_mcp_api_key(api_key: str) -> MCPAPIKeyValidation: + """Validate API key and return caller organization + token type.""" + key = cache_key(api_key) # Check cache first. with _api_key_cache_lock: cached = _api_key_validation_cache.get(key) if cached is not None: - organization_id, expires_at = cached + cached_validation, expires_at = cached if expires_at > time.monotonic(): _api_key_validation_cache.move_to_end(key) - if organization_id is None: + if cached_validation is None: raise HTTPException(status_code=401, detail="Invalid API key") - return organization_id + return cached_validation _api_key_validation_cache.pop(key, None) # Cache miss — do the DB lookup with simple retry on transient errors. @@ -104,17 +205,24 @@ async def validate_mcp_api_key(api_key: str) -> str: if attempt > 0: await asyncio.sleep(_RETRY_DELAY_SECONDS) try: - validation = await resolve_org_from_api_key(api_key, _get_auth_db()) - organization_id = validation.organization.organization_id + validation = await resolve_org_from_api_key( + api_key, + _get_auth_db(), + token_types=_MCP_ALLOWED_TOKEN_TYPES, + ) + caller_validation = MCPAPIKeyValidation( + organization_id=validation.organization.organization_id, + token_type=validation.token.token_type, + ) with _api_key_cache_lock: _api_key_validation_cache[key] = ( - organization_id, + caller_validation, time.monotonic() + _API_KEY_CACHE_TTL_SECONDS, ) _api_key_validation_cache.move_to_end(key) while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE: _api_key_validation_cache.popitem(last=False) - return organization_id + return caller_validation except HTTPException as e: if e.status_code in {401, 403}: with _api_key_cache_lock: @@ -149,6 +257,24 @@ def _service_unavailable_response(message: str) -> JSONResponse: ) +def _deny_impersonation( + *, + reason: str, + caller_organization_id: str | None, + target_organization_id: str | None, + token_type: OrganizationAuthTokenType | None = None, +) -> JSONResponse: + log_kwargs: dict[str, object] = { + "reason": reason, + "caller_organization_id": caller_organization_id, + "target_organization_id": target_organization_id, + } + if token_type is not None: + log_kwargs["token_type"] = token_type.value + LOG.warning("MCP admin impersonation denied", **log_kwargs) + return _unauthorized_response("Impersonation not allowed") + + class MCPAPIKeyMiddleware: """Require x-api-key for MCP HTTP transport and scope requests to that key.""" @@ -176,10 +302,99 @@ class MCPAPIKeyMiddleware: await response(scope, receive, send) return + target_org_id_header = request.headers.get(TARGET_ORG_ID_HEADER) + admin_hash_token: Token[str | None] | None = None + try: - organization_id = await validate_mcp_api_key(api_key) + validation = await validate_mcp_api_key(api_key) + caller_organization_id = validation.organization_id + admin_key_hash = cache_key(api_key) + scope.setdefault("state", {}) - scope["state"]["organization_id"] = organization_id + if target_org_id_header is not None: + # Explicit header-based impersonation (takes priority over session) + if not _is_admin_impersonation_enabled(): + response = _deny_impersonation( + reason="feature_disabled", + caller_organization_id=caller_organization_id, + target_organization_id=target_org_id_header.strip() or target_org_id_header, + ) + await response(scope, receive, send) + return + + target_organization_id = target_org_id_header.strip() + if not target_organization_id: + response = _deny_impersonation( + reason="missing_target_organization_id", + caller_organization_id=caller_organization_id, + target_organization_id=target_org_id_header, + token_type=validation.token_type, + ) + await response(scope, receive, send) + return + + # Delegate validation to the single source of truth in cloud/. + # The import can't fail here — _is_admin_impersonation_enabled() + # already returned True, which requires cloud.config to be importable. + try: + from cloud.mcp_admin_tools import validate_impersonation_target # noqa: PLC0415 + except ImportError: + response = _deny_impersonation( + reason="impersonation_not_available", + caller_organization_id=caller_organization_id, + target_organization_id=target_organization_id, + ) + await response(scope, receive, send) + return + + result = await validate_impersonation_target( + caller_organization_id=caller_organization_id, + target_organization_id=target_organization_id, + token_type=validation.token_type, + ) + if isinstance(result, str): + response = _deny_impersonation( + reason=result, + caller_organization_id=caller_organization_id, + target_organization_id=target_organization_id, + token_type=validation.token_type, + ) + await response(scope, receive, send) + return + + resolved_org_id, target_api_key = result + api_key = target_api_key + + scope["state"]["organization_id"] = resolved_org_id + scope["state"]["admin_organization_id"] = caller_organization_id + scope["state"]["impersonation_target_organization_id"] = resolved_org_id + admin_hash_token = _admin_api_key_hash.set(admin_key_hash) + LOG.info( + "MCP admin impersonation allowed", + caller_organization_id=caller_organization_id, + target_organization_id=resolved_org_id, + token_type=validation.token_type.value, + ) + else: + # No explicit header — check for session-based impersonation + session = get_active_impersonation(admin_key_hash) + if session is not None: + # Session data (target API key) is cached for TTL. If the target org's + # key is revoked mid-session, impersonation continues until expiry — + # acceptable trade-off vs re-validating every request. + api_key = session.target_api_key + scope["state"]["organization_id"] = session.target_org_id + scope["state"]["admin_organization_id"] = session.admin_org_id + scope["state"]["impersonation_target_organization_id"] = session.target_org_id + admin_hash_token = _admin_api_key_hash.set(admin_key_hash) + LOG.info( + "MCP session impersonation applied", + caller_organization_id=session.admin_org_id, + target_organization_id=session.target_org_id, + ttl_minutes=session.ttl_minutes, + ) + else: + scope["state"]["organization_id"] = caller_organization_id except HTTPException as e: if e.status_code in {401, 403}: response = _unauthorized_response("Invalid API key") @@ -204,3 +419,5 @@ class MCPAPIKeyMiddleware: await self.app(scope, receive, send) finally: reset_api_key_override(token) + if admin_hash_token is not None: + _admin_api_key_hash.reset(admin_hash_token) diff --git a/skyvern/cli/mcp_tools/__init__.py b/skyvern/cli/mcp_tools/__init__.py index c14839de..02cf57dc 100644 --- a/skyvern/cli/mcp_tools/__init__.py +++ b/skyvern/cli/mcp_tools/__init__.py @@ -286,6 +286,14 @@ mcp.tool()(skyvern_workflow_run) mcp.tool()(skyvern_workflow_status) mcp.tool()(skyvern_workflow_cancel) +# -- Admin impersonation (cloud-only, session-level org switching) -- +try: + from cloud.mcp_admin_tools import register_admin_tools # noqa: PLC0415 + + register_admin_tools(mcp) +except ImportError: + pass + # -- Prompts (methodology guides injected into LLM conversations) -- mcp.prompt()(build_workflow) mcp.prompt()(debug_automation) diff --git a/skyvern/forge/sdk/db/enums.py b/skyvern/forge/sdk/db/enums.py index de227125..b07e32fd 100644 --- a/skyvern/forge/sdk/db/enums.py +++ b/skyvern/forge/sdk/db/enums.py @@ -3,6 +3,7 @@ from enum import StrEnum class OrganizationAuthTokenType(StrEnum): api = "api" + mcp_admin_impersonation = "mcp_admin_impersonation" onepassword_service_account = "onepassword_service_account" azure_client_secret_credential = "azure_client_secret_credential" custom_credential_service = "custom_credential_service" diff --git a/skyvern/forge/sdk/services/org_auth_service.py b/skyvern/forge/sdk/services/org_auth_service.py index 5118d9a4..742dc7d1 100644 --- a/skyvern/forge/sdk/services/org_auth_service.py +++ b/skyvern/forge/sdk/services/org_auth_service.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import Annotated +from typing import Annotated, Sequence import structlog from asyncache import cached @@ -176,6 +176,7 @@ async def _authenticate_user_helper(authorization: str) -> str: async def resolve_org_from_api_key( x_api_key: str, db: AgentDB, + token_types: Sequence[OrganizationAuthTokenType] = (OrganizationAuthTokenType.api,), ) -> ApiKeyValidationResult: """Decode and validate the API key against the database.""" try: @@ -202,12 +203,18 @@ async def resolve_org_from_api_key( LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload) raise HTTPException(status_code=404, detail="Organization not found") - api_key_db_obj = await db.validate_org_auth_token( - organization_id=organization.organization_id, - token_type=OrganizationAuthTokenType.api, - token=x_api_key, - valid=None, - ) + api_key_db_obj: OrganizationAuthToken | None = None + # Try token types in priority order and stop at the first valid match. + for token_type in token_types: + api_key_db_obj = await db.validate_org_auth_token( + organization_id=organization.organization_id, + token_type=token_type, + token=x_api_key, + valid=None, + ) + if api_key_db_obj: + break + if not api_key_db_obj: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/tests/unit/test_mcp_http_auth.py b/tests/unit/test_mcp_http_auth.py index 8cfcc4dc..4b38ecc6 100644 --- a/tests/unit/test_mcp_http_auth.py +++ b/tests/unit/test_mcp_http_auth.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import time from types import SimpleNamespace from unittest.mock import AsyncMock @@ -15,6 +16,7 @@ from starlette.routing import Route from skyvern.cli.core import client as client_mod from skyvern.cli.core import mcp_http_auth +from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType @pytest.fixture(autouse=True) @@ -26,6 +28,7 @@ def _reset_auth_context() -> None: mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024 mcp_http_auth._MAX_VALIDATION_RETRIES = 2 mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests + mcp_http_auth.clear_all_impersonation_sessions() async def _echo_request_context(request: Request) -> JSONResponse: @@ -33,10 +36,34 @@ async def _echo_request_context(request: Request) -> JSONResponse: { "api_key": client_mod.get_active_api_key(), "organization_id": getattr(request.state, "organization_id", None), + "admin_organization_id": getattr(request.state, "admin_organization_id", None), + "impersonation_target_organization_id": getattr( + request.state, "impersonation_target_organization_id", None + ), } ) +def _build_validation( + organization_id: str, + token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api, +) -> mcp_http_auth.MCPAPIKeyValidation: + return mcp_http_auth.MCPAPIKeyValidation( + organization_id=organization_id, + token_type=token_type, + ) + + +def _build_resolved_validation( + organization_id: str, + token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api, +) -> SimpleNamespace: + return SimpleNamespace( + organization=SimpleNamespace(organization_id=organization_id), + token=SimpleNamespace(token_type=token_type), + ) + + def _build_test_app() -> Starlette: return Starlette( routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])], @@ -138,7 +165,7 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon monkeypatch.setattr( mcp_http_auth, "validate_mcp_api_key", - AsyncMock(return_value="org_123"), + AsyncMock(return_value=_build_validation("org_123")), ) app = _build_test_app() @@ -149,6 +176,8 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon assert response.json() == { "api_key": "sk_live_abc", "organization_id": "org_123", + "admin_organization_id": None, + "impersonation_target_organization_id": None, } assert client_mod.get_active_api_key() != "sk_live_abc" @@ -157,10 +186,10 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None: calls = 0 - async def _resolve(_api_key: str, _db: object) -> object: + async def _resolve(_api_key: str, _db: object, **_: object) -> object: nonlocal calls calls += 1 - return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached")) + return _build_resolved_validation("org_cached") monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) @@ -168,8 +197,8 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache") second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache") - assert first == "org_cached" - assert second == "org_cached" + assert first.organization_id == "org_cached" + assert second.organization_id == "org_cached" assert calls == 1 @@ -177,21 +206,21 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None: calls = 0 - async def _resolve(_api_key: str, _db: object) -> object: + async def _resolve(_api_key: str, _db: object, **_: object) -> object: nonlocal calls calls += 1 - return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}")) + return _build_resolved_validation(f"org_{calls}") monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire") - cache_key = mcp_http_auth._cache_key("sk_test_cache_expire") + cache_key = mcp_http_auth.cache_key("sk_test_cache_expire") mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0) second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire") - assert first == "org_1" - assert second == "org_2" + assert first.organization_id == "org_1" + assert second.organization_id == "org_2" assert calls == 2 @@ -199,7 +228,7 @@ async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatc async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None: calls = 0 - async def _resolve(_api_key: str, _db: object) -> object: + async def _resolve(_api_key: str, _db: object, **_: object) -> object: nonlocal calls calls += 1 raise HTTPException(status_code=401, detail="Invalid credentials") @@ -222,22 +251,22 @@ async def test_validate_mcp_api_key_retries_transient_failure_without_negative_c ) -> None: calls = 0 - async def _resolve(_api_key: str, _db: object) -> object: + async def _resolve(_api_key: str, _db: object, **_: object) -> object: nonlocal calls calls += 1 if calls == 1: raise RuntimeError("transient db error") - return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered")) + return _build_resolved_validation("org_recovered") monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient") - cache_key = mcp_http_auth._cache_key("sk_test_transient") - assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered" + cache_key = mcp_http_auth.cache_key("sk_test_transient") + assert mcp_http_auth._api_key_validation_cache[cache_key][0].organization_id == "org_recovered" - assert recovered_org == "org_recovered" + assert recovered_org.organization_id == "org_recovered" assert calls == 2 @@ -249,16 +278,16 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed( collapses subsequent calls after the first one populates it.""" calls = 0 - async def _resolve(_api_key: str, _db: object) -> object: + async def _resolve(_api_key: str, _db: object, **_: object) -> object: nonlocal calls calls += 1 - return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent")) + return _build_resolved_validation("org_concurrent") monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)]) - assert all(r == "org_concurrent" for r in results) + assert all(r.organization_id == "org_concurrent" for r in results) # First call populates cache; remaining may or may not hit DB depending on # scheduling, but all must succeed. assert calls >= 1 @@ -268,7 +297,7 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed( async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None: calls = 0 - async def _resolve(_api_key: str, _db: object) -> object: + async def _resolve(_api_key: str, _db: object, **_: object) -> object: nonlocal calls calls += 1 raise RuntimeError("persistent db outage") @@ -302,3 +331,220 @@ async def test_close_auth_db_noop_when_uninitialized() -> None: mcp_http_auth._auth_db = None await mcp_http_auth.close_auth_db() assert mcp_http_auth._auth_db is None + + +@pytest.mark.asyncio +async def test_mcp_http_auth_denies_target_org_when_feature_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + validate_mock = AsyncMock( + return_value=_build_validation( + "org_admin", + OrganizationAuthTokenType.mcp_admin_impersonation, + ) + ) + monkeypatch.setattr( + mcp_http_auth, + "validate_mcp_api_key", + validate_mock, + ) + monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False) + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post( + "/mcp", + headers={"x-api-key": "sk_live_admin", "x-target-org-id": "org_target"}, + json={}, + ) + + assert response.status_code == 401 + assert response.json()["error"]["code"] == "UNAUTHORIZED" + assert response.json()["error"]["message"] == "Impersonation not allowed" + validate_mock.assert_awaited_once_with("sk_live_admin") + + +@pytest.mark.asyncio +async def test_mcp_http_auth_validates_api_key_before_feature_flag_denial(monkeypatch: pytest.MonkeyPatch) -> None: + validate_mock = AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials")) + monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock) + monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False) + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post( + "/mcp", + headers={"x-api-key": "bad-key", "x-target-org-id": "org_target"}, + json={}, + ) + + assert response.status_code == 401 + assert response.json()["error"]["code"] == "UNAUTHORIZED" + assert response.json()["error"]["message"] == "Invalid API key" + validate_mock.assert_awaited_once_with("bad-key") + + +@pytest.mark.parametrize("target_org_id", ["", " \t "]) +@pytest.mark.asyncio +async def test_mcp_http_auth_denies_empty_or_whitespace_target_org_id_header( + monkeypatch: pytest.MonkeyPatch, + target_org_id: str, +) -> None: + validate_mock = AsyncMock( + return_value=_build_validation( + "org_admin", + OrganizationAuthTokenType.mcp_admin_impersonation, + ) + ) + monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock) + monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: True) + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post( + "/mcp", + headers={"x-api-key": "sk_live_admin", "x-target-org-id": target_org_id}, + json={}, + ) + + assert response.status_code == 401 + assert response.json()["error"]["code"] == "UNAUTHORIZED" + assert response.json()["error"]["message"] == "Impersonation not allowed" + validate_mock.assert_awaited_once_with("sk_live_admin") + + +# --------------------------------------------------------------------------- +# Session-based impersonation tests +# --------------------------------------------------------------------------- + + +def test_impersonation_session_lifecycle() -> None: + """set → get → clear lifecycle.""" + admin_hash = mcp_http_auth.cache_key("sk_admin_key") + session = mcp_http_auth.ImpersonationSession( + admin_api_key_hash=admin_hash, + admin_org_id="org_admin", + target_org_id="org_target", + target_api_key="sk_target_key", + expires_at=time.monotonic() + 600, + ttl_minutes=10, + ) + mcp_http_auth.set_impersonation_session(session) + + retrieved = mcp_http_auth.get_active_impersonation(admin_hash) + assert retrieved is not None + assert retrieved.target_org_id == "org_target" + + cleared = mcp_http_auth.clear_impersonation_session(admin_hash) + assert cleared is not None + assert cleared.target_org_id == "org_target" + + assert mcp_http_auth.get_active_impersonation(admin_hash) is None + + +def test_impersonation_session_auto_expiry() -> None: + """Expired sessions are lazily cleaned up on get.""" + admin_hash = mcp_http_auth.cache_key("sk_admin_key") + session = mcp_http_auth.ImpersonationSession( + admin_api_key_hash=admin_hash, + admin_org_id="org_admin", + target_org_id="org_target", + target_api_key="sk_target_key", + expires_at=time.monotonic() - 1, # already expired + ttl_minutes=1, + ) + mcp_http_auth.set_impersonation_session(session) + + assert mcp_http_auth.get_active_impersonation(admin_hash) is None + + +@pytest.mark.asyncio +async def test_middleware_applies_session_impersonation(monkeypatch: pytest.MonkeyPatch) -> None: + """When a session is active, middleware auto-applies impersonation without header.""" + monkeypatch.setattr( + mcp_http_auth, + "validate_mcp_api_key", + AsyncMock( + return_value=_build_validation( + "org_admin", + OrganizationAuthTokenType.mcp_admin_impersonation, + ) + ), + ) + + admin_hash = mcp_http_auth.cache_key("sk_live_admin") + session = mcp_http_auth.ImpersonationSession( + admin_api_key_hash=admin_hash, + admin_org_id="org_admin", + target_org_id="org_target", + target_api_key="sk_live_target_key", + expires_at=time.monotonic() + 600, + ttl_minutes=10, + ) + mcp_http_auth.set_impersonation_session(session) + + app = _build_test_app() + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={}) + + assert response.status_code == 200 + body = response.json() + assert body["api_key"] == "sk_live_target_key" + assert body["organization_id"] == "org_target" + assert body["admin_organization_id"] == "org_admin" + assert body["impersonation_target_organization_id"] == "org_target" + + +@pytest.mark.asyncio +async def test_middleware_ignores_expired_session(monkeypatch: pytest.MonkeyPatch) -> None: + """Expired session is ignored — middleware reverts to admin's own org.""" + monkeypatch.setattr( + mcp_http_auth, + "validate_mcp_api_key", + AsyncMock( + return_value=_build_validation( + "org_admin", + OrganizationAuthTokenType.mcp_admin_impersonation, + ) + ), + ) + + admin_hash = mcp_http_auth.cache_key("sk_live_admin") + session = mcp_http_auth.ImpersonationSession( + admin_api_key_hash=admin_hash, + admin_org_id="org_admin", + target_org_id="org_target", + target_api_key="sk_live_target_key", + expires_at=time.monotonic() - 1, # expired + ttl_minutes=1, + ) + mcp_http_auth.set_impersonation_session(session) + + app = _build_test_app() + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={}) + + assert response.status_code == 200 + body = response.json() + assert body["api_key"] == "sk_live_admin" + assert body["organization_id"] == "org_admin" + assert body["admin_organization_id"] is None + + +@pytest.mark.asyncio +async def test_close_auth_db_clears_impersonation_sessions() -> None: + admin_hash = mcp_http_auth.cache_key("sk_admin_key") + session = mcp_http_auth.ImpersonationSession( + admin_api_key_hash=admin_hash, + admin_org_id="org_admin", + target_org_id="org_target", + target_api_key="sk_target_key", + expires_at=time.monotonic() + 600, + ttl_minutes=10, + ) + mcp_http_auth.set_impersonation_session(session) + + dispose = AsyncMock() + mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose)) + + await mcp_http_auth.close_auth_db() + + assert mcp_http_auth.get_active_impersonation(admin_hash) is None