align workflow CLI commands with MCP parity (#4792)
This commit is contained in:
29
skyvern/cli/core/api_key_hash.py
Normal file
29
skyvern/cli/core/api_key_hash.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
|
||||
def _resolve_api_key_hash_iterations() -> int:
|
||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_HASH_ITERATIONS", "120000")
|
||||
try:
|
||||
return max(10_000, int(raw))
|
||||
except ValueError:
|
||||
return 120_000
|
||||
|
||||
|
||||
_API_KEY_HASH_ITERATIONS = _resolve_api_key_hash_iterations()
|
||||
_API_KEY_HASH_SALT = os.environ.get(
|
||||
"SKYVERN_MCP_API_KEY_HASH_SALT",
|
||||
"skyvern-mcp-api-key-cache-v1",
|
||||
).encode("utf-8")
|
||||
|
||||
|
||||
def hash_api_key_for_cache(api_key: str) -> str:
|
||||
"""Derive a deterministic, non-reversible fingerprint for API-key keyed caches."""
|
||||
return hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
api_key.encode("utf-8"),
|
||||
_API_KEY_HASH_SALT,
|
||||
_API_KEY_HASH_ITERATIONS,
|
||||
).hex()
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from collections import OrderedDict
|
||||
from contextvars import ContextVar, Token
|
||||
from threading import RLock
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -9,35 +12,126 @@ from skyvern.client import SkyvernEnvironment
|
||||
from skyvern.config import settings
|
||||
from skyvern.library.skyvern import Skyvern
|
||||
|
||||
from .api_key_hash import hash_api_key_for_cache
|
||||
|
||||
_skyvern_instance: ContextVar[Skyvern | None] = ContextVar("skyvern_instance", default=None)
|
||||
_api_key_override: ContextVar[str | None] = ContextVar("skyvern_api_key_override", default=None)
|
||||
_global_skyvern_instance: Skyvern | None = None
|
||||
_api_key_clients: OrderedDict[str, Skyvern] = OrderedDict()
|
||||
_clients_lock = RLock()
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
def _resolve_api_key_cache_size() -> int:
|
||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_CLIENT_CACHE_SIZE", "128")
|
||||
try:
|
||||
return max(1, int(raw))
|
||||
except ValueError:
|
||||
return 128
|
||||
|
||||
|
||||
_API_KEY_CLIENT_CACHE_MAX = _resolve_api_key_cache_size()
|
||||
|
||||
|
||||
def _cache_key(api_key: str) -> str:
|
||||
"""Hash API key so raw secrets are never stored as dict keys."""
|
||||
return hash_api_key_for_cache(api_key)
|
||||
|
||||
|
||||
def _resolve_api_key() -> str | None:
|
||||
return settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
|
||||
|
||||
|
||||
def _resolve_base_url() -> str | None:
|
||||
return settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
|
||||
|
||||
|
||||
def _build_cloud_client(api_key: str) -> Skyvern:
|
||||
return Skyvern(
|
||||
api_key=api_key,
|
||||
environment=SkyvernEnvironment.CLOUD,
|
||||
base_url=_resolve_base_url(),
|
||||
)
|
||||
|
||||
|
||||
def _close_skyvern_instance_best_effort(instance: Skyvern) -> None:
|
||||
"""Close a Skyvern instance, regardless of whether an event loop is running."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
try:
|
||||
asyncio.run(instance.aclose())
|
||||
except Exception:
|
||||
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
|
||||
return
|
||||
|
||||
task = loop.create_task(instance.aclose())
|
||||
|
||||
def _on_done(done: asyncio.Task[None]) -> None:
|
||||
try:
|
||||
done.result()
|
||||
except Exception:
|
||||
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
|
||||
|
||||
task.add_done_callback(_on_done)
|
||||
|
||||
|
||||
def get_active_api_key() -> str | None:
|
||||
"""Return the effective API key for this request/context."""
|
||||
return _api_key_override.get() or _resolve_api_key()
|
||||
|
||||
|
||||
def set_api_key_override(api_key: str | None) -> Token[str | None]:
|
||||
"""Set request-scoped API key override for MCP HTTP requests."""
|
||||
_skyvern_instance.set(None)
|
||||
return _api_key_override.set(api_key)
|
||||
|
||||
|
||||
def reset_api_key_override(token: Token[str | None]) -> None:
|
||||
"""Reset request-scoped API key override."""
|
||||
_api_key_override.reset(token)
|
||||
_skyvern_instance.set(None)
|
||||
|
||||
|
||||
def get_skyvern() -> Skyvern:
|
||||
"""Get or create a Skyvern client instance."""
|
||||
global _global_skyvern_instance
|
||||
|
||||
instance = _skyvern_instance.get()
|
||||
if instance is None:
|
||||
instance = _global_skyvern_instance
|
||||
if instance is not None:
|
||||
override_api_key = _api_key_override.get()
|
||||
if override_api_key:
|
||||
instance = _skyvern_instance.get()
|
||||
if instance is None:
|
||||
key = _cache_key(override_api_key)
|
||||
evicted_clients: list[Skyvern] = []
|
||||
# Hold lock across lookup + build + insert to prevent two coroutines
|
||||
# from both building a client for the same API key concurrently.
|
||||
with _clients_lock:
|
||||
instance = _api_key_clients.get(key)
|
||||
if instance is not None:
|
||||
_api_key_clients.move_to_end(key)
|
||||
else:
|
||||
instance = _build_cloud_client(override_api_key)
|
||||
_api_key_clients[key] = instance
|
||||
_api_key_clients.move_to_end(key)
|
||||
while len(_api_key_clients) > _API_KEY_CLIENT_CACHE_MAX:
|
||||
_, evicted = _api_key_clients.popitem(last=False)
|
||||
evicted_clients.append(evicted)
|
||||
for evicted in evicted_clients:
|
||||
_close_skyvern_instance_best_effort(evicted)
|
||||
_skyvern_instance.set(instance)
|
||||
return instance
|
||||
|
||||
api_key = settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
|
||||
base_url = settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
|
||||
|
||||
if api_key:
|
||||
instance = Skyvern(
|
||||
api_key=api_key,
|
||||
environment=SkyvernEnvironment.CLOUD,
|
||||
base_url=base_url,
|
||||
)
|
||||
else:
|
||||
instance = Skyvern.local()
|
||||
|
||||
_global_skyvern_instance = instance
|
||||
instance = _skyvern_instance.get()
|
||||
if instance is None:
|
||||
with _clients_lock:
|
||||
instance = _global_skyvern_instance
|
||||
if instance is None:
|
||||
api_key = _resolve_api_key()
|
||||
if api_key:
|
||||
instance = _build_cloud_client(api_key)
|
||||
else:
|
||||
instance = Skyvern.local()
|
||||
_global_skyvern_instance = instance
|
||||
_skyvern_instance.set(instance)
|
||||
return instance
|
||||
|
||||
@@ -48,7 +142,12 @@ async def close_skyvern() -> None:
|
||||
|
||||
instances: list[Skyvern] = []
|
||||
seen: set[int] = set()
|
||||
for candidate in (_skyvern_instance.get(), _global_skyvern_instance):
|
||||
with _clients_lock:
|
||||
candidates = (_skyvern_instance.get(), _global_skyvern_instance, *_api_key_clients.values())
|
||||
_api_key_clients.clear()
|
||||
_global_skyvern_instance = None
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate is None or id(candidate) in seen:
|
||||
continue
|
||||
seen.add(id(candidate))
|
||||
@@ -61,4 +160,3 @@ async def close_skyvern() -> None:
|
||||
LOG.warning("Failed to close Skyvern client", exc_info=True)
|
||||
|
||||
_skyvern_instance.set(None)
|
||||
_global_skyvern_instance = None
|
||||
|
||||
206
skyvern/cli/core/mcp_http_auth.py
Normal file
206
skyvern/cli/core/mcp_http_auth.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from threading import RLock
|
||||
|
||||
import structlog
|
||||
from fastapi import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
|
||||
|
||||
from .api_key_hash import hash_api_key_for_cache
|
||||
from .client import reset_api_key_override, set_api_key_override
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
API_KEY_HEADER = "x-api-key"
|
||||
HEALTH_PATHS = {"/health", "/healthz"}
|
||||
_auth_db: AgentDB | None = None
|
||||
_auth_db_lock = RLock()
|
||||
_api_key_cache_lock = RLock()
|
||||
_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict()
|
||||
_NEGATIVE_CACHE_TTL_SECONDS = 5.0
|
||||
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
|
||||
_MAX_VALIDATION_RETRIES = 2
|
||||
_RETRY_DELAY_SECONDS = 0.25
|
||||
|
||||
|
||||
def _resolve_api_key_cache_ttl_seconds() -> float:
|
||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
|
||||
try:
|
||||
return max(1.0, float(raw))
|
||||
except ValueError:
|
||||
return 30.0
|
||||
|
||||
|
||||
def _resolve_api_key_cache_max_size() -> int:
|
||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_MAX_SIZE", "1024")
|
||||
try:
|
||||
return max(1, int(raw))
|
||||
except ValueError:
|
||||
return 1024
|
||||
|
||||
|
||||
_API_KEY_CACHE_TTL_SECONDS = _resolve_api_key_cache_ttl_seconds()
|
||||
_API_KEY_CACHE_MAX_SIZE = _resolve_api_key_cache_max_size()
|
||||
|
||||
|
||||
def _get_auth_db() -> AgentDB:
|
||||
global _auth_db
|
||||
# Guard singleton init in case HTTP transport is served with threaded workers.
|
||||
with _auth_db_lock:
|
||||
if _auth_db is None:
|
||||
_auth_db = AgentDB(settings.DATABASE_STRING, debug_enabled=settings.DEBUG_MODE)
|
||||
return _auth_db
|
||||
|
||||
|
||||
async def close_auth_db() -> None:
|
||||
"""Dispose the auth DB engine used by HTTP middleware, if initialized."""
|
||||
global _auth_db
|
||||
with _auth_db_lock:
|
||||
db = _auth_db
|
||||
_auth_db = None
|
||||
with _api_key_cache_lock:
|
||||
_api_key_validation_cache.clear()
|
||||
if db is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await db.engine.dispose()
|
||||
except Exception:
|
||||
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
|
||||
|
||||
|
||||
def _cache_key(api_key: str) -> str:
|
||||
return hash_api_key_for_cache(api_key)
|
||||
|
||||
|
||||
async def validate_mcp_api_key(api_key: str) -> str:
|
||||
"""Validate API key and return organization id for observability."""
|
||||
key = _cache_key(api_key)
|
||||
|
||||
# Check cache first.
|
||||
with _api_key_cache_lock:
|
||||
cached = _api_key_validation_cache.get(key)
|
||||
if cached is not None:
|
||||
organization_id, expires_at = cached
|
||||
if expires_at > time.monotonic():
|
||||
_api_key_validation_cache.move_to_end(key)
|
||||
if organization_id is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return organization_id
|
||||
_api_key_validation_cache.pop(key, None)
|
||||
|
||||
# Cache miss — do the DB lookup with simple retry on transient errors.
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_MAX_VALIDATION_RETRIES + 1):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(_RETRY_DELAY_SECONDS)
|
||||
try:
|
||||
validation = await resolve_org_from_api_key(api_key, _get_auth_db())
|
||||
organization_id = validation.organization.organization_id
|
||||
with _api_key_cache_lock:
|
||||
_api_key_validation_cache[key] = (
|
||||
organization_id,
|
||||
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
|
||||
)
|
||||
_api_key_validation_cache.move_to_end(key)
|
||||
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
|
||||
_api_key_validation_cache.popitem(last=False)
|
||||
return organization_id
|
||||
except HTTPException as e:
|
||||
if e.status_code in {401, 403}:
|
||||
with _api_key_cache_lock:
|
||||
_api_key_validation_cache[key] = (None, time.monotonic() + _NEGATIVE_CACHE_TTL_SECONDS)
|
||||
_api_key_validation_cache.move_to_end(key)
|
||||
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
|
||||
_api_key_validation_cache.popitem(last=False)
|
||||
raise
|
||||
last_exc = e
|
||||
except Exception as e:
|
||||
last_exc = e
|
||||
|
||||
LOG.warning("API key validation retries exhausted", attempts=_MAX_VALIDATION_RETRIES + 1, exc_info=last_exc)
|
||||
raise HTTPException(status_code=503, detail=_VALIDATION_RETRY_EXHAUSTED_MESSAGE)
|
||||
|
||||
|
||||
def _unauthorized_response(message: str) -> JSONResponse:
|
||||
return JSONResponse({"error": {"code": "UNAUTHORIZED", "message": message}}, status_code=401)
|
||||
|
||||
|
||||
def _internal_error_response() -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
def _service_unavailable_response(message: str) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{"error": {"code": "SERVICE_UNAVAILABLE", "message": message}},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
|
||||
class MCPAPIKeyMiddleware:
|
||||
"""Require x-api-key for MCP HTTP transport and scope requests to that key."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope, receive=receive)
|
||||
if request.url.path in HEALTH_PATHS:
|
||||
response = JSONResponse({"status": "ok"})
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
api_key = request.headers.get(API_KEY_HEADER)
|
||||
if not api_key:
|
||||
response = _unauthorized_response("Missing x-api-key header")
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
try:
|
||||
organization_id = await validate_mcp_api_key(api_key)
|
||||
scope.setdefault("state", {})
|
||||
scope["state"]["organization_id"] = organization_id
|
||||
except HTTPException as e:
|
||||
if e.status_code in {401, 403}:
|
||||
response = _unauthorized_response("Invalid API key")
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
if e.status_code == 503:
|
||||
response = _service_unavailable_response(e.detail or _VALIDATION_RETRY_EXHAUSTED_MESSAGE)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
LOG.warning("Unexpected HTTPException during MCP API key validation", status_code=e.status_code)
|
||||
response = _internal_error_response()
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
except Exception:
|
||||
LOG.exception("Unexpected MCP API key validation failure")
|
||||
response = _internal_error_response()
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
token = set_api_key_override(api_key)
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
finally:
|
||||
reset_api_key_override(token)
|
||||
@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator
|
||||
|
||||
import structlog
|
||||
|
||||
from .client import get_skyvern
|
||||
from .api_key_hash import hash_api_key_for_cache
|
||||
from .client import get_active_api_key, get_skyvern
|
||||
from .result import BrowserContext, ErrorCode, make_error
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
|
||||
class SessionState:
|
||||
browser: SkyvernBrowser | None = None
|
||||
context: BrowserContext | None = None
|
||||
api_key_hash: str | None = None
|
||||
console_messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
tracing_active: bool = False
|
||||
har_enabled: bool = False
|
||||
@@ -28,26 +30,52 @@ class SessionState:
|
||||
|
||||
_current_session: ContextVar[SessionState | None] = ContextVar("mcp_session", default=None)
|
||||
_global_session: SessionState | None = None
|
||||
_stateless_http_mode = False
|
||||
|
||||
|
||||
def get_current_session() -> SessionState:
|
||||
global _global_session
|
||||
|
||||
state = _current_session.get()
|
||||
if state is None:
|
||||
if _global_session is None:
|
||||
_global_session = SessionState()
|
||||
state = _global_session
|
||||
if state is not None:
|
||||
return state
|
||||
|
||||
# In stateless HTTP mode, avoid process-wide fallback state so requests
|
||||
# cannot inherit session context from other requests.
|
||||
if _stateless_http_mode:
|
||||
state = SessionState()
|
||||
_current_session.set(state)
|
||||
return state
|
||||
|
||||
if _global_session is None:
|
||||
_global_session = SessionState()
|
||||
state = _global_session
|
||||
_current_session.set(state)
|
||||
return state
|
||||
|
||||
|
||||
def set_current_session(state: SessionState) -> None:
|
||||
global _global_session
|
||||
_global_session = state
|
||||
if not _stateless_http_mode:
|
||||
_global_session = state
|
||||
_current_session.set(state)
|
||||
|
||||
|
||||
def set_stateless_http_mode(enabled: bool) -> None:
|
||||
global _stateless_http_mode
|
||||
_stateless_http_mode = enabled
|
||||
|
||||
|
||||
def is_stateless_http_mode() -> bool:
|
||||
return _stateless_http_mode
|
||||
|
||||
|
||||
def _api_key_hash(api_key: str | None) -> str | None:
|
||||
if not api_key:
|
||||
return None
|
||||
return hash_api_key_for_cache(api_key)
|
||||
|
||||
|
||||
def _matches_current(
|
||||
current: SessionState,
|
||||
*,
|
||||
@@ -57,6 +85,8 @@ def _matches_current(
|
||||
) -> bool:
|
||||
if current.browser is None or current.context is None:
|
||||
return False
|
||||
if current.api_key_hash != _api_key_hash(get_active_api_key()):
|
||||
return False
|
||||
|
||||
if session_id:
|
||||
return current.context.mode == "cloud_session" and current.context.session_id == session_id
|
||||
@@ -84,35 +114,39 @@ async def resolve_browser(
|
||||
skyvern = get_skyvern()
|
||||
current = get_current_session()
|
||||
|
||||
if _stateless_http_mode and not (session_id or cdp_url or local or create_session):
|
||||
raise BrowserNotAvailableError()
|
||||
|
||||
if _matches_current(current, session_id=session_id, cdp_url=cdp_url, local=local):
|
||||
# _matches_current() guarantees both are non-None
|
||||
assert current.browser is not None and current.context is not None
|
||||
if current.browser is None or current.context is None:
|
||||
raise RuntimeError("Expected active browser and context for matching session")
|
||||
return current.browser, current.context
|
||||
|
||||
active_api_key_hash = _api_key_hash(get_active_api_key())
|
||||
browser: SkyvernBrowser | None = None
|
||||
try:
|
||||
if session_id:
|
||||
browser = await skyvern.connect_to_cloud_browser_session(session_id)
|
||||
ctx = BrowserContext(mode="cloud_session", session_id=session_id)
|
||||
set_current_session(SessionState(browser=browser, context=ctx))
|
||||
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||
return browser, ctx
|
||||
|
||||
if cdp_url:
|
||||
browser = await skyvern.connect_to_browser_over_cdp(cdp_url)
|
||||
ctx = BrowserContext(mode="cdp", cdp_url=cdp_url)
|
||||
set_current_session(SessionState(browser=browser, context=ctx))
|
||||
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||
return browser, ctx
|
||||
|
||||
if local:
|
||||
browser = await skyvern.launch_local_browser(headless=headless)
|
||||
ctx = BrowserContext(mode="local")
|
||||
set_current_session(SessionState(browser=browser, context=ctx))
|
||||
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||
return browser, ctx
|
||||
|
||||
if create_session:
|
||||
browser = await skyvern.launch_cloud_browser(timeout=timeout)
|
||||
ctx = BrowserContext(mode="cloud_session", session_id=browser.browser_session_id)
|
||||
set_current_session(SessionState(browser=browser, context=ctx))
|
||||
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||
return browser, ctx
|
||||
except Exception:
|
||||
if browser is not None:
|
||||
|
||||
Reference in New Issue
Block a user