From 46a7ec1d26c11bc82f4be02a772f4008e89d661d Mon Sep 17 00:00:00 2001 From: Marc Kelechava Date: Wed, 18 Feb 2026 11:34:12 -0800 Subject: [PATCH] align workflow CLI commands with MCP parity (#4792) --- docs/integrations/mcp.mdx | 24 ++ skyvern/cli/core/api_key_hash.py | 29 +++ skyvern/cli/core/client.py | 138 ++++++++-- skyvern/cli/core/mcp_http_auth.py | 206 +++++++++++++++ skyvern/cli/core/session_manager.py | 58 ++++- skyvern/cli/mcp_tools/session.py | 39 ++- skyvern/cli/run_commands.py | 97 +++---- skyvern/cli/workflow.py | 303 ++++++++++++++++------ tests/unit/test_cli_commands.py | 172 +++++++++++++ tests/unit/test_mcp_http_auth.py | 304 ++++++++++++++++++++++ tests/unit/test_mcp_session_lifecycle.py | 310 +++++++++++++++++++++++ tests/unit/test_run_commands_cleanup.py | 80 +++++- 12 files changed, 1609 insertions(+), 151 deletions(-) create mode 100644 skyvern/cli/core/api_key_hash.py create mode 100644 skyvern/cli/core/mcp_http_auth.py create mode 100644 tests/unit/test_mcp_http_auth.py diff --git a/docs/integrations/mcp.mdx b/docs/integrations/mcp.mdx index 00c48d9d..87485813 100644 --- a/docs/integrations/mcp.mdx +++ b/docs/integrations/mcp.mdx @@ -68,6 +68,30 @@ Add this to your MCP client's configuration file: Replace `/usr/bin/python3` with the output of `which python3` on your machine. For local mode, set `SKYVERN_BASE_URL` to `http://localhost:8000` and find your API key in the `.env` file after running `skyvern init`. +### Option C: Remote MCP over HTTPS (streamable HTTP) + +Use this when your team provides a hosted MCP endpoint (for example: `https://mcp.skyvern.com/mcp`). + +In remote HTTP mode: +- Clients must send `x-api-key` on every request. +- Use `skyvern_session_create` first, then pass `session_id` explicitly on subsequent browser tool calls. + +If your MCP client supports native remote HTTP transport, configure it directly: + +```json +{ + "mcpServers": { + "SkyvernRemote": { + "type": "streamable-http", + "url": "https://mcp.skyvern.com/mcp", + "headers": { + "x-api-key": "YOUR_SKYVERN_API_KEY" + } + } + } +} +``` + | Client | Path | diff --git a/skyvern/cli/core/api_key_hash.py b/skyvern/cli/core/api_key_hash.py new file mode 100644 index 00000000..ad89d414 --- /dev/null +++ b/skyvern/cli/core/api_key_hash.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import hashlib +import os + + +def _resolve_api_key_hash_iterations() -> int: + raw = os.environ.get("SKYVERN_MCP_API_KEY_HASH_ITERATIONS", "120000") + try: + return max(10_000, int(raw)) + except ValueError: + return 120_000 + + +_API_KEY_HASH_ITERATIONS = _resolve_api_key_hash_iterations() +_API_KEY_HASH_SALT = os.environ.get( + "SKYVERN_MCP_API_KEY_HASH_SALT", + "skyvern-mcp-api-key-cache-v1", +).encode("utf-8") + + +def hash_api_key_for_cache(api_key: str) -> str: + """Derive a deterministic, non-reversible fingerprint for API-key keyed caches.""" + return hashlib.pbkdf2_hmac( + "sha256", + api_key.encode("utf-8"), + _API_KEY_HASH_SALT, + _API_KEY_HASH_ITERATIONS, + ).hex() diff --git a/skyvern/cli/core/client.py b/skyvern/cli/core/client.py index ef02634a..623f3ec1 100644 --- a/skyvern/cli/core/client.py +++ b/skyvern/cli/core/client.py @@ -1,7 +1,10 @@ from __future__ import annotations +import asyncio import os -from contextvars import ContextVar +from collections import OrderedDict +from contextvars import ContextVar, Token +from threading import RLock import structlog @@ -9,35 +12,126 @@ from skyvern.client import SkyvernEnvironment from skyvern.config import settings from skyvern.library.skyvern import Skyvern +from .api_key_hash import hash_api_key_for_cache + _skyvern_instance: ContextVar[Skyvern | None] = ContextVar("skyvern_instance", default=None) +_api_key_override: ContextVar[str | None] = ContextVar("skyvern_api_key_override", default=None) _global_skyvern_instance: Skyvern | None = None +_api_key_clients: OrderedDict[str, Skyvern] = OrderedDict() +_clients_lock = RLock() LOG = structlog.get_logger(__name__) +def _resolve_api_key_cache_size() -> int: + raw = os.environ.get("SKYVERN_MCP_API_KEY_CLIENT_CACHE_SIZE", "128") + try: + return max(1, int(raw)) + except ValueError: + return 128 + + +_API_KEY_CLIENT_CACHE_MAX = _resolve_api_key_cache_size() + + +def _cache_key(api_key: str) -> str: + """Hash API key so raw secrets are never stored as dict keys.""" + return hash_api_key_for_cache(api_key) + + +def _resolve_api_key() -> str | None: + return settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY") + + +def _resolve_base_url() -> str | None: + return settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL") + + +def _build_cloud_client(api_key: str) -> Skyvern: + return Skyvern( + api_key=api_key, + environment=SkyvernEnvironment.CLOUD, + base_url=_resolve_base_url(), + ) + + +def _close_skyvern_instance_best_effort(instance: Skyvern) -> None: + """Close a Skyvern instance, regardless of whether an event loop is running.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + try: + asyncio.run(instance.aclose()) + except Exception: + LOG.debug("Failed to close evicted Skyvern client", exc_info=True) + return + + task = loop.create_task(instance.aclose()) + + def _on_done(done: asyncio.Task[None]) -> None: + try: + done.result() + except Exception: + LOG.debug("Failed to close evicted Skyvern client", exc_info=True) + + task.add_done_callback(_on_done) + + +def get_active_api_key() -> str | None: + """Return the effective API key for this request/context.""" + return _api_key_override.get() or _resolve_api_key() + + +def set_api_key_override(api_key: str | None) -> Token[str | None]: + """Set request-scoped API key override for MCP HTTP requests.""" + _skyvern_instance.set(None) + return _api_key_override.set(api_key) + + +def reset_api_key_override(token: Token[str | None]) -> None: + """Reset request-scoped API key override.""" + _api_key_override.reset(token) + _skyvern_instance.set(None) + + def get_skyvern() -> Skyvern: """Get or create a Skyvern client instance.""" global _global_skyvern_instance - instance = _skyvern_instance.get() - if instance is None: - instance = _global_skyvern_instance - if instance is not None: + override_api_key = _api_key_override.get() + if override_api_key: + instance = _skyvern_instance.get() + if instance is None: + key = _cache_key(override_api_key) + evicted_clients: list[Skyvern] = [] + # Hold lock across lookup + build + insert to prevent two coroutines + # from both building a client for the same API key concurrently. + with _clients_lock: + instance = _api_key_clients.get(key) + if instance is not None: + _api_key_clients.move_to_end(key) + else: + instance = _build_cloud_client(override_api_key) + _api_key_clients[key] = instance + _api_key_clients.move_to_end(key) + while len(_api_key_clients) > _API_KEY_CLIENT_CACHE_MAX: + _, evicted = _api_key_clients.popitem(last=False) + evicted_clients.append(evicted) + for evicted in evicted_clients: + _close_skyvern_instance_best_effort(evicted) _skyvern_instance.set(instance) return instance - api_key = settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY") - base_url = settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL") - - if api_key: - instance = Skyvern( - api_key=api_key, - environment=SkyvernEnvironment.CLOUD, - base_url=base_url, - ) - else: - instance = Skyvern.local() - - _global_skyvern_instance = instance + instance = _skyvern_instance.get() + if instance is None: + with _clients_lock: + instance = _global_skyvern_instance + if instance is None: + api_key = _resolve_api_key() + if api_key: + instance = _build_cloud_client(api_key) + else: + instance = Skyvern.local() + _global_skyvern_instance = instance _skyvern_instance.set(instance) return instance @@ -48,7 +142,12 @@ async def close_skyvern() -> None: instances: list[Skyvern] = [] seen: set[int] = set() - for candidate in (_skyvern_instance.get(), _global_skyvern_instance): + with _clients_lock: + candidates = (_skyvern_instance.get(), _global_skyvern_instance, *_api_key_clients.values()) + _api_key_clients.clear() + _global_skyvern_instance = None + + for candidate in candidates: if candidate is None or id(candidate) in seen: continue seen.add(id(candidate)) @@ -61,4 +160,3 @@ async def close_skyvern() -> None: LOG.warning("Failed to close Skyvern client", exc_info=True) _skyvern_instance.set(None) - _global_skyvern_instance = None diff --git a/skyvern/cli/core/mcp_http_auth.py b/skyvern/cli/core/mcp_http_auth.py new file mode 100644 index 00000000..07a7a168 --- /dev/null +++ b/skyvern/cli/core/mcp_http_auth.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import asyncio +import os +import time +from collections import OrderedDict +from threading import RLock + +import structlog +from fastapi import HTTPException +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.types import ASGIApp, Receive, Scope, Send + +from skyvern.config import settings +from skyvern.forge.sdk.db.agent_db import AgentDB +from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key + +from .api_key_hash import hash_api_key_for_cache +from .client import reset_api_key_override, set_api_key_override + +LOG = structlog.get_logger(__name__) +API_KEY_HEADER = "x-api-key" +HEALTH_PATHS = {"/health", "/healthz"} +_auth_db: AgentDB | None = None +_auth_db_lock = RLock() +_api_key_cache_lock = RLock() +_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict() +_NEGATIVE_CACHE_TTL_SECONDS = 5.0 +_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable" +_MAX_VALIDATION_RETRIES = 2 +_RETRY_DELAY_SECONDS = 0.25 + + +def _resolve_api_key_cache_ttl_seconds() -> float: + raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30") + try: + return max(1.0, float(raw)) + except ValueError: + return 30.0 + + +def _resolve_api_key_cache_max_size() -> int: + raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_MAX_SIZE", "1024") + try: + return max(1, int(raw)) + except ValueError: + return 1024 + + +_API_KEY_CACHE_TTL_SECONDS = _resolve_api_key_cache_ttl_seconds() +_API_KEY_CACHE_MAX_SIZE = _resolve_api_key_cache_max_size() + + +def _get_auth_db() -> AgentDB: + global _auth_db + # Guard singleton init in case HTTP transport is served with threaded workers. + with _auth_db_lock: + if _auth_db is None: + _auth_db = AgentDB(settings.DATABASE_STRING, debug_enabled=settings.DEBUG_MODE) + return _auth_db + + +async def close_auth_db() -> None: + """Dispose the auth DB engine used by HTTP middleware, if initialized.""" + global _auth_db + with _auth_db_lock: + db = _auth_db + _auth_db = None + with _api_key_cache_lock: + _api_key_validation_cache.clear() + if db is None: + return + + try: + await db.engine.dispose() + except Exception: + LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True) + + +def _cache_key(api_key: str) -> str: + return hash_api_key_for_cache(api_key) + + +async def validate_mcp_api_key(api_key: str) -> str: + """Validate API key and return organization id for observability.""" + key = _cache_key(api_key) + + # Check cache first. + with _api_key_cache_lock: + cached = _api_key_validation_cache.get(key) + if cached is not None: + organization_id, expires_at = cached + if expires_at > time.monotonic(): + _api_key_validation_cache.move_to_end(key) + if organization_id is None: + raise HTTPException(status_code=401, detail="Invalid API key") + return organization_id + _api_key_validation_cache.pop(key, None) + + # Cache miss — do the DB lookup with simple retry on transient errors. + last_exc: Exception | None = None + for attempt in range(_MAX_VALIDATION_RETRIES + 1): + if attempt > 0: + await asyncio.sleep(_RETRY_DELAY_SECONDS) + try: + validation = await resolve_org_from_api_key(api_key, _get_auth_db()) + organization_id = validation.organization.organization_id + with _api_key_cache_lock: + _api_key_validation_cache[key] = ( + organization_id, + time.monotonic() + _API_KEY_CACHE_TTL_SECONDS, + ) + _api_key_validation_cache.move_to_end(key) + while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE: + _api_key_validation_cache.popitem(last=False) + return organization_id + except HTTPException as e: + if e.status_code in {401, 403}: + with _api_key_cache_lock: + _api_key_validation_cache[key] = (None, time.monotonic() + _NEGATIVE_CACHE_TTL_SECONDS) + _api_key_validation_cache.move_to_end(key) + while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE: + _api_key_validation_cache.popitem(last=False) + raise + last_exc = e + except Exception as e: + last_exc = e + + LOG.warning("API key validation retries exhausted", attempts=_MAX_VALIDATION_RETRIES + 1, exc_info=last_exc) + raise HTTPException(status_code=503, detail=_VALIDATION_RETRY_EXHAUSTED_MESSAGE) + + +def _unauthorized_response(message: str) -> JSONResponse: + return JSONResponse({"error": {"code": "UNAUTHORIZED", "message": message}}, status_code=401) + + +def _internal_error_response() -> JSONResponse: + return JSONResponse( + {"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}}, + status_code=500, + ) + + +def _service_unavailable_response(message: str) -> JSONResponse: + return JSONResponse( + {"error": {"code": "SERVICE_UNAVAILABLE", "message": message}}, + status_code=503, + ) + + +class MCPAPIKeyMiddleware: + """Require x-api-key for MCP HTTP transport and scope requests to that key.""" + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = Request(scope, receive=receive) + if request.url.path in HEALTH_PATHS: + response = JSONResponse({"status": "ok"}) + await response(scope, receive, send) + return + + if request.method == "OPTIONS": + await self.app(scope, receive, send) + return + + api_key = request.headers.get(API_KEY_HEADER) + if not api_key: + response = _unauthorized_response("Missing x-api-key header") + await response(scope, receive, send) + return + + try: + organization_id = await validate_mcp_api_key(api_key) + scope.setdefault("state", {}) + scope["state"]["organization_id"] = organization_id + except HTTPException as e: + if e.status_code in {401, 403}: + response = _unauthorized_response("Invalid API key") + await response(scope, receive, send) + return + if e.status_code == 503: + response = _service_unavailable_response(e.detail or _VALIDATION_RETRY_EXHAUSTED_MESSAGE) + await response(scope, receive, send) + return + LOG.warning("Unexpected HTTPException during MCP API key validation", status_code=e.status_code) + response = _internal_error_response() + await response(scope, receive, send) + return + except Exception: + LOG.exception("Unexpected MCP API key validation failure") + response = _internal_error_response() + await response(scope, receive, send) + return + + token = set_api_key_override(api_key) + try: + await self.app(scope, receive, send) + finally: + reset_api_key_override(token) diff --git a/skyvern/cli/core/session_manager.py b/skyvern/cli/core/session_manager.py index 2b48bbe5..83af9286 100644 --- a/skyvern/cli/core/session_manager.py +++ b/skyvern/cli/core/session_manager.py @@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator import structlog -from .client import get_skyvern +from .api_key_hash import hash_api_key_for_cache +from .client import get_active_api_key, get_skyvern from .result import BrowserContext, ErrorCode, make_error LOG = structlog.get_logger(__name__) @@ -21,6 +22,7 @@ if TYPE_CHECKING: class SessionState: browser: SkyvernBrowser | None = None context: BrowserContext | None = None + api_key_hash: str | None = None console_messages: list[dict[str, Any]] = field(default_factory=list) tracing_active: bool = False har_enabled: bool = False @@ -28,26 +30,52 @@ class SessionState: _current_session: ContextVar[SessionState | None] = ContextVar("mcp_session", default=None) _global_session: SessionState | None = None +_stateless_http_mode = False def get_current_session() -> SessionState: global _global_session state = _current_session.get() - if state is None: - if _global_session is None: - _global_session = SessionState() - state = _global_session + if state is not None: + return state + + # In stateless HTTP mode, avoid process-wide fallback state so requests + # cannot inherit session context from other requests. + if _stateless_http_mode: + state = SessionState() _current_session.set(state) + return state + + if _global_session is None: + _global_session = SessionState() + state = _global_session + _current_session.set(state) return state def set_current_session(state: SessionState) -> None: global _global_session - _global_session = state + if not _stateless_http_mode: + _global_session = state _current_session.set(state) +def set_stateless_http_mode(enabled: bool) -> None: + global _stateless_http_mode + _stateless_http_mode = enabled + + +def is_stateless_http_mode() -> bool: + return _stateless_http_mode + + +def _api_key_hash(api_key: str | None) -> str | None: + if not api_key: + return None + return hash_api_key_for_cache(api_key) + + def _matches_current( current: SessionState, *, @@ -57,6 +85,8 @@ def _matches_current( ) -> bool: if current.browser is None or current.context is None: return False + if current.api_key_hash != _api_key_hash(get_active_api_key()): + return False if session_id: return current.context.mode == "cloud_session" and current.context.session_id == session_id @@ -84,35 +114,39 @@ async def resolve_browser( skyvern = get_skyvern() current = get_current_session() + if _stateless_http_mode and not (session_id or cdp_url or local or create_session): + raise BrowserNotAvailableError() + if _matches_current(current, session_id=session_id, cdp_url=cdp_url, local=local): - # _matches_current() guarantees both are non-None - assert current.browser is not None and current.context is not None + if current.browser is None or current.context is None: + raise RuntimeError("Expected active browser and context for matching session") return current.browser, current.context + active_api_key_hash = _api_key_hash(get_active_api_key()) browser: SkyvernBrowser | None = None try: if session_id: browser = await skyvern.connect_to_cloud_browser_session(session_id) ctx = BrowserContext(mode="cloud_session", session_id=session_id) - set_current_session(SessionState(browser=browser, context=ctx)) + set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash)) return browser, ctx if cdp_url: browser = await skyvern.connect_to_browser_over_cdp(cdp_url) ctx = BrowserContext(mode="cdp", cdp_url=cdp_url) - set_current_session(SessionState(browser=browser, context=ctx)) + set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash)) return browser, ctx if local: browser = await skyvern.launch_local_browser(headless=headless) ctx = BrowserContext(mode="local") - set_current_session(SessionState(browser=browser, context=ctx)) + set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash)) return browser, ctx if create_session: browser = await skyvern.launch_cloud_browser(timeout=timeout) ctx = BrowserContext(mode="cloud_session", session_id=browser.browser_session_id) - set_current_session(SessionState(browser=browser, context=ctx)) + set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash)) return browser, ctx except Exception: if browser is not None: diff --git a/skyvern/cli/mcp_tools/session.py b/skyvern/cli/mcp_tools/session.py index 7b22fa89..37c2eac0 100644 --- a/skyvern/cli/mcp_tools/session.py +++ b/skyvern/cli/mcp_tools/session.py @@ -4,7 +4,11 @@ from typing import Annotated, Any from pydantic import Field +from skyvern.cli.core.api_key_hash import hash_api_key_for_cache +from skyvern.cli.core.client import get_active_api_key +from skyvern.cli.core.session_manager import is_stateless_http_mode from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list +from skyvern.schemas.runs import ProxyLocation from ._common import BrowserContext, ErrorCode, Timer, make_error, make_result from ._session import ( @@ -16,6 +20,13 @@ from ._session import ( ) +def _session_api_key_hash() -> str | None: + api_key = get_active_api_key() + if not api_key: + return None + return hash_api_key_for_cache(api_key) + + async def skyvern_session_create( timeout: Annotated[int | None, Field(description="Session timeout in minutes (5-1440)")] = 60, proxy_location: Annotated[str | None, Field(description="Proxy location: RESIDENTIAL, US, etc.")] = None, @@ -29,7 +40,33 @@ async def skyvern_session_create( """ with Timer() as timer: try: + if is_stateless_http_mode() and local: + return make_result( + "skyvern_session_create", + ok=False, + error=make_error( + ErrorCode.INVALID_INPUT, + "Local browser sessions are not supported in stateless HTTP mode", + "Use cloud sessions for remote MCP transport", + ), + ) + skyvern = get_skyvern() + if is_stateless_http_mode(): + proxy = ProxyLocation(proxy_location) if proxy_location else None + session = await skyvern.create_browser_session(timeout=timeout or 60, proxy_location=proxy) + timer.mark("sdk") + ctx = BrowserContext(mode="cloud_session", session_id=session.browser_session_id) + return make_result( + "skyvern_session_create", + browser_context=ctx, + data={ + "session_id": session.browser_session_id, + "timeout_minutes": timeout or 60, + }, + timing_ms=timer.timing_ms, + ) + browser, result = await do_session_create( skyvern, timeout=timeout or 60, @@ -43,7 +80,7 @@ async def skyvern_session_create( ctx = BrowserContext(mode="local") else: ctx = BrowserContext(mode="cloud_session", session_id=result.session_id) - set_current_session(SessionState(browser=browser, context=ctx)) + set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=_session_api_key_hash())) except ValueError as e: return make_result( diff --git a/skyvern/cli/run_commands.py b/skyvern/cli/run_commands.py index 1d48be22..f9b31826 100644 --- a/skyvern/cli/run_commands.py +++ b/skyvern/cli/run_commands.py @@ -5,7 +5,7 @@ import logging import os import shutil import subprocess -from typing import Any, List, Optional +from typing import Annotated, List, Literal, Optional import psutil import typer @@ -13,17 +13,17 @@ import uvicorn from dotenv import load_dotenv, set_key from rich.panel import Panel from rich.prompt import Confirm +from starlette.middleware import Middleware from skyvern.cli.console import console from skyvern.cli.core.client import close_skyvern -from skyvern.cli.core.session_manager import close_current_session +from skyvern.cli.core.mcp_http_auth import MCPAPIKeyMiddleware, close_auth_db +from skyvern.cli.core.session_manager import close_current_session, set_stateless_http_mode from skyvern.cli.mcp_tools import mcp # Uses standalone fastmcp (v2.x) from skyvern.cli.utils import start_services -from skyvern.client import SkyvernEnvironment from skyvern.config import settings from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.forge_log import setup_logger -from skyvern.library.skyvern import Skyvern from skyvern.services.script_service import run_script from skyvern.utils import detect_os from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path @@ -36,7 +36,10 @@ async def _cleanup_mcp_resources() -> None: try: await close_current_session() finally: - await close_skyvern() + try: + await close_skyvern() + finally: + await close_auth_db() def _cleanup_mcp_resources_blocking() -> None: @@ -67,39 +70,6 @@ def _cleanup_mcp_resources_sync() -> None: logger.debug("Skipping MCP cleanup because event loop is still running") -@mcp.tool() -async def skyvern_run_task(prompt: str, url: str) -> dict[str, Any]: - """Use Skyvern to execute anything in the browser. Useful for accomplishing tasks that require browser automation. - - This tool uses Skyvern's browser automation to navigate websites and perform actions to achieve - the user's intended outcome. It can handle tasks like form filling, clicking buttons, data extraction, - and multi-step workflows. - - It can even help you find updated data on the internet if your model information is outdated. - - Args: - prompt: A natural language description of what needs to be accomplished (e.g. "Book a flight from - NYC to LA", "Sign up for the newsletter", "Find the price of item X", "Apply to a job") - url: The starting URL of the website where the task should be performed - """ - skyvern_agent = Skyvern( - environment=SkyvernEnvironment.CLOUD, - base_url=settings.SKYVERN_BASE_URL, - api_key=settings.SKYVERN_API_KEY, - ) - res = await skyvern_agent.run_task(prompt=prompt, url=url, user_agent="skyvern-mcp", wait_for_completion=True) - - output = res.model_dump()["output"] - if res.app_url: - task_url = res.app_url - else: - if res.run_id and res.run_id.startswith("wr_"): - task_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/runs/{res.run_id}/overview" - else: - task_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/tasks/{res.run_id}/actions" - return {"output": output, "task_url": task_url, "run_id": res.run_id} - - def get_pids_on_port(port: int) -> List[int]: """Return a list of PIDs listening on the given port.""" pids = [] @@ -295,20 +265,61 @@ def run_dev() -> None: @run_app.command(name="mcp") -def run_mcp() -> None: - """Run the MCP server.""" - # This breaks the MCP processing because it expects json output only - # console.print(Panel("[bold green]Starting MCP Server...[/bold green]", border_style="green")) +def run_mcp( + transport: Annotated[ + Literal["stdio", "sse", "streamable-http"], + typer.Option( + "--transport", + help="MCP transport: stdio (default), sse, or streamable-http.", + ), + ] = "stdio", + host: Annotated[str, typer.Option("--host", help="Host for HTTP transports.")] = "0.0.0.0", + port: Annotated[int, typer.Option("--port", help="Port for HTTP transports.")] = 8000, + path: Annotated[str, typer.Option("--path", help="HTTP endpoint path for MCP transport.")] = "/mcp", + stateless_http: Annotated[ + bool, + typer.Option( + "--stateless-http/--no-stateless-http", + help="Use stateless HTTP semantics for HTTP transports (ignored for stdio).", + ), + ] = True, +) -> None: + """Run the MCP server with configurable transport for local or remote hosting.""" + path = _normalize_mcp_path(path) + stateless_http_enabled = transport != "stdio" and stateless_http # atexit covers signal-based exits (SIGTERM); finally covers normal # mcp.run() completion or unhandled exceptions. Both are needed because # atexit doesn't fire on normal return and finally doesn't fire on signals. atexit.register(_cleanup_mcp_resources_sync) + set_stateless_http_mode(stateless_http_enabled) try: - mcp.run(transport="stdio") + if transport == "stdio": + mcp.run(transport="stdio") + return + + middleware = [Middleware(MCPAPIKeyMiddleware)] + mcp.run( + transport=transport, + host=host, + port=port, + path=path, + middleware=middleware, + stateless_http=stateless_http_enabled, + ) finally: + set_stateless_http_mode(False) _cleanup_mcp_resources_blocking() +def _normalize_mcp_path(path: str) -> str: + path = path.strip() + if not path: + return "/mcp" + if not path.startswith("/"): + return f"/{path}" + return path + + @run_app.command( name="code", context_settings={"allow_interspersed_args": False}, diff --git a/skyvern/cli/workflow.py b/skyvern/cli/workflow.py index eb8d6a20..9e9cc219 100644 --- a/skyvern/cli/workflow.py +++ b/skyvern/cli/workflow.py @@ -1,27 +1,80 @@ -"""Workflow-related CLI helpers.""" +"""Workflow-related CLI commands with MCP-parity flags and output.""" from __future__ import annotations +import asyncio import json -import os +import sys +from pathlib import Path +from typing import Any, Callable, Coroutine import typer from dotenv import load_dotenv -from rich.panel import Panel -from skyvern.client import Skyvern from skyvern.config import settings from skyvern.utils.env_paths import resolve_backend_env_path -from .console import console -from .tasks import _list_workflow_tasks +from .commands._output import output, output_error +from .mcp_tools.workflow import skyvern_workflow_cancel as tool_workflow_cancel +from .mcp_tools.workflow import skyvern_workflow_create as tool_workflow_create +from .mcp_tools.workflow import skyvern_workflow_delete as tool_workflow_delete +from .mcp_tools.workflow import skyvern_workflow_get as tool_workflow_get +from .mcp_tools.workflow import skyvern_workflow_list as tool_workflow_list +from .mcp_tools.workflow import skyvern_workflow_run as tool_workflow_run +from .mcp_tools.workflow import skyvern_workflow_status as tool_workflow_status +from .mcp_tools.workflow import skyvern_workflow_update as tool_workflow_update -workflow_app = typer.Typer(help="Manage Skyvern workflows.") +workflow_app = typer.Typer(help="Manage Skyvern workflows.", no_args_is_help=True) + + +def _emit_tool_result(result: dict[str, Any], *, json_output: bool) -> None: + if json_output: + json.dump(result, sys.stdout, indent=2, default=str) + sys.stdout.write("\n") + if not result.get("ok", False): + raise SystemExit(1) + return + + if result.get("ok", False): + output(result.get("data"), action=str(result.get("action", "")), json_mode=False) + return + + err = result.get("error") or {} + output_error(str(err.get("message", "Unknown error")), hint=str(err.get("hint", "")), json_mode=False) + + +def _run_tool( + runner: Callable[[], Coroutine[Any, Any, dict[str, Any]]], + *, + json_output: bool, + hint_on_exception: str, +) -> None: + try: + result: dict[str, Any] = asyncio.run(runner()) + _emit_tool_result(result, json_output=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint=hint_on_exception, json_mode=json_output) + + +def _resolve_inline_or_file(value: str | None, *, param_name: str) -> str | None: + if value is None or not value.startswith("@"): + return value + + file_path = value[1:] + if not file_path: + raise typer.BadParameter(f"{param_name} file path cannot be empty after '@'.") + + path = Path(file_path).expanduser() + try: + return path.read_text(encoding="utf-8") + except OSError as e: + raise typer.BadParameter(f"Unable to read {param_name} file '{path}': {e}") from e @workflow_app.callback() def workflow_callback( - ctx: typer.Context, api_key: str | None = typer.Option( None, "--api-key", @@ -29,86 +82,188 @@ def workflow_callback( envvar="SKYVERN_API_KEY", ), ) -> None: - """Store the provided API key in the Typer context.""" - ctx.obj = {"api_key": api_key} - - -def _get_client(api_key: str | None = None) -> Skyvern: - """Instantiate a Skyvern SDK client using environment variables.""" + """Load workflow CLI environment and optional API key override.""" load_dotenv(resolve_backend_env_path()) - key = api_key or os.getenv("SKYVERN_API_KEY") or settings.SKYVERN_API_KEY - return Skyvern(base_url=settings.SKYVERN_BASE_URL, api_key=key) + if api_key: + settings.SKYVERN_API_KEY = api_key + + +@workflow_app.command("list") +def workflow_list( + search: str | None = typer.Option( + None, + "--search", + help="Search across workflow titles, folder names, and parameter metadata.", + ), + page: int = typer.Option(1, "--page", min=1, help="Page number (1-based)."), + page_size: int = typer.Option(10, "--page-size", min=1, max=100, help="Results per page."), + only_workflows: bool = typer.Option( + False, + "--only-workflows", + help="Only return multi-step workflows (exclude saved tasks).", + ), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """List workflows.""" + + async def _run() -> dict[str, Any]: + return await tool_workflow_list( + search=search, + page=page, + page_size=page_size, + only_workflows=only_workflows, + ) + + _run_tool(_run, json_output=json_output, hint_on_exception="Check your API key and workflow list filters.") + + +@workflow_app.command("get") +def workflow_get( + workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."), + version: int | None = typer.Option(None, "--version", min=1, help="Specific version to retrieve."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Get a workflow definition by ID.""" + + async def _run() -> dict[str, Any]: + return await tool_workflow_get(workflow_id=workflow_id, version=version) + + _run_tool(_run, json_output=json_output, hint_on_exception="Check your API key and workflow ID.") + + +@workflow_app.command("create") +def workflow_create( + definition: str = typer.Option( + ..., + "--definition", + help="Workflow definition as YAML/JSON string or @file path.", + ), + definition_format: str = typer.Option( + "auto", + "--format", + help="Definition format: json, yaml, or auto.", + ), + folder_id: str | None = typer.Option(None, "--folder-id", help="Folder ID (fld_...) for the workflow."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Create a workflow.""" + + async def _run() -> dict[str, Any]: + resolved_definition = _resolve_inline_or_file(definition, param_name="definition") + assert resolved_definition is not None + return await tool_workflow_create( + definition=resolved_definition, + format=definition_format, + folder_id=folder_id, + ) + + _run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow definition syntax.") + + +@workflow_app.command("update") +def workflow_update( + workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."), + definition: str = typer.Option( + ..., + "--definition", + help="Updated workflow definition as YAML/JSON string or @file path.", + ), + definition_format: str = typer.Option( + "auto", + "--format", + help="Definition format: json, yaml, or auto.", + ), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Update a workflow definition.""" + + async def _run() -> dict[str, Any]: + resolved_definition = _resolve_inline_or_file(definition, param_name="definition") + assert resolved_definition is not None + return await tool_workflow_update( + workflow_id=workflow_id, + definition=resolved_definition, + format=definition_format, + ) + + _run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and definition syntax.") + + +@workflow_app.command("delete") +def workflow_delete( + workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."), + force: bool = typer.Option(False, "--force", help="Confirm permanent deletion."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Delete a workflow.""" + + async def _run() -> dict[str, Any]: + return await tool_workflow_delete(workflow_id=workflow_id, force=force) + + _run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and your permissions.") @workflow_app.command("run") -def run_workflow( - ctx: typer.Context, - workflow_id: str = typer.Argument(..., help="Workflow permanent ID"), - parameters: str = typer.Option("{}", "--parameters", "-p", help="JSON parameters for the workflow"), - title: str | None = typer.Option(None, "--title", help="Title for the workflow run"), - max_steps: int | None = typer.Option(None, "--max-steps", help="Override the workflow max steps"), +def workflow_run( + workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."), + params: str | None = typer.Option( + None, + "--params", + "--parameters", + "-p", + help="Workflow parameters as JSON string or @file path.", + ), + session: str | None = typer.Option(None, "--session", help="Browser session ID (pbs_...) to reuse."), + webhook: str | None = typer.Option(None, "--webhook", help="Status webhook callback URL."), + proxy: str | None = typer.Option(None, "--proxy", help="Proxy location (e.g., RESIDENTIAL)."), + wait: bool = typer.Option(False, "--wait", help="Wait for workflow completion before returning."), + timeout: int = typer.Option( + 300, + "--timeout", + min=10, + max=3600, + help="Max wait time in seconds when --wait is set.", + ), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), ) -> None: """Run a workflow.""" - try: - params_dict = json.loads(parameters) if parameters else {} - except json.JSONDecodeError: - console.print(f"[red]Invalid JSON for parameters: {parameters}[/red]") - raise typer.Exit(code=1) - client = _get_client(ctx.obj.get("api_key") if ctx.obj else None) - run_resp = client.run_workflow( - workflow_id=workflow_id, - parameters=params_dict, - title=title, - max_steps_override=max_steps, - ) - console.print( - Panel( - f"Started workflow run [bold]{run_resp.run_id}[/bold]", - border_style="green", + async def _run() -> dict[str, Any]: + resolved_params = _resolve_inline_or_file(params, param_name="params") + return await tool_workflow_run( + workflow_id=workflow_id, + parameters=resolved_params, + browser_session_id=session, + webhook_url=webhook, + proxy_location=proxy, + wait=wait, + timeout_seconds=timeout, ) - ) - -@workflow_app.command("cancel") -def cancel_workflow( - ctx: typer.Context, - run_id: str = typer.Argument(..., help="ID of the workflow run"), -) -> None: - """Cancel a running workflow.""" - client = _get_client(ctx.obj.get("api_key") if ctx.obj else None) - client.cancel_run(run_id=run_id) - console.print(Panel(f"Cancel signal sent for run {run_id}", border_style="red")) + _run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and run parameters.") @workflow_app.command("status") def workflow_status( - ctx: typer.Context, - run_id: str = typer.Argument(..., help="ID of the workflow run"), - tasks: bool = typer.Option(False, "--tasks", help="Show task executions"), + run_id: str = typer.Option(..., "--run-id", help="Run ID (wr_... or tsk_v2_...)."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), ) -> None: - """Retrieve status information for a workflow run.""" - client = _get_client(ctx.obj.get("api_key") if ctx.obj else None) - run = client.get_run(run_id=run_id) - console.print(Panel(run.model_dump_json(indent=2), border_style="cyan")) - if tasks: - task_list = _list_workflow_tasks(client, run_id) - console.print(Panel(json.dumps(task_list, indent=2), border_style="magenta")) + """Get workflow run status.""" + + async def _run() -> dict[str, Any]: + return await tool_workflow_status(run_id=run_id) + + _run_tool(_run, json_output=json_output, hint_on_exception="Check the run ID and API key.") -@workflow_app.command("list") -def list_workflows( - ctx: typer.Context, - page: int = typer.Option(1, "--page", help="Page number"), - page_size: int = typer.Option(10, "--page-size", help="Number of workflows to return"), - template: bool = typer.Option(False, "--template", help="List template workflows"), +@workflow_app.command("cancel") +def workflow_cancel( + run_id: str = typer.Option(..., "--run-id", help="Run ID (wr_... or tsk_v2_...)."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), ) -> None: - """List workflows for the organization.""" - client = _get_client(ctx.obj.get("api_key") if ctx.obj else None) - resp = client._client_wrapper.httpx_client.request( - "api/v1/workflows", - method="GET", - params={"page": page, "page_size": page_size, "template": template}, - ) - resp.raise_for_status() - console.print(Panel(json.dumps(resp.json(), indent=2), border_style="cyan")) + """Cancel a workflow run.""" + + async def _run() -> dict[str, Any]: + return await tool_workflow_cancel(run_id=run_id) + + _run_tool(_run, json_output=json_output, hint_on_exception="Check the run ID and API key.") diff --git a/tests/unit/test_cli_commands.py b/tests/unit/test_cli_commands.py index 08edf9a7..cc9c4d80 100644 --- a/tests/unit/test_cli_commands.py +++ b/tests/unit/test_cli_commands.py @@ -296,3 +296,175 @@ class TestBrowserCommands: parsed = json.loads(capsys.readouterr().out) assert parsed["ok"] is False assert "Invalid state" in parsed["error"]["message"] + + +# --------------------------------------------------------------------------- +# Workflow command behavior +# --------------------------------------------------------------------------- + + +class TestWorkflowCommands: + def test_workflow_get_outputs_mcp_envelope_in_json_mode( + self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture + ) -> None: + from skyvern.cli import workflow as workflow_cmd + + expected = { + "ok": True, + "action": "skyvern_workflow_get", + "browser_context": {"mode": "none", "session_id": None, "cdp_url": None}, + "data": {"workflow_permanent_id": "wpid_123"}, + "artifacts": [], + "timing_ms": {}, + "warnings": [], + "error": None, + } + tool = AsyncMock(return_value=expected) + monkeypatch.setattr(workflow_cmd, "tool_workflow_get", tool) + + workflow_cmd.workflow_get(workflow_id="wpid_123", version=2, json_output=True) + + parsed = json.loads(capsys.readouterr().out) + assert parsed == expected + assert tool.await_args.kwargs == {"workflow_id": "wpid_123", "version": 2} + + def test_workflow_create_reads_definition_from_file( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture, + ) -> None: + from skyvern.cli import workflow as workflow_cmd + + definition_file = tmp_path / "workflow.json" + definition_text = '{"title": "Example", "workflow_definition": {"blocks": []}}' + definition_file.write_text(definition_text) + + tool = AsyncMock( + return_value={ + "ok": True, + "action": "skyvern_workflow_create", + "browser_context": {"mode": "none", "session_id": None, "cdp_url": None}, + "data": {"workflow_permanent_id": "wpid_new"}, + "artifacts": [], + "timing_ms": {}, + "warnings": [], + "error": None, + } + ) + monkeypatch.setattr(workflow_cmd, "tool_workflow_create", tool) + + workflow_cmd.workflow_create( + definition=f"@{definition_file}", + definition_format="json", + folder_id="fld_123", + json_output=True, + ) + + assert tool.await_args.kwargs == { + "definition": definition_text, + "format": "json", + "folder_id": "fld_123", + } + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is True + assert parsed["data"]["workflow_permanent_id"] == "wpid_new" + + def test_workflow_run_reads_params_file_and_maps_options( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture, + ) -> None: + from skyvern.cli import workflow as workflow_cmd + + params_file = tmp_path / "params.json" + params_file.write_text('{"company": "Acme"}') + + tool = AsyncMock( + return_value={ + "ok": True, + "action": "skyvern_workflow_run", + "browser_context": {"mode": "none", "session_id": None, "cdp_url": None}, + "data": {"run_id": "wr_123", "status": "queued"}, + "artifacts": [], + "timing_ms": {}, + "warnings": [], + "error": None, + } + ) + monkeypatch.setattr(workflow_cmd, "tool_workflow_run", tool) + + workflow_cmd.workflow_run( + workflow_id="wpid_123", + params=f"@{params_file}", + session="pbs_456", + webhook="https://example.com/webhook", + proxy="RESIDENTIAL", + wait=True, + timeout=450, + json_output=True, + ) + + assert tool.await_args.kwargs == { + "workflow_id": "wpid_123", + "parameters": '{"company": "Acme"}', + "browser_session_id": "pbs_456", + "webhook_url": "https://example.com/webhook", + "proxy_location": "RESIDENTIAL", + "wait": True, + "timeout_seconds": 450, + } + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is True + assert parsed["data"]["run_id"] == "wr_123" + + def test_workflow_status_json_error_exits_nonzero( + self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture + ) -> None: + from skyvern.cli import workflow as workflow_cmd + + tool = AsyncMock( + return_value={ + "ok": False, + "action": "skyvern_workflow_status", + "browser_context": {"mode": "none", "session_id": None, "cdp_url": None}, + "data": None, + "artifacts": [], + "timing_ms": {}, + "warnings": [], + "error": { + "code": "RUN_NOT_FOUND", + "message": "Run 'wr_missing' not found", + "hint": "Verify the run ID", + "details": {}, + }, + } + ) + monkeypatch.setattr(workflow_cmd, "tool_workflow_status", tool) + + with pytest.raises(SystemExit, match="1"): + workflow_cmd.workflow_status(run_id="wr_missing", json_output=True) + + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is False + assert parsed["error"]["code"] == "RUN_NOT_FOUND" + + def test_workflow_update_missing_definition_file_raises_bad_parameter( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + from skyvern.cli import workflow as workflow_cmd + + tool = AsyncMock() + monkeypatch.setattr(workflow_cmd, "tool_workflow_update", tool) + missing_file = tmp_path / "missing-definition.json" + + with pytest.raises(typer.BadParameter, match="Unable to read definition file"): + workflow_cmd.workflow_update( + workflow_id="wpid_123", + definition=f"@{missing_file}", + definition_format="json", + json_output=False, + ) + + tool.assert_not_called() diff --git a/tests/unit/test_mcp_http_auth.py b/tests/unit/test_mcp_http_auth.py new file mode 100644 index 00000000..8cfcc4dc --- /dev/null +++ b/tests/unit/test_mcp_http_auth.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import httpx +import pytest +from fastapi import HTTPException +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from skyvern.cli.core import client as client_mod +from skyvern.cli.core import mcp_http_auth + + +@pytest.fixture(autouse=True) +def _reset_auth_context() -> None: + client_mod._api_key_override.set(None) + mcp_http_auth._auth_db = None + mcp_http_auth._api_key_validation_cache.clear() + mcp_http_auth._API_KEY_CACHE_TTL_SECONDS = 30.0 + mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024 + mcp_http_auth._MAX_VALIDATION_RETRIES = 2 + mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests + + +async def _echo_request_context(request: Request) -> JSONResponse: + return JSONResponse( + { + "api_key": client_mod.get_active_api_key(), + "organization_id": getattr(request.state, "organization_id", None), + } + ) + + +def _build_test_app() -> Starlette: + return Starlette( + routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])], + middleware=[Middleware(mcp_http_auth.MCPAPIKeyMiddleware)], + ) + + +@pytest.mark.asyncio +async def test_mcp_http_auth_rejects_missing_api_key() -> None: + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post("/mcp", json={}) + + assert response.status_code == 401 + assert response.json()["error"]["code"] == "UNAUTHORIZED" + assert "x-api-key" in response.json()["error"]["message"] + + +@pytest.mark.asyncio +async def test_mcp_http_auth_allows_health_checks_without_api_key() -> None: + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.get("/healthz") + + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +async def test_mcp_http_auth_rejects_invalid_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + mcp_http_auth, + "validate_mcp_api_key", + AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials")), + ) + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post("/mcp", headers={"x-api-key": "bad-key"}, json={}) + + assert response.status_code == 401 + assert response.json()["error"]["code"] == "UNAUTHORIZED" + assert response.json()["error"]["message"] == "Invalid API key" + + +@pytest.mark.asyncio +async def test_mcp_http_auth_returns_500_on_non_auth_http_exception(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + mcp_http_auth, + "validate_mcp_api_key", + AsyncMock(side_effect=HTTPException(status_code=500, detail="db down")), + ) + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={}) + + assert response.status_code == 500 + assert response.json()["error"]["code"] == "INTERNAL_ERROR" + + +@pytest.mark.asyncio +async def test_mcp_http_auth_returns_503_on_transient_validation_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + mcp_http_auth, + "validate_mcp_api_key", + AsyncMock(side_effect=HTTPException(status_code=503, detail="API key validation temporarily unavailable")), + ) + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={}) + + assert response.status_code == 503 + assert response.json()["error"]["code"] == "SERVICE_UNAVAILABLE" + assert response.json()["error"]["message"] == "API key validation temporarily unavailable" + + +@pytest.mark.asyncio +async def test_mcp_http_auth_returns_500_on_unexpected_validation_error(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + mcp_http_auth, + "validate_mcp_api_key", + AsyncMock(side_effect=RuntimeError("boom")), + ) + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={}) + + assert response.status_code == 500 + assert response.json()["error"]["code"] == "INTERNAL_ERROR" + + +@pytest.mark.asyncio +async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + mcp_http_auth, + "validate_mcp_api_key", + AsyncMock(return_value="org_123"), + ) + app = _build_test_app() + + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client: + response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={}) + + assert response.status_code == 200 + assert response.json() == { + "api_key": "sk_live_abc", + "organization_id": "org_123", + } + assert client_mod.get_active_api_key() != "sk_live_abc" + + +@pytest.mark.asyncio +async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None: + calls = 0 + + async def _resolve(_api_key: str, _db: object) -> object: + nonlocal calls + calls += 1 + return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached")) + + monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) + monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) + + first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache") + second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache") + + assert first == "org_cached" + assert second == "org_cached" + assert calls == 1 + + +@pytest.mark.asyncio +async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None: + calls = 0 + + async def _resolve(_api_key: str, _db: object) -> object: + nonlocal calls + calls += 1 + return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}")) + + monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) + monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) + + first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire") + cache_key = mcp_http_auth._cache_key("sk_test_cache_expire") + mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0) + second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire") + + assert first == "org_1" + assert second == "org_2" + assert calls == 2 + + +@pytest.mark.asyncio +async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None: + calls = 0 + + async def _resolve(_api_key: str, _db: object) -> object: + nonlocal calls + calls += 1 + raise HTTPException(status_code=401, detail="Invalid credentials") + + monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) + monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) + + with pytest.raises(HTTPException, match="Invalid credentials"): + await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure") + + with pytest.raises(HTTPException, match="Invalid API key"): + await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure") + + assert calls == 1 + + +@pytest.mark.asyncio +async def test_validate_mcp_api_key_retries_transient_failure_without_negative_cache( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls = 0 + + async def _resolve(_api_key: str, _db: object) -> object: + nonlocal calls + calls += 1 + if calls == 1: + raise RuntimeError("transient db error") + return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered")) + + monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) + monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) + + recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient") + + cache_key = mcp_http_auth._cache_key("sk_test_transient") + assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered" + + assert recovered_org == "org_recovered" + assert calls == 2 + + +@pytest.mark.asyncio +async def test_validate_mcp_api_key_concurrent_callers_all_succeed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Multiple concurrent callers for the same key all succeed; the cache + collapses subsequent calls after the first one populates it.""" + calls = 0 + + async def _resolve(_api_key: str, _db: object) -> object: + nonlocal calls + calls += 1 + return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent")) + + monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) + monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) + + results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)]) + assert all(r == "org_concurrent" for r in results) + # First call populates cache; remaining may or may not hit DB depending on + # scheduling, but all must succeed. + assert calls >= 1 + + +@pytest.mark.asyncio +async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None: + calls = 0 + + async def _resolve(_api_key: str, _db: object) -> object: + nonlocal calls + calls += 1 + raise RuntimeError("persistent db outage") + + monkeypatch.setattr(mcp_http_auth, "_MAX_VALIDATION_RETRIES", 2) + monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve) + monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object()) + + with pytest.raises(HTTPException, match="temporarily unavailable") as exc_info: + await mcp_http_auth.validate_mcp_api_key("sk_test_transient_exhausted") + + assert exc_info.value.status_code == 503 + assert calls == 3 # initial + 2 retries + + +@pytest.mark.asyncio +async def test_close_auth_db_disposes_engine() -> None: + dispose = AsyncMock() + mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose)) + mcp_http_auth._api_key_validation_cache["k"] = ("org", 123.0) + + await mcp_http_auth.close_auth_db() + + dispose.assert_awaited_once() + assert mcp_http_auth._auth_db is None + assert mcp_http_auth._api_key_validation_cache == {} + + +@pytest.mark.asyncio +async def test_close_auth_db_noop_when_uninitialized() -> None: + mcp_http_auth._auth_db = None + await mcp_http_auth.close_auth_db() + assert mcp_http_auth._auth_db is None diff --git a/tests/unit/test_mcp_session_lifecycle.py b/tests/unit/test_mcp_session_lifecycle.py index 568fccda..a26d8a12 100644 --- a/tests/unit/test_mcp_session_lifecycle.py +++ b/tests/unit/test_mcp_session_lifecycle.py @@ -1,5 +1,6 @@ from __future__ import annotations +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock import pytest @@ -14,10 +15,13 @@ from skyvern.cli.mcp_tools import session as mcp_session @pytest.fixture(autouse=True) def _reset_singletons() -> None: client_mod._skyvern_instance.set(None) + client_mod._api_key_override.set(None) client_mod._global_skyvern_instance = None + client_mod._api_key_clients.clear() session_manager._current_session.set(None) session_manager._global_session = None + session_manager.set_stateless_http_mode(False) mcp_session.set_current_session(mcp_session.SessionState()) @@ -47,6 +51,115 @@ def test_get_skyvern_reuses_global_instance_across_contexts(monkeypatch: pytest. assert len(created) == 1 +def test_get_skyvern_reuses_override_instance_per_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + created_keys: list[str] = [] + + class FakeSkyvern: + def __init__(self, *args: object, **kwargs: object) -> None: + created_keys.append(kwargs["api_key"]) + + @classmethod + def local(cls) -> FakeSkyvern: + return cls(api_key="local") + + async def aclose(self) -> None: + return None + + monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern) + monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None) + monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None) + + token = client_mod.set_api_key_override("sk_key_a") + try: + first = client_mod.get_skyvern() + client_mod._skyvern_instance.set(None) + second = client_mod.get_skyvern() + finally: + client_mod.reset_api_key_override(token) + + assert first is second + assert created_keys == ["sk_key_a"] + + +def test_get_skyvern_override_client_cache_uses_lru_eviction(monkeypatch: pytest.MonkeyPatch) -> None: + created_keys: list[str] = [] + + class FakeSkyvern: + def __init__(self, *args: object, **kwargs: object) -> None: + created_keys.append(kwargs["api_key"]) + + @classmethod + def local(cls) -> FakeSkyvern: + return cls(api_key="local") + + async def aclose(self) -> None: + return None + + monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern) + monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None) + monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None) + monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 2) + + for key in ("sk_key_a", "sk_key_b"): + token = client_mod.set_api_key_override(key) + try: + client_mod.get_skyvern() + finally: + client_mod.reset_api_key_override(token) + + # Touch key_a so key_b becomes least-recently-used. + token = client_mod.set_api_key_override("sk_key_a") + try: + client_mod._skyvern_instance.set(None) + client_mod.get_skyvern() + finally: + client_mod.reset_api_key_override(token) + + # Adding key_c should evict key_b. + token = client_mod.set_api_key_override("sk_key_c") + try: + client_mod.get_skyvern() + finally: + client_mod.reset_api_key_override(token) + + assert list(client_mod._api_key_clients.keys()) == [ + client_mod._cache_key("sk_key_a"), + client_mod._cache_key("sk_key_c"), + ] + # key_a, key_b, key_c were created exactly once each. + assert created_keys == ["sk_key_a", "sk_key_b", "sk_key_c"] + + +def test_get_skyvern_override_cache_closes_evicted_client(monkeypatch: pytest.MonkeyPatch) -> None: + closed_keys: list[str] = [] + + class FakeSkyvern: + def __init__(self, *args: object, **kwargs: object) -> None: + self.api_key = kwargs["api_key"] + + @classmethod + def local(cls) -> FakeSkyvern: + return cls(api_key="local") + + async def aclose(self) -> None: + closed_keys.append(self.api_key) + + monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern) + monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None) + monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None) + monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 1) + + for key in ("sk_key_a", "sk_key_b"): + token = client_mod.set_api_key_override(key) + try: + client_mod.get_skyvern() + finally: + client_mod.reset_api_key_override(token) + + assert list(client_mod._api_key_clients.keys()) == [client_mod._cache_key("sk_key_b")] + assert closed_keys == ["sk_key_a"] + + @pytest.mark.asyncio async def test_close_skyvern_closes_singleton() -> None: fake = MagicMock() @@ -75,12 +188,53 @@ def test_get_current_session_falls_back_to_global_state() -> None: assert recovered is state +def test_get_current_session_stateless_mode_ignores_global_state() -> None: + global_state = session_manager.SessionState( + browser=MagicMock(), + context=BrowserContext(mode="cloud_session", session_id="pbs_999"), + ) + session_manager._global_session = global_state + session_manager._current_session.set(None) + + session_manager.set_stateless_http_mode(True) + try: + recovered = session_manager.get_current_session() + finally: + session_manager.set_stateless_http_mode(False) + + assert recovered is not global_state + assert recovered.browser is None + assert recovered.context is None + + +def test_set_current_session_stateless_mode_does_not_override_global_state() -> None: + global_state = session_manager.SessionState( + browser=MagicMock(), + context=BrowserContext(mode="cloud_session", session_id="pbs_global"), + ) + session_manager._global_session = global_state + replacement = session_manager.SessionState( + browser=MagicMock(), + context=BrowserContext(mode="cloud_session", session_id="pbs_request"), + ) + + session_manager.set_stateless_http_mode(True) + try: + session_manager.set_current_session(replacement) + finally: + session_manager.set_stateless_http_mode(False) + + assert session_manager._global_session is global_state + assert session_manager._current_session.get() is replacement + + @pytest.mark.asyncio async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None: current_browser = MagicMock() current_state = session_manager.SessionState( browser=current_browser, context=BrowserContext(mode="cloud_session", session_id="pbs_123"), + api_key_hash=session_manager._api_key_hash(client_mod.get_active_api_key()), ) session_manager.set_current_session(current_state) @@ -95,6 +249,92 @@ async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited() +@pytest.mark.asyncio +async def test_resolve_browser_does_not_reuse_session_for_different_api_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + current_browser = MagicMock() + session_manager.set_current_session( + session_manager.SessionState( + browser=current_browser, + context=BrowserContext(mode="cloud_session", session_id="pbs_123"), + api_key_hash=session_manager._api_key_hash("sk_key_a"), + ) + ) + + replacement_browser = MagicMock() + fake_skyvern = MagicMock() + fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser) + monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern) + + token = client_mod.set_api_key_override("sk_key_b") + try: + browser, ctx = await session_manager.resolve_browser(session_id="pbs_123") + finally: + client_mod.reset_api_key_override(token) + + assert browser is replacement_browser + assert ctx.session_id == "pbs_123" + fake_skyvern.connect_to_cloud_browser_session.assert_awaited_once_with("pbs_123") + + +@pytest.mark.asyncio +async def test_resolve_browser_stateless_mode_does_not_write_global_session(monkeypatch: pytest.MonkeyPatch) -> None: + global_state = session_manager.SessionState( + browser=MagicMock(), + context=BrowserContext(mode="cloud_session", session_id="pbs_global"), + ) + session_manager._global_session = global_state + + replacement_browser = MagicMock() + fake_skyvern = MagicMock() + fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser) + monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern) + + session_manager.set_stateless_http_mode(True) + try: + browser, ctx = await session_manager.resolve_browser(session_id="pbs_123") + finally: + session_manager.set_stateless_http_mode(False) + + assert browser is replacement_browser + assert ctx.session_id == "pbs_123" + assert session_manager._global_session is global_state + + +@pytest.mark.asyncio +async def test_resolve_browser_blocks_implicit_session_in_stateless_mode() -> None: + session_manager.set_current_session( + session_manager.SessionState( + browser=MagicMock(), + context=BrowserContext(mode="cloud_session", session_id="pbs_123"), + api_key_hash=session_manager._api_key_hash("sk_key_a"), + ) + ) + session_manager.set_stateless_http_mode(True) + try: + with pytest.raises(session_manager.BrowserNotAvailableError): + await session_manager.resolve_browser() + finally: + session_manager.set_stateless_http_mode(False) + + +@pytest.mark.asyncio +async def test_resolve_browser_raises_for_invalid_matching_state(monkeypatch: pytest.MonkeyPatch) -> None: + session_manager.set_current_session( + session_manager.SessionState( + browser=None, + context=BrowserContext(mode="cloud_session", session_id="pbs_123"), + ) + ) + + monkeypatch.setattr(session_manager, "_matches_current", lambda *args, **kwargs: True) + monkeypatch.setattr(session_manager, "get_skyvern", lambda: MagicMock()) + + with pytest.raises(RuntimeError, match="Expected active browser and context"): + await session_manager.resolve_browser(session_id="pbs_123") + + @pytest.mark.asyncio async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None: current_browser = MagicMock() @@ -262,3 +502,73 @@ async def test_close_current_session_still_closes_browser_when_api_fails(monkeyp # _browser_session_id should NOT be cleared (API close failed, let browser.close() try) assert browser._browser_session_id == "pbs_fail" assert session_manager.get_current_session().browser is None + + +# --------------------------------------------------------------------------- +# Tests for stateless HTTP mode session creation +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_session_create_stateless_mode_returns_session_without_persisting_browser( + monkeypatch: pytest.MonkeyPatch, +) -> None: + session_manager.set_stateless_http_mode(True) + fake_skyvern = MagicMock() + fake_skyvern.create_browser_session = AsyncMock(return_value=SimpleNamespace(browser_session_id="pbs_abc")) + monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern) + do_session_create = AsyncMock() + monkeypatch.setattr(mcp_session, "do_session_create", do_session_create) + + try: + result = await mcp_session.skyvern_session_create(timeout=45) + finally: + session_manager.set_stateless_http_mode(False) + + assert result["ok"] is True + assert result["data"] == {"session_id": "pbs_abc", "timeout_minutes": 45} + do_session_create.assert_not_awaited() + assert mcp_session.get_current_session().browser is None + assert mcp_session.get_current_session().context is None + + +@pytest.mark.asyncio +async def test_session_create_stateless_mode_rejects_local() -> None: + session_manager.set_stateless_http_mode(True) + try: + result = await mcp_session.skyvern_session_create(local=True) + finally: + session_manager.set_stateless_http_mode(False) + + assert result["ok"] is False + assert result["error"]["code"] == mcp_session.ErrorCode.INVALID_INPUT + + +@pytest.mark.asyncio +async def test_session_create_persists_active_api_key_hash_in_session_state( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake_skyvern = MagicMock() + monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern) + + fake_browser = MagicMock() + do_session_create = AsyncMock( + return_value=( + fake_browser, + SimpleNamespace(local=False, session_id="pbs_123", timeout_minutes=60, headless=False), + ) + ) + monkeypatch.setattr(mcp_session, "do_session_create", do_session_create) + + token = client_mod.set_api_key_override("sk_key_create") + try: + result = await mcp_session.skyvern_session_create(timeout=60) + finally: + client_mod.reset_api_key_override(token) + + assert result["ok"] is True + current = mcp_session.get_current_session() + assert current.browser is fake_browser + assert current.context == BrowserContext(mode="cloud_session", session_id="pbs_123") + assert current.api_key_hash == session_manager._api_key_hash("sk_key_create") + assert current.api_key_hash != "sk_key_create" diff --git a/tests/unit/test_run_commands_cleanup.py b/tests/unit/test_run_commands_cleanup.py index 74736cfc..34fc7909 100644 --- a/tests/unit/test_run_commands_cleanup.py +++ b/tests/unit/test_run_commands_cleanup.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, call import pytest @@ -13,6 +13,42 @@ def _reset_cleanup_state() -> None: run_commands._mcp_cleanup_done = False +@pytest.mark.asyncio +async def test_cleanup_mcp_resources_closes_auth_db(monkeypatch: pytest.MonkeyPatch) -> None: + close_current_session = AsyncMock() + close_skyvern = AsyncMock() + close_auth_db = AsyncMock() + + monkeypatch.setattr(run_commands, "close_current_session", close_current_session) + monkeypatch.setattr(run_commands, "close_skyvern", close_skyvern) + monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db) + + await run_commands._cleanup_mcp_resources() + + close_current_session.assert_awaited_once() + close_skyvern.assert_awaited_once() + close_auth_db.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_cleanup_mcp_resources_closes_auth_db_on_skyvern_close_error(monkeypatch: pytest.MonkeyPatch) -> None: + close_current_session = AsyncMock() + close_auth_db = AsyncMock() + + async def _failing_close_skyvern() -> None: + raise RuntimeError("close failed") + + monkeypatch.setattr(run_commands, "close_current_session", close_current_session) + monkeypatch.setattr(run_commands, "close_skyvern", _failing_close_skyvern) + monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db) + + with pytest.raises(RuntimeError, match="close failed"): + await run_commands._cleanup_mcp_resources() + + close_current_session.assert_awaited_once() + close_auth_db.assert_awaited_once() + + def test_cleanup_mcp_resources_sync_runs_without_running_loop(monkeypatch: pytest.MonkeyPatch) -> None: cleanup = AsyncMock() monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup) @@ -50,14 +86,56 @@ def test_run_mcp_calls_blocking_cleanup_in_finally(monkeypatch: pytest.MonkeyPat cleanup_blocking = MagicMock() register = MagicMock() run = MagicMock(side_effect=RuntimeError("boom")) + set_stateless = MagicMock() monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking) monkeypatch.setattr(run_commands.atexit, "register", register) monkeypatch.setattr(run_commands.mcp, "run", run) + monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless) with pytest.raises(RuntimeError, match="boom"): run_commands.run_mcp() register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync) run.assert_called_once_with(transport="stdio") + set_stateless.assert_has_calls([call(False), call(False)]) cleanup_blocking.assert_called_once() + + +def test_run_mcp_http_transport_wires_auth_middleware(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup_blocking = MagicMock() + register = MagicMock() + run = MagicMock() + set_stateless = MagicMock() + + monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking) + monkeypatch.setattr(run_commands.atexit, "register", register) + monkeypatch.setattr(run_commands.mcp, "run", run) + monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless) + + run_commands.run_mcp( + transport="streamable-http", + host="127.0.0.1", + port=9010, + path="mcp", + stateless_http=True, + ) + + register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync) + run.assert_called_once() + kwargs = run.call_args.kwargs + assert kwargs["transport"] == "streamable-http" + assert kwargs["host"] == "127.0.0.1" + assert kwargs["port"] == 9010 + assert kwargs["path"] == "/mcp" + assert kwargs["stateless_http"] is True + middleware = kwargs["middleware"] + assert len(middleware) == 1 + assert middleware[0].cls is run_commands.MCPAPIKeyMiddleware + set_stateless.assert_has_calls([call(True), call(False)]) + cleanup_blocking.assert_called_once() + + +def test_run_task_tool_registration_points_to_browser_module() -> None: + tool = run_commands.mcp._tool_manager._tools["skyvern_run_task"] # type: ignore[attr-defined] + assert tool.fn.__module__ == "skyvern.cli.mcp_tools.browser"