align workflow CLI commands with MCP parity (#4792)

This commit is contained in:
Marc Kelechava
2026-02-18 11:34:12 -08:00
committed by GitHub
parent 2f6850ce20
commit 46a7ec1d26
12 changed files with 1609 additions and 151 deletions

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
import hashlib
import os
def _resolve_api_key_hash_iterations() -> int:
raw = os.environ.get("SKYVERN_MCP_API_KEY_HASH_ITERATIONS", "120000")
try:
return max(10_000, int(raw))
except ValueError:
return 120_000
_API_KEY_HASH_ITERATIONS = _resolve_api_key_hash_iterations()
_API_KEY_HASH_SALT = os.environ.get(
"SKYVERN_MCP_API_KEY_HASH_SALT",
"skyvern-mcp-api-key-cache-v1",
).encode("utf-8")
def hash_api_key_for_cache(api_key: str) -> str:
"""Derive a deterministic, non-reversible fingerprint for API-key keyed caches."""
return hashlib.pbkdf2_hmac(
"sha256",
api_key.encode("utf-8"),
_API_KEY_HASH_SALT,
_API_KEY_HASH_ITERATIONS,
).hex()

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
import asyncio
import os
from contextvars import ContextVar
from collections import OrderedDict
from contextvars import ContextVar, Token
from threading import RLock
import structlog
@@ -9,35 +12,126 @@ from skyvern.client import SkyvernEnvironment
from skyvern.config import settings
from skyvern.library.skyvern import Skyvern
from .api_key_hash import hash_api_key_for_cache
_skyvern_instance: ContextVar[Skyvern | None] = ContextVar("skyvern_instance", default=None)
_api_key_override: ContextVar[str | None] = ContextVar("skyvern_api_key_override", default=None)
_global_skyvern_instance: Skyvern | None = None
_api_key_clients: OrderedDict[str, Skyvern] = OrderedDict()
_clients_lock = RLock()
LOG = structlog.get_logger(__name__)
def _resolve_api_key_cache_size() -> int:
raw = os.environ.get("SKYVERN_MCP_API_KEY_CLIENT_CACHE_SIZE", "128")
try:
return max(1, int(raw))
except ValueError:
return 128
_API_KEY_CLIENT_CACHE_MAX = _resolve_api_key_cache_size()
def _cache_key(api_key: str) -> str:
"""Hash API key so raw secrets are never stored as dict keys."""
return hash_api_key_for_cache(api_key)
def _resolve_api_key() -> str | None:
return settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
def _resolve_base_url() -> str | None:
return settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
def _build_cloud_client(api_key: str) -> Skyvern:
return Skyvern(
api_key=api_key,
environment=SkyvernEnvironment.CLOUD,
base_url=_resolve_base_url(),
)
def _close_skyvern_instance_best_effort(instance: Skyvern) -> None:
"""Close a Skyvern instance, regardless of whether an event loop is running."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
try:
asyncio.run(instance.aclose())
except Exception:
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
return
task = loop.create_task(instance.aclose())
def _on_done(done: asyncio.Task[None]) -> None:
try:
done.result()
except Exception:
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
task.add_done_callback(_on_done)
def get_active_api_key() -> str | None:
"""Return the effective API key for this request/context."""
return _api_key_override.get() or _resolve_api_key()
def set_api_key_override(api_key: str | None) -> Token[str | None]:
"""Set request-scoped API key override for MCP HTTP requests."""
_skyvern_instance.set(None)
return _api_key_override.set(api_key)
def reset_api_key_override(token: Token[str | None]) -> None:
"""Reset request-scoped API key override."""
_api_key_override.reset(token)
_skyvern_instance.set(None)
def get_skyvern() -> Skyvern:
"""Get or create a Skyvern client instance."""
global _global_skyvern_instance
instance = _skyvern_instance.get()
if instance is None:
instance = _global_skyvern_instance
if instance is not None:
override_api_key = _api_key_override.get()
if override_api_key:
instance = _skyvern_instance.get()
if instance is None:
key = _cache_key(override_api_key)
evicted_clients: list[Skyvern] = []
# Hold lock across lookup + build + insert to prevent two coroutines
# from both building a client for the same API key concurrently.
with _clients_lock:
instance = _api_key_clients.get(key)
if instance is not None:
_api_key_clients.move_to_end(key)
else:
instance = _build_cloud_client(override_api_key)
_api_key_clients[key] = instance
_api_key_clients.move_to_end(key)
while len(_api_key_clients) > _API_KEY_CLIENT_CACHE_MAX:
_, evicted = _api_key_clients.popitem(last=False)
evicted_clients.append(evicted)
for evicted in evicted_clients:
_close_skyvern_instance_best_effort(evicted)
_skyvern_instance.set(instance)
return instance
api_key = settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
base_url = settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
if api_key:
instance = Skyvern(
api_key=api_key,
environment=SkyvernEnvironment.CLOUD,
base_url=base_url,
)
else:
instance = Skyvern.local()
_global_skyvern_instance = instance
instance = _skyvern_instance.get()
if instance is None:
with _clients_lock:
instance = _global_skyvern_instance
if instance is None:
api_key = _resolve_api_key()
if api_key:
instance = _build_cloud_client(api_key)
else:
instance = Skyvern.local()
_global_skyvern_instance = instance
_skyvern_instance.set(instance)
return instance
@@ -48,7 +142,12 @@ async def close_skyvern() -> None:
instances: list[Skyvern] = []
seen: set[int] = set()
for candidate in (_skyvern_instance.get(), _global_skyvern_instance):
with _clients_lock:
candidates = (_skyvern_instance.get(), _global_skyvern_instance, *_api_key_clients.values())
_api_key_clients.clear()
_global_skyvern_instance = None
for candidate in candidates:
if candidate is None or id(candidate) in seen:
continue
seen.add(id(candidate))
@@ -61,4 +160,3 @@ async def close_skyvern() -> None:
LOG.warning("Failed to close Skyvern client", exc_info=True)
_skyvern_instance.set(None)
_global_skyvern_instance = None

View File

@@ -0,0 +1,206 @@
from __future__ import annotations
import asyncio
import os
import time
from collections import OrderedDict
from threading import RLock
import structlog
from fastapi import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from skyvern.config import settings
from skyvern.forge.sdk.db.agent_db import AgentDB
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
from .api_key_hash import hash_api_key_for_cache
from .client import reset_api_key_override, set_api_key_override
LOG = structlog.get_logger(__name__)
API_KEY_HEADER = "x-api-key"
HEALTH_PATHS = {"/health", "/healthz"}
_auth_db: AgentDB | None = None
_auth_db_lock = RLock()
_api_key_cache_lock = RLock()
_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict()
_NEGATIVE_CACHE_TTL_SECONDS = 5.0
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
_MAX_VALIDATION_RETRIES = 2
_RETRY_DELAY_SECONDS = 0.25
def _resolve_api_key_cache_ttl_seconds() -> float:
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
try:
return max(1.0, float(raw))
except ValueError:
return 30.0
def _resolve_api_key_cache_max_size() -> int:
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_MAX_SIZE", "1024")
try:
return max(1, int(raw))
except ValueError:
return 1024
_API_KEY_CACHE_TTL_SECONDS = _resolve_api_key_cache_ttl_seconds()
_API_KEY_CACHE_MAX_SIZE = _resolve_api_key_cache_max_size()
def _get_auth_db() -> AgentDB:
global _auth_db
# Guard singleton init in case HTTP transport is served with threaded workers.
with _auth_db_lock:
if _auth_db is None:
_auth_db = AgentDB(settings.DATABASE_STRING, debug_enabled=settings.DEBUG_MODE)
return _auth_db
async def close_auth_db() -> None:
"""Dispose the auth DB engine used by HTTP middleware, if initialized."""
global _auth_db
with _auth_db_lock:
db = _auth_db
_auth_db = None
with _api_key_cache_lock:
_api_key_validation_cache.clear()
if db is None:
return
try:
await db.engine.dispose()
except Exception:
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
def _cache_key(api_key: str) -> str:
return hash_api_key_for_cache(api_key)
async def validate_mcp_api_key(api_key: str) -> str:
"""Validate API key and return organization id for observability."""
key = _cache_key(api_key)
# Check cache first.
with _api_key_cache_lock:
cached = _api_key_validation_cache.get(key)
if cached is not None:
organization_id, expires_at = cached
if expires_at > time.monotonic():
_api_key_validation_cache.move_to_end(key)
if organization_id is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return organization_id
_api_key_validation_cache.pop(key, None)
# Cache miss — do the DB lookup with simple retry on transient errors.
last_exc: Exception | None = None
for attempt in range(_MAX_VALIDATION_RETRIES + 1):
if attempt > 0:
await asyncio.sleep(_RETRY_DELAY_SECONDS)
try:
validation = await resolve_org_from_api_key(api_key, _get_auth_db())
organization_id = validation.organization.organization_id
with _api_key_cache_lock:
_api_key_validation_cache[key] = (
organization_id,
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
)
_api_key_validation_cache.move_to_end(key)
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
_api_key_validation_cache.popitem(last=False)
return organization_id
except HTTPException as e:
if e.status_code in {401, 403}:
with _api_key_cache_lock:
_api_key_validation_cache[key] = (None, time.monotonic() + _NEGATIVE_CACHE_TTL_SECONDS)
_api_key_validation_cache.move_to_end(key)
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
_api_key_validation_cache.popitem(last=False)
raise
last_exc = e
except Exception as e:
last_exc = e
LOG.warning("API key validation retries exhausted", attempts=_MAX_VALIDATION_RETRIES + 1, exc_info=last_exc)
raise HTTPException(status_code=503, detail=_VALIDATION_RETRY_EXHAUSTED_MESSAGE)
def _unauthorized_response(message: str) -> JSONResponse:
return JSONResponse({"error": {"code": "UNAUTHORIZED", "message": message}}, status_code=401)
def _internal_error_response() -> JSONResponse:
return JSONResponse(
{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}},
status_code=500,
)
def _service_unavailable_response(message: str) -> JSONResponse:
return JSONResponse(
{"error": {"code": "SERVICE_UNAVAILABLE", "message": message}},
status_code=503,
)
class MCPAPIKeyMiddleware:
"""Require x-api-key for MCP HTTP transport and scope requests to that key."""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive=receive)
if request.url.path in HEALTH_PATHS:
response = JSONResponse({"status": "ok"})
await response(scope, receive, send)
return
if request.method == "OPTIONS":
await self.app(scope, receive, send)
return
api_key = request.headers.get(API_KEY_HEADER)
if not api_key:
response = _unauthorized_response("Missing x-api-key header")
await response(scope, receive, send)
return
try:
organization_id = await validate_mcp_api_key(api_key)
scope.setdefault("state", {})
scope["state"]["organization_id"] = organization_id
except HTTPException as e:
if e.status_code in {401, 403}:
response = _unauthorized_response("Invalid API key")
await response(scope, receive, send)
return
if e.status_code == 503:
response = _service_unavailable_response(e.detail or _VALIDATION_RETRY_EXHAUSTED_MESSAGE)
await response(scope, receive, send)
return
LOG.warning("Unexpected HTTPException during MCP API key validation", status_code=e.status_code)
response = _internal_error_response()
await response(scope, receive, send)
return
except Exception:
LOG.exception("Unexpected MCP API key validation failure")
response = _internal_error_response()
await response(scope, receive, send)
return
token = set_api_key_override(api_key)
try:
await self.app(scope, receive, send)
finally:
reset_api_key_override(token)

View File

@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator
import structlog
from .client import get_skyvern
from .api_key_hash import hash_api_key_for_cache
from .client import get_active_api_key, get_skyvern
from .result import BrowserContext, ErrorCode, make_error
LOG = structlog.get_logger(__name__)
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
class SessionState:
browser: SkyvernBrowser | None = None
context: BrowserContext | None = None
api_key_hash: str | None = None
console_messages: list[dict[str, Any]] = field(default_factory=list)
tracing_active: bool = False
har_enabled: bool = False
@@ -28,26 +30,52 @@ class SessionState:
_current_session: ContextVar[SessionState | None] = ContextVar("mcp_session", default=None)
_global_session: SessionState | None = None
_stateless_http_mode = False
def get_current_session() -> SessionState:
global _global_session
state = _current_session.get()
if state is None:
if _global_session is None:
_global_session = SessionState()
state = _global_session
if state is not None:
return state
# In stateless HTTP mode, avoid process-wide fallback state so requests
# cannot inherit session context from other requests.
if _stateless_http_mode:
state = SessionState()
_current_session.set(state)
return state
if _global_session is None:
_global_session = SessionState()
state = _global_session
_current_session.set(state)
return state
def set_current_session(state: SessionState) -> None:
global _global_session
_global_session = state
if not _stateless_http_mode:
_global_session = state
_current_session.set(state)
def set_stateless_http_mode(enabled: bool) -> None:
global _stateless_http_mode
_stateless_http_mode = enabled
def is_stateless_http_mode() -> bool:
return _stateless_http_mode
def _api_key_hash(api_key: str | None) -> str | None:
if not api_key:
return None
return hash_api_key_for_cache(api_key)
def _matches_current(
current: SessionState,
*,
@@ -57,6 +85,8 @@ def _matches_current(
) -> bool:
if current.browser is None or current.context is None:
return False
if current.api_key_hash != _api_key_hash(get_active_api_key()):
return False
if session_id:
return current.context.mode == "cloud_session" and current.context.session_id == session_id
@@ -84,35 +114,39 @@ async def resolve_browser(
skyvern = get_skyvern()
current = get_current_session()
if _stateless_http_mode and not (session_id or cdp_url or local or create_session):
raise BrowserNotAvailableError()
if _matches_current(current, session_id=session_id, cdp_url=cdp_url, local=local):
# _matches_current() guarantees both are non-None
assert current.browser is not None and current.context is not None
if current.browser is None or current.context is None:
raise RuntimeError("Expected active browser and context for matching session")
return current.browser, current.context
active_api_key_hash = _api_key_hash(get_active_api_key())
browser: SkyvernBrowser | None = None
try:
if session_id:
browser = await skyvern.connect_to_cloud_browser_session(session_id)
ctx = BrowserContext(mode="cloud_session", session_id=session_id)
set_current_session(SessionState(browser=browser, context=ctx))
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
return browser, ctx
if cdp_url:
browser = await skyvern.connect_to_browser_over_cdp(cdp_url)
ctx = BrowserContext(mode="cdp", cdp_url=cdp_url)
set_current_session(SessionState(browser=browser, context=ctx))
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
return browser, ctx
if local:
browser = await skyvern.launch_local_browser(headless=headless)
ctx = BrowserContext(mode="local")
set_current_session(SessionState(browser=browser, context=ctx))
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
return browser, ctx
if create_session:
browser = await skyvern.launch_cloud_browser(timeout=timeout)
ctx = BrowserContext(mode="cloud_session", session_id=browser.browser_session_id)
set_current_session(SessionState(browser=browser, context=ctx))
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
return browser, ctx
except Exception:
if browser is not None: