Add gated admin impersonation controls for MCP API-key auth (#4822)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,8 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from contextvars import ContextVar, Token
|
||||||
|
from dataclasses import dataclass
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
@@ -14,6 +16,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
|
|||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||||
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
|
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
|
||||||
|
|
||||||
from .api_key_hash import hash_api_key_for_cache
|
from .api_key_hash import hash_api_key_for_cache
|
||||||
@@ -21,16 +24,94 @@ from .client import reset_api_key_override, set_api_key_override
|
|||||||
|
|
||||||
LOG = structlog.get_logger(__name__)
|
LOG = structlog.get_logger(__name__)
|
||||||
API_KEY_HEADER = "x-api-key"
|
API_KEY_HEADER = "x-api-key"
|
||||||
|
TARGET_ORG_ID_HEADER = "x-target-org-id"
|
||||||
HEALTH_PATHS = {"/health", "/healthz"}
|
HEALTH_PATHS = {"/health", "/healthz"}
|
||||||
|
_MCP_ALLOWED_TOKEN_TYPES = (
|
||||||
|
OrganizationAuthTokenType.api,
|
||||||
|
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||||
|
)
|
||||||
_auth_db: AgentDB | None = None
|
_auth_db: AgentDB | None = None
|
||||||
_auth_db_lock = RLock()
|
_auth_db_lock = RLock()
|
||||||
_api_key_cache_lock = RLock()
|
_api_key_cache_lock = RLock()
|
||||||
_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict()
|
_api_key_validation_cache: OrderedDict[str, tuple[MCPAPIKeyValidation | None, float]] = OrderedDict()
|
||||||
_NEGATIVE_CACHE_TTL_SECONDS = 5.0
|
_NEGATIVE_CACHE_TTL_SECONDS = 5.0
|
||||||
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
|
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
|
||||||
_MAX_VALIDATION_RETRIES = 2
|
_MAX_VALIDATION_RETRIES = 2
|
||||||
_RETRY_DELAY_SECONDS = 0.25
|
_RETRY_DELAY_SECONDS = 0.25
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Impersonation session state
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_admin_api_key_hash: ContextVar[str | None] = ContextVar("admin_api_key_hash", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ImpersonationSession:
|
||||||
|
admin_api_key_hash: str
|
||||||
|
admin_org_id: str
|
||||||
|
target_org_id: str
|
||||||
|
# Stored in plaintext in process memory for up to TTL. Acceptable V1 trade-off
|
||||||
|
# (in-process only, same as API key cache).
|
||||||
|
target_api_key: str
|
||||||
|
expires_at: float # time.monotonic() deadline
|
||||||
|
ttl_minutes: int
|
||||||
|
|
||||||
|
|
||||||
|
_impersonation_sessions: dict[str, ImpersonationSession] = {}
|
||||||
|
_impersonation_lock = RLock()
|
||||||
|
|
||||||
|
MAX_TTL_MINUTES = 120
|
||||||
|
DEFAULT_TTL_MINUTES = 30
|
||||||
|
_MAX_IMPERSONATION_SESSIONS = 100
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_impersonation(admin_key_hash: str) -> ImpersonationSession | None:
|
||||||
|
with _impersonation_lock:
|
||||||
|
session = _impersonation_sessions.get(admin_key_hash)
|
||||||
|
if session is None:
|
||||||
|
return None
|
||||||
|
if session.expires_at <= time.monotonic():
|
||||||
|
_impersonation_sessions.pop(admin_key_hash, None)
|
||||||
|
return None
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def set_impersonation_session(session: ImpersonationSession) -> None:
|
||||||
|
with _impersonation_lock:
|
||||||
|
_impersonation_sessions[session.admin_api_key_hash] = session
|
||||||
|
# Sweep expired sessions, then evict oldest if over capacity
|
||||||
|
now = time.monotonic()
|
||||||
|
expired_keys = [k for k, s in _impersonation_sessions.items() if s.expires_at <= now]
|
||||||
|
for k in expired_keys:
|
||||||
|
_impersonation_sessions.pop(k, None)
|
||||||
|
if len(_impersonation_sessions) > _MAX_IMPERSONATION_SESSIONS:
|
||||||
|
oldest_key = min(_impersonation_sessions, key=lambda k: _impersonation_sessions[k].expires_at)
|
||||||
|
_impersonation_sessions.pop(oldest_key, None)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_impersonation_session(admin_key_hash: str) -> ImpersonationSession | None:
|
||||||
|
with _impersonation_lock:
|
||||||
|
return _impersonation_sessions.pop(admin_key_hash, None)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all_impersonation_sessions() -> None:
|
||||||
|
with _impersonation_lock:
|
||||||
|
_impersonation_sessions.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def get_admin_api_key_hash() -> str | None:
|
||||||
|
return _admin_api_key_hash.get()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MCPAPIKeyValidation:
|
||||||
|
organization_id: str
|
||||||
|
token_type: OrganizationAuthTokenType
|
||||||
|
|
||||||
|
|
||||||
def _resolve_api_key_cache_ttl_seconds() -> float:
|
def _resolve_api_key_cache_ttl_seconds() -> float:
|
||||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
|
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
|
||||||
@@ -69,6 +150,7 @@ async def close_auth_db() -> None:
|
|||||||
_auth_db = None
|
_auth_db = None
|
||||||
with _api_key_cache_lock:
|
with _api_key_cache_lock:
|
||||||
_api_key_validation_cache.clear()
|
_api_key_validation_cache.clear()
|
||||||
|
clear_all_impersonation_sessions()
|
||||||
if db is None:
|
if db is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -78,24 +160,43 @@ async def close_auth_db() -> None:
|
|||||||
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
|
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def _cache_key(api_key: str) -> str:
|
def cache_key(api_key: str) -> str:
|
||||||
return hash_api_key_for_cache(api_key)
|
return hash_api_key_for_cache(api_key)
|
||||||
|
|
||||||
|
|
||||||
async def validate_mcp_api_key(api_key: str) -> str:
|
def _admin_organization_ids() -> set[str]:
|
||||||
"""Validate API key and return organization id for observability."""
|
try:
|
||||||
key = _cache_key(api_key)
|
from cloud.config import settings as cloud_settings # noqa: PLC0415
|
||||||
|
|
||||||
|
admin_ids = getattr(cloud_settings, "ADMIN_ORGANIZATION_IDS", [])
|
||||||
|
except ImportError:
|
||||||
|
admin_ids = getattr(settings, "ADMIN_ORGANIZATION_IDS", [])
|
||||||
|
return {org_id for org_id in admin_ids if org_id}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_admin_impersonation_enabled() -> bool:
|
||||||
|
try:
|
||||||
|
from cloud.config import settings as cloud_settings # noqa: PLC0415
|
||||||
|
|
||||||
|
return bool(getattr(cloud_settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False))
|
||||||
|
except ImportError:
|
||||||
|
return bool(getattr(settings, "MCP_ADMIN_IMPERSONATION_ENABLED", False))
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_mcp_api_key(api_key: str) -> MCPAPIKeyValidation:
|
||||||
|
"""Validate API key and return caller organization + token type."""
|
||||||
|
key = cache_key(api_key)
|
||||||
|
|
||||||
# Check cache first.
|
# Check cache first.
|
||||||
with _api_key_cache_lock:
|
with _api_key_cache_lock:
|
||||||
cached = _api_key_validation_cache.get(key)
|
cached = _api_key_validation_cache.get(key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
organization_id, expires_at = cached
|
cached_validation, expires_at = cached
|
||||||
if expires_at > time.monotonic():
|
if expires_at > time.monotonic():
|
||||||
_api_key_validation_cache.move_to_end(key)
|
_api_key_validation_cache.move_to_end(key)
|
||||||
if organization_id is None:
|
if cached_validation is None:
|
||||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
return organization_id
|
return cached_validation
|
||||||
_api_key_validation_cache.pop(key, None)
|
_api_key_validation_cache.pop(key, None)
|
||||||
|
|
||||||
# Cache miss — do the DB lookup with simple retry on transient errors.
|
# Cache miss — do the DB lookup with simple retry on transient errors.
|
||||||
@@ -104,17 +205,24 @@ async def validate_mcp_api_key(api_key: str) -> str:
|
|||||||
if attempt > 0:
|
if attempt > 0:
|
||||||
await asyncio.sleep(_RETRY_DELAY_SECONDS)
|
await asyncio.sleep(_RETRY_DELAY_SECONDS)
|
||||||
try:
|
try:
|
||||||
validation = await resolve_org_from_api_key(api_key, _get_auth_db())
|
validation = await resolve_org_from_api_key(
|
||||||
organization_id = validation.organization.organization_id
|
api_key,
|
||||||
|
_get_auth_db(),
|
||||||
|
token_types=_MCP_ALLOWED_TOKEN_TYPES,
|
||||||
|
)
|
||||||
|
caller_validation = MCPAPIKeyValidation(
|
||||||
|
organization_id=validation.organization.organization_id,
|
||||||
|
token_type=validation.token.token_type,
|
||||||
|
)
|
||||||
with _api_key_cache_lock:
|
with _api_key_cache_lock:
|
||||||
_api_key_validation_cache[key] = (
|
_api_key_validation_cache[key] = (
|
||||||
organization_id,
|
caller_validation,
|
||||||
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
|
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
|
||||||
)
|
)
|
||||||
_api_key_validation_cache.move_to_end(key)
|
_api_key_validation_cache.move_to_end(key)
|
||||||
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
|
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
|
||||||
_api_key_validation_cache.popitem(last=False)
|
_api_key_validation_cache.popitem(last=False)
|
||||||
return organization_id
|
return caller_validation
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
if e.status_code in {401, 403}:
|
if e.status_code in {401, 403}:
|
||||||
with _api_key_cache_lock:
|
with _api_key_cache_lock:
|
||||||
@@ -149,6 +257,24 @@ def _service_unavailable_response(message: str) -> JSONResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _deny_impersonation(
|
||||||
|
*,
|
||||||
|
reason: str,
|
||||||
|
caller_organization_id: str | None,
|
||||||
|
target_organization_id: str | None,
|
||||||
|
token_type: OrganizationAuthTokenType | None = None,
|
||||||
|
) -> JSONResponse:
|
||||||
|
log_kwargs: dict[str, object] = {
|
||||||
|
"reason": reason,
|
||||||
|
"caller_organization_id": caller_organization_id,
|
||||||
|
"target_organization_id": target_organization_id,
|
||||||
|
}
|
||||||
|
if token_type is not None:
|
||||||
|
log_kwargs["token_type"] = token_type.value
|
||||||
|
LOG.warning("MCP admin impersonation denied", **log_kwargs)
|
||||||
|
return _unauthorized_response("Impersonation not allowed")
|
||||||
|
|
||||||
|
|
||||||
class MCPAPIKeyMiddleware:
|
class MCPAPIKeyMiddleware:
|
||||||
"""Require x-api-key for MCP HTTP transport and scope requests to that key."""
|
"""Require x-api-key for MCP HTTP transport and scope requests to that key."""
|
||||||
|
|
||||||
@@ -176,10 +302,99 @@ class MCPAPIKeyMiddleware:
|
|||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
target_org_id_header = request.headers.get(TARGET_ORG_ID_HEADER)
|
||||||
|
admin_hash_token: Token[str | None] | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
organization_id = await validate_mcp_api_key(api_key)
|
validation = await validate_mcp_api_key(api_key)
|
||||||
|
caller_organization_id = validation.organization_id
|
||||||
|
admin_key_hash = cache_key(api_key)
|
||||||
|
|
||||||
scope.setdefault("state", {})
|
scope.setdefault("state", {})
|
||||||
scope["state"]["organization_id"] = organization_id
|
if target_org_id_header is not None:
|
||||||
|
# Explicit header-based impersonation (takes priority over session)
|
||||||
|
if not _is_admin_impersonation_enabled():
|
||||||
|
response = _deny_impersonation(
|
||||||
|
reason="feature_disabled",
|
||||||
|
caller_organization_id=caller_organization_id,
|
||||||
|
target_organization_id=target_org_id_header.strip() or target_org_id_header,
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
target_organization_id = target_org_id_header.strip()
|
||||||
|
if not target_organization_id:
|
||||||
|
response = _deny_impersonation(
|
||||||
|
reason="missing_target_organization_id",
|
||||||
|
caller_organization_id=caller_organization_id,
|
||||||
|
target_organization_id=target_org_id_header,
|
||||||
|
token_type=validation.token_type,
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Delegate validation to the single source of truth in cloud/.
|
||||||
|
# The import can't fail here — _is_admin_impersonation_enabled()
|
||||||
|
# already returned True, which requires cloud.config to be importable.
|
||||||
|
try:
|
||||||
|
from cloud.mcp_admin_tools import validate_impersonation_target # noqa: PLC0415
|
||||||
|
except ImportError:
|
||||||
|
response = _deny_impersonation(
|
||||||
|
reason="impersonation_not_available",
|
||||||
|
caller_organization_id=caller_organization_id,
|
||||||
|
target_organization_id=target_organization_id,
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await validate_impersonation_target(
|
||||||
|
caller_organization_id=caller_organization_id,
|
||||||
|
target_organization_id=target_organization_id,
|
||||||
|
token_type=validation.token_type,
|
||||||
|
)
|
||||||
|
if isinstance(result, str):
|
||||||
|
response = _deny_impersonation(
|
||||||
|
reason=result,
|
||||||
|
caller_organization_id=caller_organization_id,
|
||||||
|
target_organization_id=target_organization_id,
|
||||||
|
token_type=validation.token_type,
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
resolved_org_id, target_api_key = result
|
||||||
|
api_key = target_api_key
|
||||||
|
|
||||||
|
scope["state"]["organization_id"] = resolved_org_id
|
||||||
|
scope["state"]["admin_organization_id"] = caller_organization_id
|
||||||
|
scope["state"]["impersonation_target_organization_id"] = resolved_org_id
|
||||||
|
admin_hash_token = _admin_api_key_hash.set(admin_key_hash)
|
||||||
|
LOG.info(
|
||||||
|
"MCP admin impersonation allowed",
|
||||||
|
caller_organization_id=caller_organization_id,
|
||||||
|
target_organization_id=resolved_org_id,
|
||||||
|
token_type=validation.token_type.value,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No explicit header — check for session-based impersonation
|
||||||
|
session = get_active_impersonation(admin_key_hash)
|
||||||
|
if session is not None:
|
||||||
|
# Session data (target API key) is cached for TTL. If the target org's
|
||||||
|
# key is revoked mid-session, impersonation continues until expiry —
|
||||||
|
# acceptable trade-off vs re-validating every request.
|
||||||
|
api_key = session.target_api_key
|
||||||
|
scope["state"]["organization_id"] = session.target_org_id
|
||||||
|
scope["state"]["admin_organization_id"] = session.admin_org_id
|
||||||
|
scope["state"]["impersonation_target_organization_id"] = session.target_org_id
|
||||||
|
admin_hash_token = _admin_api_key_hash.set(admin_key_hash)
|
||||||
|
LOG.info(
|
||||||
|
"MCP session impersonation applied",
|
||||||
|
caller_organization_id=session.admin_org_id,
|
||||||
|
target_organization_id=session.target_org_id,
|
||||||
|
ttl_minutes=session.ttl_minutes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scope["state"]["organization_id"] = caller_organization_id
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
if e.status_code in {401, 403}:
|
if e.status_code in {401, 403}:
|
||||||
response = _unauthorized_response("Invalid API key")
|
response = _unauthorized_response("Invalid API key")
|
||||||
@@ -204,3 +419,5 @@ class MCPAPIKeyMiddleware:
|
|||||||
await self.app(scope, receive, send)
|
await self.app(scope, receive, send)
|
||||||
finally:
|
finally:
|
||||||
reset_api_key_override(token)
|
reset_api_key_override(token)
|
||||||
|
if admin_hash_token is not None:
|
||||||
|
_admin_api_key_hash.reset(admin_hash_token)
|
||||||
|
|||||||
@@ -286,6 +286,14 @@ mcp.tool()(skyvern_workflow_run)
|
|||||||
mcp.tool()(skyvern_workflow_status)
|
mcp.tool()(skyvern_workflow_status)
|
||||||
mcp.tool()(skyvern_workflow_cancel)
|
mcp.tool()(skyvern_workflow_cancel)
|
||||||
|
|
||||||
|
# -- Admin impersonation (cloud-only, session-level org switching) --
|
||||||
|
try:
|
||||||
|
from cloud.mcp_admin_tools import register_admin_tools # noqa: PLC0415
|
||||||
|
|
||||||
|
register_admin_tools(mcp)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
# -- Prompts (methodology guides injected into LLM conversations) --
|
# -- Prompts (methodology guides injected into LLM conversations) --
|
||||||
mcp.prompt()(build_workflow)
|
mcp.prompt()(build_workflow)
|
||||||
mcp.prompt()(debug_automation)
|
mcp.prompt()(debug_automation)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from enum import StrEnum
|
|||||||
|
|
||||||
class OrganizationAuthTokenType(StrEnum):
|
class OrganizationAuthTokenType(StrEnum):
|
||||||
api = "api"
|
api = "api"
|
||||||
|
mcp_admin_impersonation = "mcp_admin_impersonation"
|
||||||
onepassword_service_account = "onepassword_service_account"
|
onepassword_service_account = "onepassword_service_account"
|
||||||
azure_client_secret_credential = "azure_client_secret_credential"
|
azure_client_secret_credential = "azure_client_secret_credential"
|
||||||
custom_credential_service = "custom_credential_service"
|
custom_credential_service = "custom_credential_service"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Annotated
|
from typing import Annotated, Sequence
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from asyncache import cached
|
from asyncache import cached
|
||||||
@@ -176,6 +176,7 @@ async def _authenticate_user_helper(authorization: str) -> str:
|
|||||||
async def resolve_org_from_api_key(
|
async def resolve_org_from_api_key(
|
||||||
x_api_key: str,
|
x_api_key: str,
|
||||||
db: AgentDB,
|
db: AgentDB,
|
||||||
|
token_types: Sequence[OrganizationAuthTokenType] = (OrganizationAuthTokenType.api,),
|
||||||
) -> ApiKeyValidationResult:
|
) -> ApiKeyValidationResult:
|
||||||
"""Decode and validate the API key against the database."""
|
"""Decode and validate the API key against the database."""
|
||||||
try:
|
try:
|
||||||
@@ -202,12 +203,18 @@ async def resolve_org_from_api_key(
|
|||||||
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
|
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
|
||||||
raise HTTPException(status_code=404, detail="Organization not found")
|
raise HTTPException(status_code=404, detail="Organization not found")
|
||||||
|
|
||||||
|
api_key_db_obj: OrganizationAuthToken | None = None
|
||||||
|
# Try token types in priority order and stop at the first valid match.
|
||||||
|
for token_type in token_types:
|
||||||
api_key_db_obj = await db.validate_org_auth_token(
|
api_key_db_obj = await db.validate_org_auth_token(
|
||||||
organization_id=organization.organization_id,
|
organization_id=organization.organization_id,
|
||||||
token_type=OrganizationAuthTokenType.api,
|
token_type=token_type,
|
||||||
token=x_api_key,
|
token=x_api_key,
|
||||||
valid=None,
|
valid=None,
|
||||||
)
|
)
|
||||||
|
if api_key_db_obj:
|
||||||
|
break
|
||||||
|
|
||||||
if not api_key_db_obj:
|
if not api_key_db_obj:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
@@ -15,6 +16,7 @@ from starlette.routing import Route
|
|||||||
|
|
||||||
from skyvern.cli.core import client as client_mod
|
from skyvern.cli.core import client as client_mod
|
||||||
from skyvern.cli.core import mcp_http_auth
|
from skyvern.cli.core import mcp_http_auth
|
||||||
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -26,6 +28,7 @@ def _reset_auth_context() -> None:
|
|||||||
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
|
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
|
||||||
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
|
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
|
||||||
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
|
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
|
||||||
|
mcp_http_auth.clear_all_impersonation_sessions()
|
||||||
|
|
||||||
|
|
||||||
async def _echo_request_context(request: Request) -> JSONResponse:
|
async def _echo_request_context(request: Request) -> JSONResponse:
|
||||||
@@ -33,10 +36,34 @@ async def _echo_request_context(request: Request) -> JSONResponse:
|
|||||||
{
|
{
|
||||||
"api_key": client_mod.get_active_api_key(),
|
"api_key": client_mod.get_active_api_key(),
|
||||||
"organization_id": getattr(request.state, "organization_id", None),
|
"organization_id": getattr(request.state, "organization_id", None),
|
||||||
|
"admin_organization_id": getattr(request.state, "admin_organization_id", None),
|
||||||
|
"impersonation_target_organization_id": getattr(
|
||||||
|
request.state, "impersonation_target_organization_id", None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_validation(
|
||||||
|
organization_id: str,
|
||||||
|
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
|
||||||
|
) -> mcp_http_auth.MCPAPIKeyValidation:
|
||||||
|
return mcp_http_auth.MCPAPIKeyValidation(
|
||||||
|
organization_id=organization_id,
|
||||||
|
token_type=token_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_resolved_validation(
|
||||||
|
organization_id: str,
|
||||||
|
token_type: OrganizationAuthTokenType = OrganizationAuthTokenType.api,
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(
|
||||||
|
organization=SimpleNamespace(organization_id=organization_id),
|
||||||
|
token=SimpleNamespace(token_type=token_type),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_test_app() -> Starlette:
|
def _build_test_app() -> Starlette:
|
||||||
return Starlette(
|
return Starlette(
|
||||||
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
|
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
|
||||||
@@ -138,7 +165,7 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
mcp_http_auth,
|
mcp_http_auth,
|
||||||
"validate_mcp_api_key",
|
"validate_mcp_api_key",
|
||||||
AsyncMock(return_value="org_123"),
|
AsyncMock(return_value=_build_validation("org_123")),
|
||||||
)
|
)
|
||||||
app = _build_test_app()
|
app = _build_test_app()
|
||||||
|
|
||||||
@@ -149,6 +176,8 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
|
|||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
"api_key": "sk_live_abc",
|
"api_key": "sk_live_abc",
|
||||||
"organization_id": "org_123",
|
"organization_id": "org_123",
|
||||||
|
"admin_organization_id": None,
|
||||||
|
"impersonation_target_organization_id": None,
|
||||||
}
|
}
|
||||||
assert client_mod.get_active_api_key() != "sk_live_abc"
|
assert client_mod.get_active_api_key() != "sk_live_abc"
|
||||||
|
|
||||||
@@ -157,10 +186,10 @@ async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.Mon
|
|||||||
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
calls = 0
|
calls = 0
|
||||||
|
|
||||||
async def _resolve(_api_key: str, _db: object) -> object:
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||||
nonlocal calls
|
nonlocal calls
|
||||||
calls += 1
|
calls += 1
|
||||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached"))
|
return _build_resolved_validation("org_cached")
|
||||||
|
|
||||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
@@ -168,8 +197,8 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
|
|||||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||||
|
|
||||||
assert first == "org_cached"
|
assert first.organization_id == "org_cached"
|
||||||
assert second == "org_cached"
|
assert second.organization_id == "org_cached"
|
||||||
assert calls == 1
|
assert calls == 1
|
||||||
|
|
||||||
|
|
||||||
@@ -177,21 +206,21 @@ async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPat
|
|||||||
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
calls = 0
|
calls = 0
|
||||||
|
|
||||||
async def _resolve(_api_key: str, _db: object) -> object:
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||||
nonlocal calls
|
nonlocal calls
|
||||||
calls += 1
|
calls += 1
|
||||||
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}"))
|
return _build_resolved_validation(f"org_{calls}")
|
||||||
|
|
||||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
|
|
||||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||||
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire")
|
cache_key = mcp_http_auth.cache_key("sk_test_cache_expire")
|
||||||
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
|
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
|
||||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||||
|
|
||||||
assert first == "org_1"
|
assert first.organization_id == "org_1"
|
||||||
assert second == "org_2"
|
assert second.organization_id == "org_2"
|
||||||
assert calls == 2
|
assert calls == 2
|
||||||
|
|
||||||
|
|
||||||
@@ -199,7 +228,7 @@ async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatc
|
|||||||
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
calls = 0
|
calls = 0
|
||||||
|
|
||||||
async def _resolve(_api_key: str, _db: object) -> object:
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||||
nonlocal calls
|
nonlocal calls
|
||||||
calls += 1
|
calls += 1
|
||||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||||
@@ -222,22 +251,22 @@ async def test_validate_mcp_api_key_retries_transient_failure_without_negative_c
|
|||||||
) -> None:
|
) -> None:
|
||||||
calls = 0
|
calls = 0
|
||||||
|
|
||||||
async def _resolve(_api_key: str, _db: object) -> object:
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||||
nonlocal calls
|
nonlocal calls
|
||||||
calls += 1
|
calls += 1
|
||||||
if calls == 1:
|
if calls == 1:
|
||||||
raise RuntimeError("transient db error")
|
raise RuntimeError("transient db error")
|
||||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered"))
|
return _build_resolved_validation("org_recovered")
|
||||||
|
|
||||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
|
|
||||||
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
|
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
|
||||||
|
|
||||||
cache_key = mcp_http_auth._cache_key("sk_test_transient")
|
cache_key = mcp_http_auth.cache_key("sk_test_transient")
|
||||||
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered"
|
assert mcp_http_auth._api_key_validation_cache[cache_key][0].organization_id == "org_recovered"
|
||||||
|
|
||||||
assert recovered_org == "org_recovered"
|
assert recovered_org.organization_id == "org_recovered"
|
||||||
assert calls == 2
|
assert calls == 2
|
||||||
|
|
||||||
|
|
||||||
@@ -249,16 +278,16 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
|||||||
collapses subsequent calls after the first one populates it."""
|
collapses subsequent calls after the first one populates it."""
|
||||||
calls = 0
|
calls = 0
|
||||||
|
|
||||||
async def _resolve(_api_key: str, _db: object) -> object:
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||||
nonlocal calls
|
nonlocal calls
|
||||||
calls += 1
|
calls += 1
|
||||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent"))
|
return _build_resolved_validation("org_concurrent")
|
||||||
|
|
||||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
|
|
||||||
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
|
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
|
||||||
assert all(r == "org_concurrent" for r in results)
|
assert all(r.organization_id == "org_concurrent" for r in results)
|
||||||
# First call populates cache; remaining may or may not hit DB depending on
|
# First call populates cache; remaining may or may not hit DB depending on
|
||||||
# scheduling, but all must succeed.
|
# scheduling, but all must succeed.
|
||||||
assert calls >= 1
|
assert calls >= 1
|
||||||
@@ -268,7 +297,7 @@ async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
|||||||
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
calls = 0
|
calls = 0
|
||||||
|
|
||||||
async def _resolve(_api_key: str, _db: object) -> object:
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
||||||
nonlocal calls
|
nonlocal calls
|
||||||
calls += 1
|
calls += 1
|
||||||
raise RuntimeError("persistent db outage")
|
raise RuntimeError("persistent db outage")
|
||||||
@@ -302,3 +331,220 @@ async def test_close_auth_db_noop_when_uninitialized() -> None:
|
|||||||
mcp_http_auth._auth_db = None
|
mcp_http_auth._auth_db = None
|
||||||
await mcp_http_auth.close_auth_db()
|
await mcp_http_auth.close_auth_db()
|
||||||
assert mcp_http_auth._auth_db is None
|
assert mcp_http_auth._auth_db is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_denies_target_org_when_feature_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
validate_mock = AsyncMock(
|
||||||
|
return_value=_build_validation(
|
||||||
|
"org_admin",
|
||||||
|
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mcp_http_auth,
|
||||||
|
"validate_mcp_api_key",
|
||||||
|
validate_mock,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/mcp",
|
||||||
|
headers={"x-api-key": "sk_live_admin", "x-target-org-id": "org_target"},
|
||||||
|
json={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||||
|
assert response.json()["error"]["message"] == "Impersonation not allowed"
|
||||||
|
validate_mock.assert_awaited_once_with("sk_live_admin")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_validates_api_key_before_feature_flag_denial(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
validate_mock = AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials"))
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: False)
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/mcp",
|
||||||
|
headers={"x-api-key": "bad-key", "x-target-org-id": "org_target"},
|
||||||
|
json={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||||
|
assert response.json()["error"]["message"] == "Invalid API key"
|
||||||
|
validate_mock.assert_awaited_once_with("bad-key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("target_org_id", ["", " \t "])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_denies_empty_or_whitespace_target_org_id_header(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
target_org_id: str,
|
||||||
|
) -> None:
|
||||||
|
validate_mock = AsyncMock(
|
||||||
|
return_value=_build_validation(
|
||||||
|
"org_admin",
|
||||||
|
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mock)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_is_admin_impersonation_enabled", lambda: True)
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/mcp",
|
||||||
|
headers={"x-api-key": "sk_live_admin", "x-target-org-id": target_org_id},
|
||||||
|
json={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||||
|
assert response.json()["error"]["message"] == "Impersonation not allowed"
|
||||||
|
validate_mock.assert_awaited_once_with("sk_live_admin")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Session-based impersonation tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_impersonation_session_lifecycle() -> None:
|
||||||
|
"""set → get → clear lifecycle."""
|
||||||
|
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
|
||||||
|
session = mcp_http_auth.ImpersonationSession(
|
||||||
|
admin_api_key_hash=admin_hash,
|
||||||
|
admin_org_id="org_admin",
|
||||||
|
target_org_id="org_target",
|
||||||
|
target_api_key="sk_target_key",
|
||||||
|
expires_at=time.monotonic() + 600,
|
||||||
|
ttl_minutes=10,
|
||||||
|
)
|
||||||
|
mcp_http_auth.set_impersonation_session(session)
|
||||||
|
|
||||||
|
retrieved = mcp_http_auth.get_active_impersonation(admin_hash)
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.target_org_id == "org_target"
|
||||||
|
|
||||||
|
cleared = mcp_http_auth.clear_impersonation_session(admin_hash)
|
||||||
|
assert cleared is not None
|
||||||
|
assert cleared.target_org_id == "org_target"
|
||||||
|
|
||||||
|
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_impersonation_session_auto_expiry() -> None:
|
||||||
|
"""Expired sessions are lazily cleaned up on get."""
|
||||||
|
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
|
||||||
|
session = mcp_http_auth.ImpersonationSession(
|
||||||
|
admin_api_key_hash=admin_hash,
|
||||||
|
admin_org_id="org_admin",
|
||||||
|
target_org_id="org_target",
|
||||||
|
target_api_key="sk_target_key",
|
||||||
|
expires_at=time.monotonic() - 1, # already expired
|
||||||
|
ttl_minutes=1,
|
||||||
|
)
|
||||||
|
mcp_http_auth.set_impersonation_session(session)
|
||||||
|
|
||||||
|
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_middleware_applies_session_impersonation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""When a session is active, middleware auto-applies impersonation without header."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mcp_http_auth,
|
||||||
|
"validate_mcp_api_key",
|
||||||
|
AsyncMock(
|
||||||
|
return_value=_build_validation(
|
||||||
|
"org_admin",
|
||||||
|
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
|
||||||
|
session = mcp_http_auth.ImpersonationSession(
|
||||||
|
admin_api_key_hash=admin_hash,
|
||||||
|
admin_org_id="org_admin",
|
||||||
|
target_org_id="org_target",
|
||||||
|
target_api_key="sk_live_target_key",
|
||||||
|
expires_at=time.monotonic() + 600,
|
||||||
|
ttl_minutes=10,
|
||||||
|
)
|
||||||
|
mcp_http_auth.set_impersonation_session(session)
|
||||||
|
|
||||||
|
app = _build_test_app()
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["api_key"] == "sk_live_target_key"
|
||||||
|
assert body["organization_id"] == "org_target"
|
||||||
|
assert body["admin_organization_id"] == "org_admin"
|
||||||
|
assert body["impersonation_target_organization_id"] == "org_target"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_middleware_ignores_expired_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Expired session is ignored — middleware reverts to admin's own org."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mcp_http_auth,
|
||||||
|
"validate_mcp_api_key",
|
||||||
|
AsyncMock(
|
||||||
|
return_value=_build_validation(
|
||||||
|
"org_admin",
|
||||||
|
OrganizationAuthTokenType.mcp_admin_impersonation,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
admin_hash = mcp_http_auth.cache_key("sk_live_admin")
|
||||||
|
session = mcp_http_auth.ImpersonationSession(
|
||||||
|
admin_api_key_hash=admin_hash,
|
||||||
|
admin_org_id="org_admin",
|
||||||
|
target_org_id="org_target",
|
||||||
|
target_api_key="sk_live_target_key",
|
||||||
|
expires_at=time.monotonic() - 1, # expired
|
||||||
|
ttl_minutes=1,
|
||||||
|
)
|
||||||
|
mcp_http_auth.set_impersonation_session(session)
|
||||||
|
|
||||||
|
app = _build_test_app()
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post("/mcp", headers={"x-api-key": "sk_live_admin"}, json={})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["api_key"] == "sk_live_admin"
|
||||||
|
assert body["organization_id"] == "org_admin"
|
||||||
|
assert body["admin_organization_id"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_auth_db_clears_impersonation_sessions() -> None:
|
||||||
|
admin_hash = mcp_http_auth.cache_key("sk_admin_key")
|
||||||
|
session = mcp_http_auth.ImpersonationSession(
|
||||||
|
admin_api_key_hash=admin_hash,
|
||||||
|
admin_org_id="org_admin",
|
||||||
|
target_org_id="org_target",
|
||||||
|
target_api_key="sk_target_key",
|
||||||
|
expires_at=time.monotonic() + 600,
|
||||||
|
ttl_minutes=10,
|
||||||
|
)
|
||||||
|
mcp_http_auth.set_impersonation_session(session)
|
||||||
|
|
||||||
|
dispose = AsyncMock()
|
||||||
|
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
|
||||||
|
|
||||||
|
await mcp_http_auth.close_auth_db()
|
||||||
|
|
||||||
|
assert mcp_http_auth.get_active_impersonation(admin_hash) is None
|
||||||
|
|||||||
Reference in New Issue
Block a user