Add gated admin impersonation controls for MCP API-key auth (#4822)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,8 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextvars import ContextVar, Token
|
||||
from dataclasses import dataclass
|
||||
from threading import RLock
|
||||
|
||||
import structlog
|
||||
@@ -14,6 +16,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
|
||||
|
||||
from .api_key_hash import hash_api_key_for_cache
|
||||
@@ -21,16 +24,94 @@ from .client import reset_api_key_override, set_api_key_override
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
API_KEY_HEADER = "x-api-key"
|
||||
TARGET_ORG_ID_HEADER = "x-target-org-id"
|
||||
HEALTH_PATHS = {"/health", "/healthz"}
|
||||
_MCP_ALLOWED_TOKEN_TYPES = (
|
||||
OrganizationAuthTokenType.api,
|
||||
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||
)
|
||||
_auth_db: AgentDB | None = None
|
||||
_auth_db_lock = RLock()
|
||||
_api_key_cache_lock = RLock()
|
||||
_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict()
|
||||
_api_key_validation_cache: OrderedDict[str, tuple[MCPAPIKeyValidation | None, float]] = OrderedDict()
|
||||
_NEGATIVE_CACHE_TTL_SECONDS = 5.0
|
||||
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
|
||||
_MAX_VALIDATION_RETRIES = 2
|
||||
_RETRY_DELAY_SECONDS = 0.25
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Impersonation session state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_admin_api_key_hash: ContextVar[str | None] = ContextVar("admin_api_key_hash", default=None)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImpersonationSession:
|
||||
admin_api_key_hash: str
|
||||
admin_org_id: str
|
||||
target_org_id: str
|
||||
# Stored in plaintext in process memory for up to TTL. Acceptable V1 trade-off
|
||||
# (in-process only, same as API key cache).
|
||||
target_api_key: str
|
||||
expires_at: float # time.monotonic() deadline
|
||||
ttl_minutes: int
|
||||
|
||||
|
||||
_impersonation_sessions: dict[str, ImpersonationSession] = {}
|
||||
_impersonation_lock = RLock()
|
||||
|
||||
MAX_TTL_MINUTES = 120
|
||||
DEFAULT_TTL_MINUTES = 30
|
||||
_MAX_IMPERSONATION_SESSIONS = 100
|
||||
|
||||
|
||||
def get_active_impersonation(admin_key_hash: str) -> ImpersonationSession | None:
|
||||
with _impersonation_lock:
|
||||
session = _impersonation_sessions.get(admin_key_hash)
|
||||
if session is None:
|
||||
return None
|
||||
if session.expires_at <= time.monotonic():
|
||||
_impersonation_sessions.pop(admin_key_hash, None)
|
||||
return None
|
||||
return session
|
||||
|
||||
|
||||
def set_impersonation_session(session: ImpersonationSession) -> None:
|
||||
with _impersonation_lock:
|
||||
_impersonation_sessions[session.admin_api_key_hash] = session
|
||||
# Sweep expired sessions, then evict oldest if over capacity
|
||||
now = time.monotonic()
|
||||
expired_keys = [k for k, s in _impersonation_sessions.items() if s.expires_at <= now]
|
||||
for k in expired_keys:
|
||||
_impersonation_sessions.pop(k, None)
|
||||
if len(_impersonation_sessions) > _MAX_IMPERSONATION_SESSIONS:
|
||||
oldest_key = min(_impersonation_sessions, key=lambda k: _impersonation_sessions[k].expires_at)
|
||||
_impersonation_sessions.pop(oldest_key, None)
|
||||
|
||||
|
||||
def clear_impersonation_session(admin_key_hash: str) -> ImpersonationSession | None:
|
||||
with _impersonation_lock:
|
||||
return _impersonation_sessions.pop(admin_key_hash, None)
|
||||
|
||||
|
||||
def clear_all_impersonation_sessions() -> None:
|
||||
with _impersonation_lock:
|
||||
_impersonation_sessions.clear()
|
||||
|
||||
|
||||
def get_admin_api_key_hash() -> str | None:
|
||||
return _admin_api_key_hash.get()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MCPAPIKeyValidation:
|
||||
organization_id: str
|
||||
token_type: OrganizationAuthTokenType
|
||||
|
||||
|
||||
def _resolve_api_key_cache_ttl_seconds() -> float:
|
||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
|
||||
@@ -69,6 +150,7 @@ async def close_auth_db() -> None:
|
||||
_auth_db = None
|
||||
with _api_key_cache_lock:
|
||||
_api_key_validation_cache.clear()
|
||||
clear_all_impersonation_sessions()
|
||||
if db is None:
|
||||
return
|
||||
|
||||
@@ -78,24 +160,43 @@ async def close_auth_db() -> None:
|
||||
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
|
||||
|
||||
|
||||
def _cache_key(api_key: str) -> str:
|
||||
def cache_key(api_key: str) -> str:
|
||||
return hash_api_key_for_cache(api_key)
|
||||
|
||||
|
||||
async def validate_mcp_api_key(api_key: str) -> str:
|
||||
"""Validate API key and return organization id for observability."""
|
||||
key = _cache_key(api_key)
|
||||
def _admin_organization_ids() -> set[str]:
|
||||
try:
|
||||
from cloud.config import settings as cloud_settings # noqa: PLC0415
|
||||
|
||||
admin_ids = getattr(cloud_settings, "ADMIN_ORGANIZATION_IDS", [])
|
||||
except ImportError:
|
||||
admin_ids = getattr(settings, "ADMIN_ORGANIZATION_IDS", [])
|
||||
return {org_id for org_id in admin_ids if org_id}
|
||||
|
||||
|
||||
def _is_admin_impersonation_enabled() -> bool:
|
||||
try:
|
||||
from cloud.config import settings as cloud_settings # noqa: PLC0415
|
||||
|
||||
return bool(getattr(cloud_settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False))
|
||||
except ImportError:
|
||||
return bool(getattr(settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False))
|
||||
|
||||
|
||||
async def validate_mcp_api_key(api_key: str) -> MCPAPIKeyValidation:
|
||||
"""Validate API key and return caller organization + token type."""
|
||||
key = cache_key(api_key)
|
||||
|
||||
# Check cache first.
|
||||
with _api_key_cache_lock:
|
||||
cached = _api_key_validation_cache.get(key)
|
||||
if cached is not None:
|
||||
organization_id, expires_at = cached
|
||||
cached_validation, expires_at = cached
|
||||
if expires_at > time.monotonic():
|
||||
_api_key_validation_cache.move_to_end(key)
|
||||
if organization_id is None:
|
||||
if cached_validation is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return organization_id
|
||||
return cached_validation
|
||||
_api_key_validation_cache.pop(key, None)
|
||||
|
||||
# Cache miss — do the DB lookup with simple retry on transient errors.
|
||||
@@ -104,17 +205,24 @@ async def validate_mcp_api_key(api_key: str) -> str:
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(_RETRY_DELAY_SECONDS)
|
||||
try:
|
||||
validation = await resolve_org_from_api_key(api_key, _get_auth_db())
|
||||
organization_id = validation.organization.organization_id
|
||||
validation = await resolve_org_from_api_key(
|
||||
api_key,
|
||||
_get_auth_db(),
|
||||
token_types=_MCP_ALLOWED_TOKEN_TYPES,
|
||||
)
|
||||
caller_validation = MCPAPIKeyValidation(
|
||||
organization_id=validation.organization.organization_id,
|
||||
token_type=validation.token.token_type,
|
||||
)
|
||||
with _api_key_cache_lock:
|
||||
_api_key_validation_cache[key] = (
|
||||
organization_id,
|
||||
caller_validation,
|
||||
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
|
||||
)
|
||||
_api_key_validation_cache.move_to_end(key)
|
||||
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
|
||||
_api_key_validation_cache.popitem(last=False)
|
||||
return organization_id
|
||||
return caller_validation
|
||||
except HTTPException as e:
|
||||
if e.status_code in {401, 403}:
|
||||
with _api_key_cache_lock:
|
||||
@@ -149,6 +257,24 @@ def _service_unavailable_response(message: str) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
def _deny_impersonation(
|
||||
*,
|
||||
reason: str,
|
||||
caller_organization_id: str | None,
|
||||
target_organization_id: str | None,
|
||||
token_type: OrganizationAuthTokenType | None = None,
|
||||
) -> JSONResponse:
|
||||
log_kwargs: dict[str, object] = {
|
||||
"reason": reason,
|
||||
"caller_organization_id": caller_organization_id,
|
||||
"target_organization_id": target_organization_id,
|
||||
}
|
||||
if token_type is not None:
|
||||
log_kwargs["token_type"] = token_type.value
|
||||
LOG.warning("MCP admin impersonation denied", **log_kwargs)
|
||||
return _unauthorized_response("Impersonation not allowed")
|
||||
|
||||
|
||||
class MCPAPIKeyMiddleware:
|
||||
"""Require x-api-key for MCP HTTP transport and scope requests to that key."""
|
||||
|
||||
@@ -176,10 +302,99 @@ class MCPAPIKeyMiddleware:
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
target_org_id_header = request.headers.get(TARGET_ORG_ID_HEADER)
|
||||
admin_hash_token: Token[str | None] | None = None
|
||||
|
||||
try:
|
||||
organization_id = await validate_mcp_api_key(api_key)
|
||||
validation = await validate_mcp_api_key(api_key)
|
||||
caller_organization_id = validation.organization_id
|
||||
admin_key_hash = cache_key(api_key)
|
||||
|
||||
scope.setdefault("state", {})
|
||||
scope["state"]["organization_id"] = organization_id
|
||||
if target_org_id_header is not None:
|
||||
# Explicit header-based impersonation (takes priority over session)
|
||||
if not _is_admin_impersonation_enabled():
|
||||
response = _deny_impersonation(
|
||||
reason="feature_disabled",
|
||||
caller_organization_id=caller_organization_id,
|
||||
target_organization_id=target_org_id_header.strip() or target_org_id_header,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
target_organization_id = target_org_id_header.strip()
|
||||
if not target_organization_id:
|
||||
response = _deny_impersonation(
|
||||
reason="missing_target_organization_id",
|
||||
caller_organization_id=caller_organization_id,
|
||||
target_organization_id=target_org_id_header,
|
||||
token_type=validation.token_type,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Delegate validation to the single source of truth in cloud/.
|
||||
# The import can't fail here — _is_admin_impersonation_enabled()
|
||||
# already returned True, which requires cloud.config to be importable.
|
||||
try:
|
||||
from cloud.mcp_admin_tools import validate_impersonation_target # noqa: PLC0415
|
||||
except ImportError:
|
||||
response = _deny_impersonation(
|
||||
reason="impersonation_not_available",
|
||||
caller_organization_id=caller_organization_id,
|
||||
target_organization_id=target_organization_id,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
result = await validate_impersonation_target(
|
||||
caller_organization_id=caller_organization_id,
|
||||
target_organization_id=target_organization_id,
|
||||
token_type=validation.token_type,
|
||||
)
|
||||
if isinstance(result, str):
|
||||
response = _deny_impersonation(
|
||||
reason=result,
|
||||
caller_organization_id=caller_organization_id,
|
||||
target_organization_id=target_organization_id,
|
||||
token_type=validation.token_type,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
resolved_org_id, target_api_key = result
|
||||
api_key = target_api_key
|
||||
|
||||
scope["state"]["organization_id"] = resolved_org_id
|
||||
scope["state"]["admin_organization_id"] = caller_organization_id
|
||||
scope["state"]["impersonation_target_organization_id"] = resolved_org_id
|
||||
admin_hash_token = _admin_api_key_hash.set(admin_key_hash)
|
||||
LOG.info(
|
||||
"MCP admin impersonation allowed",
|
||||
caller_organization_id=caller_organization_id,
|
||||
target_organization_id=resolved_org_id,
|
||||
token_type=validation.token_type.value,
|
||||
)
|
||||
else:
|
||||
# No explicit header — check for session-based impersonation
|
||||
session = get_active_impersonation(admin_key_hash)
|
||||
if session is not None:
|
||||
# Session data (target API key) is cached for TTL. If the target org's
|
||||
# key is revoked mid-session, impersonation continues until expiry —
|
||||
# acceptable trade-off vs re-validating every request.
|
||||
api_key = session.target_api_key
|
||||
scope["state"]["organization_id"] = session.target_org_id
|
||||
scope["state"]["admin_organization_id"] = session.admin_org_id
|
||||
scope["state"]["impersonation_target_organization_id"] = session.target_org_id
|
||||
admin_hash_token = _admin_api_key_hash.set(admin_key_hash)
|
||||
LOG.info(
|
||||
"MCP session impersonation applied",
|
||||
caller_organization_id=session.admin_org_id,
|
||||
target_organization_id=session.target_org_id,
|
||||
ttl_minutes=session.ttl_minutes,
|
||||
)
|
||||
else:
|
||||
scope["state"]["organization_id"] = caller_organization_id
|
||||
except HTTPException as e:
|
||||
if e.status_code in {401, 403}:
|
||||
response = _unauthorized_response("Invalid API key")
|
||||
@@ -204,3 +419,5 @@ class MCPAPIKeyMiddleware:
|
||||
await self.app(scope, receive, send)
|
||||
finally:
|
||||
reset_api_key_override(token)
|
||||
if admin_hash_token is not None:
|
||||
_admin_api_key_hash.reset(admin_hash_token)
|
||||
|
||||
@@ -286,6 +286,14 @@ mcp.tool()(skyvern_workflow_run)
|
||||
mcp.tool()(skyvern_workflow_status)
|
||||
mcp.tool()(skyvern_workflow_cancel)
|
||||
|
||||
# -- Admin impersonation (cloud-only, session-level org switching) --
|
||||
try:
|
||||
from cloud.mcp_admin_tools import register_admin_tools # noqa: PLC0415
|
||||
|
||||
register_admin_tools(mcp)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# -- Prompts (methodology guides injected into LLM conversations) --
|
||||
mcp.prompt()(build_workflow)
|
||||
mcp.prompt()(debug_automation)
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import StrEnum
|
||||
|
||||
class OrganizationAuthTokenType(StrEnum):
|
||||
api = "api"
|
||||
mcp_admin_impersonation = "mcp_admin_impersonation"
|
||||
onepassword_service_account = "onepassword_service_account"
|
||||
azure_client_secret_credential = "azure_client_secret_credential"
|
||||
custom_credential_service = "custom_credential_service"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
import structlog
|
||||
from asyncache import cached
|
||||
@@ -176,6 +176,7 @@ async def _authenticate_user_helper(authorization: str) -> str:
|
||||
async def resolve_org_from_api_key(
|
||||
x_api_key: str,
|
||||
db: AgentDB,
|
||||
token_types: Sequence[OrganizationAuthTokenType] = (OrganizationAuthTokenType.api,),
|
||||
) -> ApiKeyValidationResult:
|
||||
"""Decode and validate the API key against the database."""
|
||||
try:
|
||||
@@ -202,12 +203,18 @@ async def resolve_org_from_api_key(
|
||||
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
api_key_db_obj = await db.validate_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=OrganizationAuthTokenType.api,
|
||||
token=x_api_key,
|
||||
valid=None,
|
||||
)
|
||||
api_key_db_obj: OrganizationAuthToken | None = None
|
||||
# Try token types in priority order and stop at the first valid match.
|
||||
for token_type in token_types:
|
||||
api_key_db_obj = await db.validate_org_auth_token(
|
||||
organization_id=organization.organization_id,
|
||||
token_type=token_type,
|
||||
token=x_api_key,
|
||||
valid=None,
|
||||
)
|
||||
if api_key_db_obj:
|
||||
break
|
||||
|
||||
if not api_key_db_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
@@ -15,6 +16,7 @@ from starlette.routing import Route
|
||||
|
||||
from skyvern.cli.core import client as client_mod
|
||||
from skyvern.cli.core import mcp_http_auth
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -26,6 +28,7 @@ def _reset_auth_context() -> None:
|
||||
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
|
||||
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
|
||||
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
|
||||
mcp_http_auth.clear_all_impersonation_sessions()
|
||||
|
||||
|
||||
async def _echo_request_context(request: Request) -> JSONResponse:
|
||||
@@ -33,10 +36,34 @@ async def _echo_request_context(request: Request) -> JSONResponse:
|
||||
{
|
||||
"api_key": client_mod.get_active_api_key(),
|
||||
"organization_id": getattr(request.state, "organization_id", None),
|
||||
"admin_organization_id": getattr(request.state, "admin_organization_id", None),
|
||||
"impersonation_target_organization_id": getattr(
|
||||
request.state, "impersonation_target_organization_id", None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_validation(
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
|
||||
) -> mcp_http_auth.MCPAPIKeyValidation:
|
||||
return mcp_http_auth.MCPAPIKeyValidation(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
)
|
||||
|
||||
|
||||
def _build_resolved_validation(
|
||||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
organization=SimpleNamespace(organization_id=organization_id),
|
||||
token=SimpleNamespace(token_type=token_type),
|
||||
)
|
||||
|
||||
|
||||
def _build_test_app() -> Starlette:
|
||||
return Starlette(
|
||||
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
|
||||
@@ -138,7 +165,7 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(return_value="org_123"),
|
||||
AsyncMock(return_value=_build_validation("org_123")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
@@ -149,6 +176,8 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
|
||||
assert response.json() == {
|
||||
"api_key": "sk_live_abc",
|
||||
"organization_id": "org_123",
|
||||
"admin_organization_id": None,
|
||||
"impersonation_target_organization_id": None,
|
||||
}
|
||||
assert client_mod.get_active_api_key() != "sk_live_abc"
|
||||
|
||||
@@ -157,10 +186,10 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
|
||||
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached"))
|
||||
return _build_resolved_validation("org_cached")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
@@ -168,8 +197,8 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
|
||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||
|
||||
assert first == "org_cached"
|
||||
assert second == "org_cached"
|
||||
assert first.organization_id == "org_cached"
|
||||
assert second.organization_id == "org_cached"
|
||||
assert calls == 1
|
||||
|
||||
|
||||
@@ -177,21 +206,21 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
|
||||
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}"))
|
||||
return _build_resolved_validation(f"org_{calls}")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire")
|
||||
cache_key = mcp_http_auth.cache_key("sk_test_cache_expire")
|
||||
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
|
||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||
|
||||
assert first == "org_1"
|
||||
assert second == "org_2"
|
||||
assert first.organization_id == "org_1"
|
||||
assert second.organization_id == "org_2"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
@@ -199,7 +228,7 @@ async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatc
|
||||
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
@@ -222,22 +251,22 @@ async def test_validate_mcp_api_key_retries_transient_failure_without_negative_c
|
||||
) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
if calls == 1:
|
||||
raise RuntimeError("transient db error")
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered"))
|
||||
return _build_resolved_validation("org_recovered")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
|
||||
|
||||
cache_key = mcp_http_auth._cache_key("sk_test_transient")
|
||||
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered"
|
||||
cache_key = mcp_http_auth.cache_key("sk_test_transient")
|
||||
assert mcp_http_auth._api_key_validation_cache[cache_key][0].organization_id == "org_recovered"
|
||||
|
||||
assert recovered_org == "org_recovered"
|
||||
assert recovered_org.organization_id == "org_recovered"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
@@ -249,16 +278,16 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
||||
collapses subsequent calls after the first one populates it."""
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent"))
|
||||
return _build_resolved_validation("org_concurrent")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
|
||||
assert all(r == "org_concurrent" for r in results)
|
||||
assert all(r.organization_id == "org_concurrent" for r in results)
|
||||
# First call populates cache; remaining may or may not hit DB depending on
|
||||
# scheduling, but all must succeed.
|
||||
assert calls >= 1
|
||||
@@ -268,7 +297,7 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
||||
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise RuntimeError("persistent db outage")
|
||||
@@ -302,3 +331,220 @@ async def test_close_auth_db_noop_when_uninitialized() -> None:
|
||||
mcp_http_auth._auth_db = None
|
||||
await mcp_http_auth.close_auth_db()
|
||||
assert mcp_http_auth._auth_db is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_denies_target_org_when_feature_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
validate_mock = AsyncMock(
|
||||
return_value=_build_validation(
|
||||
"org_admin",
|
||||
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
validate_mock,
|
||||
)
|
||||
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post(
|
||||
"/mcp",
|
||||
headers={"x-api-key": "sk_live_admin", "x-target-org-id": "org_target"},
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert response.json()["error"]["message"] == "Impersonation not allowed"
|
||||
validate_mock.assert_awaited_once_with("sk_live_admin")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_validates_api_key_before_feature_flag_denial(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
validate_mock = AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials"))
|
||||
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
|
||||
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post(
|
||||
"/mcp",
|
||||
headers={"x-api-key": "bad-key", "x-target-org-id": "org_target"},
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert response.json()["error"]["message"] == "Invalid API key"
|
||||
validate_mock.assert_awaited_once_with("bad-key")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("target_org_id", ["", " \t "])
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_denies_empty_or_whitespace_target_org_id_header(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
target_org_id: str,
|
||||
) -> None:
|
||||
validate_mock = AsyncMock(
|
||||
return_value=_build_validation(
|
||||
"org_admin",
|
||||
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
|
||||
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: True)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post(
|
||||
"/mcp",
|
||||
headers={"x-api-key": "sk_live_admin", "x-target-org-id": target_org_id},
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert response.json()["error"]["message"] == "Impersonation not allowed"
|
||||
validate_mock.assert_awaited_once_with("sk_live_admin")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session-based impersonation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_impersonation_session_lifecycle() -> None:
|
||||
"""set → get → clear lifecycle."""
|
||||
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_target_key",
|
||||
expires_at=time.monotonic() + 600,
|
||||
ttl_minutes=10,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
retrieved = mcp_http_auth.get_active_impersonation(admin_hash)
|
||||
assert retrieved is not None
|
||||
assert retrieved.target_org_id == "org_target"
|
||||
|
||||
cleared = mcp_http_auth.clear_impersonation_session(admin_hash)
|
||||
assert cleared is not None
|
||||
assert cleared.target_org_id == "org_target"
|
||||
|
||||
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
|
||||
|
||||
|
||||
def test_impersonation_session_auto_expiry() -> None:
|
||||
"""Expired sessions are lazily cleaned up on get."""
|
||||
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_target_key",
|
||||
expires_at=time.monotonic() - 1, # already expired
|
||||
ttl_minutes=1,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_applies_session_impersonation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""When a session is active, middleware auto-applies impersonation without header."""
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(
|
||||
return_value=_build_validation(
|
||||
"org_admin",
|
||||
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_live_target_key",
|
||||
expires_at=time.monotonic() + 600,
|
||||
ttl_minutes=10,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
app = _build_test_app()
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["api_key"] == "sk_live_target_key"
|
||||
assert body["organization_id"] == "org_target"
|
||||
assert body["admin_organization_id"] == "org_admin"
|
||||
assert body["impersonation_target_organization_id"] == "org_target"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_ignores_expired_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Expired session is ignored — middleware reverts to admin's own org."""
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(
|
||||
return_value=_build_validation(
|
||||
"org_admin",
|
||||
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_live_target_key",
|
||||
expires_at=time.monotonic() - 1, # expired
|
||||
ttl_minutes=1,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
app = _build_test_app()
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["api_key"] == "sk_live_admin"
|
||||
assert body["organization_id"] == "org_admin"
|
||||
assert body["admin_organization_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_auth_db_clears_impersonation_sessions() -> None:
|
||||
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
|
||||
session = mcp_http_auth.ImpersonationSession(
|
||||
admin_api_key_hash=admin_hash,
|
||||
admin_org_id="org_admin",
|
||||
target_org_id="org_target",
|
||||
target_api_key="sk_target_key",
|
||||
expires_at=time.monotonic() + 600,
|
||||
ttl_minutes=10,
|
||||
)
|
||||
mcp_http_auth.set_impersonation_session(session)
|
||||
|
||||
dispose = AsyncMock()
|
||||
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
|
||||
|
||||
await mcp_http_auth.close_auth_db()
|
||||
|
||||
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
|
||||
|
||||
Reference in New Issue
Block a user