align workflow CLI commands with MCP parity (#4792)

This commit is contained in:
Marc Kelechava
2026-02-18 11:34:12 -08:00
committed by GitHub
parent 2f6850ce20
commit 46a7ec1d26
12 changed files with 1609 additions and 151 deletions

View File

@@ -68,6 +68,30 @@ Add this to your MCP client's configuration file:
Replace `/usr/bin/python3` with the output of `which python3` on your machine. For local mode, set `SKYVERN_BASE_URL` to `http://localhost:8000` and find your API key in the `.env` file after running `skyvern init`.
### Option C: Remote MCP over HTTPS (streamable HTTP)
Use this when your team provides a hosted MCP endpoint (for example: `https://mcp.skyvern.com/mcp`).
In remote HTTP mode:
- Clients must send `x-api-key` on every request.
- Use `skyvern_session_create` first, then pass `session_id` explicitly on subsequent browser tool calls.
If your MCP client supports native remote HTTP transport, configure it directly:
```json
{
"mcpServers": {
"SkyvernRemote": {
"type": "streamable-http",
"url": "https://mcp.skyvern.com/mcp",
"headers": {
"x-api-key": "YOUR_SKYVERN_API_KEY"
}
}
}
}
```
<Accordion title="Config file locations by client">
| Client | Path |

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
import hashlib
import os
def _resolve_api_key_hash_iterations() -> int:
raw = os.environ.get("SKYVERN_MCP_API_KEY_HASH_ITERATIONS", "120000")
try:
return max(10_000, int(raw))
except ValueError:
return 120_000
_API_KEY_HASH_ITERATIONS = _resolve_api_key_hash_iterations()
_API_KEY_HASH_SALT = os.environ.get(
"SKYVERN_MCP_API_KEY_HASH_SALT",
"skyvern-mcp-api-key-cache-v1",
).encode("utf-8")
def hash_api_key_for_cache(api_key: str) -> str:
"""Derive a deterministic, non-reversible fingerprint for API-key keyed caches."""
return hashlib.pbkdf2_hmac(
"sha256",
api_key.encode("utf-8"),
_API_KEY_HASH_SALT,
_API_KEY_HASH_ITERATIONS,
).hex()

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
import asyncio
import os
from contextvars import ContextVar
from collections import OrderedDict
from contextvars import ContextVar, Token
from threading import RLock
import structlog
@@ -9,35 +12,126 @@ from skyvern.client import SkyvernEnvironment
from skyvern.config import settings
from skyvern.library.skyvern import Skyvern
from .api_key_hash import hash_api_key_for_cache
_skyvern_instance: ContextVar[Skyvern | None] = ContextVar("skyvern_instance", default=None)
_api_key_override: ContextVar[str | None] = ContextVar("skyvern_api_key_override", default=None)
_global_skyvern_instance: Skyvern | None = None
_api_key_clients: OrderedDict[str, Skyvern] = OrderedDict()
_clients_lock = RLock()
LOG = structlog.get_logger(__name__)
def _resolve_api_key_cache_size() -> int:
raw = os.environ.get("SKYVERN_MCP_API_KEY_CLIENT_CACHE_SIZE", "128")
try:
return max(1, int(raw))
except ValueError:
return 128
_API_KEY_CLIENT_CACHE_MAX = _resolve_api_key_cache_size()
def _cache_key(api_key: str) -> str:
"""Hash API key so raw secrets are never stored as dict keys."""
return hash_api_key_for_cache(api_key)
def _resolve_api_key() -> str | None:
return settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
def _resolve_base_url() -> str | None:
return settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
def _build_cloud_client(api_key: str) -> Skyvern:
return Skyvern(
api_key=api_key,
environment=SkyvernEnvironment.CLOUD,
base_url=_resolve_base_url(),
)
def _close_skyvern_instance_best_effort(instance: Skyvern) -> None:
"""Close a Skyvern instance, regardless of whether an event loop is running."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
try:
asyncio.run(instance.aclose())
except Exception:
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
return
task = loop.create_task(instance.aclose())
def _on_done(done: asyncio.Task[None]) -> None:
try:
done.result()
except Exception:
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
task.add_done_callback(_on_done)
def get_active_api_key() -> str | None:
"""Return the effective API key for this request/context."""
return _api_key_override.get() or _resolve_api_key()
def set_api_key_override(api_key: str | None) -> Token[str | None]:
"""Set request-scoped API key override for MCP HTTP requests."""
_skyvern_instance.set(None)
return _api_key_override.set(api_key)
def reset_api_key_override(token: Token[str | None]) -> None:
"""Reset request-scoped API key override."""
_api_key_override.reset(token)
_skyvern_instance.set(None)
def get_skyvern() -> Skyvern:
"""Get or create a Skyvern client instance."""
global _global_skyvern_instance
instance = _skyvern_instance.get()
if instance is None:
instance = _global_skyvern_instance
if instance is not None:
override_api_key = _api_key_override.get()
if override_api_key:
instance = _skyvern_instance.get()
if instance is None:
key = _cache_key(override_api_key)
evicted_clients: list[Skyvern] = []
# Hold lock across lookup + build + insert to prevent two coroutines
# from both building a client for the same API key concurrently.
with _clients_lock:
instance = _api_key_clients.get(key)
if instance is not None:
_api_key_clients.move_to_end(key)
else:
instance = _build_cloud_client(override_api_key)
_api_key_clients[key] = instance
_api_key_clients.move_to_end(key)
while len(_api_key_clients) > _API_KEY_CLIENT_CACHE_MAX:
_, evicted = _api_key_clients.popitem(last=False)
evicted_clients.append(evicted)
for evicted in evicted_clients:
_close_skyvern_instance_best_effort(evicted)
_skyvern_instance.set(instance)
return instance
api_key = settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
base_url = settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
if api_key:
instance = Skyvern(
api_key=api_key,
environment=SkyvernEnvironment.CLOUD,
base_url=base_url,
)
else:
instance = Skyvern.local()
_global_skyvern_instance = instance
instance = _skyvern_instance.get()
if instance is None:
with _clients_lock:
instance = _global_skyvern_instance
if instance is None:
api_key = _resolve_api_key()
if api_key:
instance = _build_cloud_client(api_key)
else:
instance = Skyvern.local()
_global_skyvern_instance = instance
_skyvern_instance.set(instance)
return instance
@@ -48,7 +142,12 @@ async def close_skyvern() -> None:
instances: list[Skyvern] = []
seen: set[int] = set()
for candidate in (_skyvern_instance.get(), _global_skyvern_instance):
with _clients_lock:
candidates = (_skyvern_instance.get(), _global_skyvern_instance, *_api_key_clients.values())
_api_key_clients.clear()
_global_skyvern_instance = None
for candidate in candidates:
if candidate is None or id(candidate) in seen:
continue
seen.add(id(candidate))
@@ -61,4 +160,3 @@ async def close_skyvern() -> None:
LOG.warning("Failed to close Skyvern client", exc_info=True)
_skyvern_instance.set(None)
_global_skyvern_instance = None

View File

@@ -0,0 +1,206 @@
from __future__ import annotations
import asyncio
import os
import time
from collections import OrderedDict
from threading import RLock
import structlog
from fastapi import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from skyvern.config import settings
from skyvern.forge.sdk.db.agent_db import AgentDB
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
from .api_key_hash import hash_api_key_for_cache
from .client import reset_api_key_override, set_api_key_override
LOG = structlog.get_logger(__name__)
API_KEY_HEADER = "x-api-key"
HEALTH_PATHS = {"/health", "/healthz"}
_auth_db: AgentDB | None = None
_auth_db_lock = RLock()
_api_key_cache_lock = RLock()
_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict()
_NEGATIVE_CACHE_TTL_SECONDS = 5.0
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
_MAX_VALIDATION_RETRIES = 2
_RETRY_DELAY_SECONDS = 0.25
def _resolve_api_key_cache_ttl_seconds() -> float:
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
try:
return max(1.0, float(raw))
except ValueError:
return 30.0
def _resolve_api_key_cache_max_size() -> int:
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_MAX_SIZE", "1024")
try:
return max(1, int(raw))
except ValueError:
return 1024
_API_KEY_CACHE_TTL_SECONDS = _resolve_api_key_cache_ttl_seconds()
_API_KEY_CACHE_MAX_SIZE = _resolve_api_key_cache_max_size()
def _get_auth_db() -> AgentDB:
global _auth_db
# Guard singleton init in case HTTP transport is served with threaded workers.
with _auth_db_lock:
if _auth_db is None:
_auth_db = AgentDB(settings.DATABASE_STRING, debug_enabled=settings.DEBUG_MODE)
return _auth_db
async def close_auth_db() -> None:
"""Dispose the auth DB engine used by HTTP middleware, if initialized."""
global _auth_db
with _auth_db_lock:
db = _auth_db
_auth_db = None
with _api_key_cache_lock:
_api_key_validation_cache.clear()
if db is None:
return
try:
await db.engine.dispose()
except Exception:
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
def _cache_key(api_key: str) -> str:
return hash_api_key_for_cache(api_key)
async def validate_mcp_api_key(api_key: str) -> str:
"""Validate API key and return organization id for observability."""
key = _cache_key(api_key)
# Check cache first.
with _api_key_cache_lock:
cached = _api_key_validation_cache.get(key)
if cached is not None:
organization_id, expires_at = cached
if expires_at > time.monotonic():
_api_key_validation_cache.move_to_end(key)
if organization_id is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return organization_id
_api_key_validation_cache.pop(key, None)
# Cache miss — do the DB lookup with simple retry on transient errors.
last_exc: Exception | None = None
for attempt in range(_MAX_VALIDATION_RETRIES + 1):
if attempt > 0:
await asyncio.sleep(_RETRY_DELAY_SECONDS)
try:
validation = await resolve_org_from_api_key(api_key, _get_auth_db())
organization_id = validation.organization.organization_id
with _api_key_cache_lock:
_api_key_validation_cache[key] = (
organization_id,
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
)
_api_key_validation_cache.move_to_end(key)
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
_api_key_validation_cache.popitem(last=False)
return organization_id
except HTTPException as e:
if e.status_code in {401, 403}:
with _api_key_cache_lock:
_api_key_validation_cache[key] = (None, time.monotonic() + _NEGATIVE_CACHE_TTL_SECONDS)
_api_key_validation_cache.move_to_end(key)
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
_api_key_validation_cache.popitem(last=False)
raise
last_exc = e
except Exception as e:
last_exc = e
LOG.warning("API key validation retries exhausted", attempts=_MAX_VALIDATION_RETRIES + 1, exc_info=last_exc)
raise HTTPException(status_code=503, detail=_VALIDATION_RETRY_EXHAUSTED_MESSAGE)
def _unauthorized_response(message: str) -> JSONResponse:
return JSONResponse({"error": {"code": "UNAUTHORIZED", "message": message}}, status_code=401)
def _internal_error_response() -> JSONResponse:
return JSONResponse(
{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}},
status_code=500,
)
def _service_unavailable_response(message: str) -> JSONResponse:
return JSONResponse(
{"error": {"code": "SERVICE_UNAVAILABLE", "message": message}},
status_code=503,
)
class MCPAPIKeyMiddleware:
"""Require x-api-key for MCP HTTP transport and scope requests to that key."""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive=receive)
if request.url.path in HEALTH_PATHS:
response = JSONResponse({"status": "ok"})
await response(scope, receive, send)
return
if request.method == "OPTIONS":
await self.app(scope, receive, send)
return
api_key = request.headers.get(API_KEY_HEADER)
if not api_key:
response = _unauthorized_response("Missing x-api-key header")
await response(scope, receive, send)
return
try:
organization_id = await validate_mcp_api_key(api_key)
scope.setdefault("state", {})
scope["state"]["organization_id"] = organization_id
except HTTPException as e:
if e.status_code in {401, 403}:
response = _unauthorized_response("Invalid API key")
await response(scope, receive, send)
return
if e.status_code == 503:
response = _service_unavailable_response(e.detail or _VALIDATION_RETRY_EXHAUSTED_MESSAGE)
await response(scope, receive, send)
return
LOG.warning("Unexpected HTTPException during MCP API key validation", status_code=e.status_code)
response = _internal_error_response()
await response(scope, receive, send)
return
except Exception:
LOG.exception("Unexpected MCP API key validation failure")
response = _internal_error_response()
await response(scope, receive, send)
return
token = set_api_key_override(api_key)
try:
await self.app(scope, receive, send)
finally:
reset_api_key_override(token)

View File

@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator
import structlog
from .client import get_skyvern
from .api_key_hash import hash_api_key_for_cache
from .client import get_active_api_key, get_skyvern
from .result import BrowserContext, ErrorCode, make_error
LOG = structlog.get_logger(__name__)
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
class SessionState:
browser: SkyvernBrowser | None = None
context: BrowserContext | None = None
api_key_hash: str | None = None
console_messages: list[dict[str, Any]] = field(default_factory=list)
tracing_active: bool = False
har_enabled: bool = False
@@ -28,26 +30,52 @@ class SessionState:
_current_session: ContextVar[SessionState | None] = ContextVar("mcp_session", default=None)
_global_session: SessionState | None = None
_stateless_http_mode = False
def get_current_session() -> SessionState:
global _global_session
state = _current_session.get()
if state is None:
if _global_session is None:
_global_session = SessionState()
state = _global_session
if state is not None:
return state
# In stateless HTTP mode, avoid process-wide fallback state so requests
# cannot inherit session context from other requests.
if _stateless_http_mode:
state = SessionState()
_current_session.set(state)
return state
if _global_session is None:
_global_session = SessionState()
state = _global_session
_current_session.set(state)
return state
def set_current_session(state: SessionState) -> None:
global _global_session
_global_session = state
if not _stateless_http_mode:
_global_session = state
_current_session.set(state)
def set_stateless_http_mode(enabled: bool) -> None:
global _stateless_http_mode
_stateless_http_mode = enabled
def is_stateless_http_mode() -> bool:
return _stateless_http_mode
def _api_key_hash(api_key: str | None) -> str | None:
if not api_key:
return None
return hash_api_key_for_cache(api_key)
def _matches_current(
current: SessionState,
*,
@@ -57,6 +85,8 @@ def _matches_current(
) -> bool:
if current.browser is None or current.context is None:
return False
if current.api_key_hash != _api_key_hash(get_active_api_key()):
return False
if session_id:
return current.context.mode == "cloud_session" and current.context.session_id == session_id
@@ -84,35 +114,39 @@ async def resolve_browser(
skyvern = get_skyvern()
current = get_current_session()
if _stateless_http_mode and not (session_id or cdp_url or local or create_session):
raise BrowserNotAvailableError()
if _matches_current(current, session_id=session_id, cdp_url=cdp_url, local=local):
# _matches_current() guarantees both are non-None
assert current.browser is not None and current.context is not None
if current.browser is None or current.context is None:
raise RuntimeError("Expected active browser and context for matching session")
return current.browser, current.context
active_api_key_hash = _api_key_hash(get_active_api_key())
browser: SkyvernBrowser | None = None
try:
if session_id:
browser = await skyvern.connect_to_cloud_browser_session(session_id)
ctx = BrowserContext(mode="cloud_session", session_id=session_id)
set_current_session(SessionState(browser=browser, context=ctx))
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
return browser, ctx
if cdp_url:
browser = await skyvern.connect_to_browser_over_cdp(cdp_url)
ctx = BrowserContext(mode="cdp", cdp_url=cdp_url)
set_current_session(SessionState(browser=browser, context=ctx))
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
return browser, ctx
if local:
browser = await skyvern.launch_local_browser(headless=headless)
ctx = BrowserContext(mode="local")
set_current_session(SessionState(browser=browser, context=ctx))
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
return browser, ctx
if create_session:
browser = await skyvern.launch_cloud_browser(timeout=timeout)
ctx = BrowserContext(mode="cloud_session", session_id=browser.browser_session_id)
set_current_session(SessionState(browser=browser, context=ctx))
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
return browser, ctx
except Exception:
if browser is not None:

View File

@@ -4,7 +4,11 @@ from typing import Annotated, Any
from pydantic import Field
from skyvern.cli.core.api_key_hash import hash_api_key_for_cache
from skyvern.cli.core.client import get_active_api_key
from skyvern.cli.core.session_manager import is_stateless_http_mode
from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list
from skyvern.schemas.runs import ProxyLocation
from ._common import BrowserContext, ErrorCode, Timer, make_error, make_result
from ._session import (
@@ -16,6 +20,13 @@ from ._session import (
)
def _session_api_key_hash() -> str | None:
api_key = get_active_api_key()
if not api_key:
return None
return hash_api_key_for_cache(api_key)
async def skyvern_session_create(
timeout: Annotated[int | None, Field(description="Session timeout in minutes (5-1440)")] = 60,
proxy_location: Annotated[str | None, Field(description="Proxy location: RESIDENTIAL, US, etc.")] = None,
@@ -29,7 +40,33 @@ async def skyvern_session_create(
"""
with Timer() as timer:
try:
if is_stateless_http_mode() and local:
return make_result(
"skyvern_session_create",
ok=False,
error=make_error(
ErrorCode.INVALID_INPUT,
"Local browser sessions are not supported in stateless HTTP mode",
"Use cloud sessions for remote MCP transport",
),
)
skyvern = get_skyvern()
if is_stateless_http_mode():
proxy = ProxyLocation(proxy_location) if proxy_location else None
session = await skyvern.create_browser_session(timeout=timeout or 60, proxy_location=proxy)
timer.mark("sdk")
ctx = BrowserContext(mode="cloud_session", session_id=session.browser_session_id)
return make_result(
"skyvern_session_create",
browser_context=ctx,
data={
"session_id": session.browser_session_id,
"timeout_minutes": timeout or 60,
},
timing_ms=timer.timing_ms,
)
browser, result = await do_session_create(
skyvern,
timeout=timeout or 60,
@@ -43,7 +80,7 @@ async def skyvern_session_create(
ctx = BrowserContext(mode="local")
else:
ctx = BrowserContext(mode="cloud_session", session_id=result.session_id)
set_current_session(SessionState(browser=browser, context=ctx))
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=_session_api_key_hash()))
except ValueError as e:
return make_result(

View File

@@ -5,7 +5,7 @@ import logging
import os
import shutil
import subprocess
from typing import Any, List, Optional
from typing import Annotated, List, Literal, Optional
import psutil
import typer
@@ -13,17 +13,17 @@ import uvicorn
from dotenv import load_dotenv, set_key
from rich.panel import Panel
from rich.prompt import Confirm
from starlette.middleware import Middleware
from skyvern.cli.console import console
from skyvern.cli.core.client import close_skyvern
from skyvern.cli.core.session_manager import close_current_session
from skyvern.cli.core.mcp_http_auth import MCPAPIKeyMiddleware, close_auth_db
from skyvern.cli.core.session_manager import close_current_session, set_stateless_http_mode
from skyvern.cli.mcp_tools import mcp # Uses standalone fastmcp (v2.x)
from skyvern.cli.utils import start_services
from skyvern.client import SkyvernEnvironment
from skyvern.config import settings
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.forge_log import setup_logger
from skyvern.library.skyvern import Skyvern
from skyvern.services.script_service import run_script
from skyvern.utils import detect_os
from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path
@@ -36,7 +36,10 @@ async def _cleanup_mcp_resources() -> None:
try:
await close_current_session()
finally:
await close_skyvern()
try:
await close_skyvern()
finally:
await close_auth_db()
def _cleanup_mcp_resources_blocking() -> None:
@@ -67,39 +70,6 @@ def _cleanup_mcp_resources_sync() -> None:
logger.debug("Skipping MCP cleanup because event loop is still running")
@mcp.tool()
async def skyvern_run_task(prompt: str, url: str) -> dict[str, Any]:
"""Use Skyvern to execute anything in the browser. Useful for accomplishing tasks that require browser automation.
This tool uses Skyvern's browser automation to navigate websites and perform actions to achieve
the user's intended outcome. It can handle tasks like form filling, clicking buttons, data extraction,
and multi-step workflows.
It can even help you find updated data on the internet if your model information is outdated.
Args:
prompt: A natural language description of what needs to be accomplished (e.g. "Book a flight from
NYC to LA", "Sign up for the newsletter", "Find the price of item X", "Apply to a job")
url: The starting URL of the website where the task should be performed
"""
skyvern_agent = Skyvern(
environment=SkyvernEnvironment.CLOUD,
base_url=settings.SKYVERN_BASE_URL,
api_key=settings.SKYVERN_API_KEY,
)
res = await skyvern_agent.run_task(prompt=prompt, url=url, user_agent="skyvern-mcp", wait_for_completion=True)
output = res.model_dump()["output"]
if res.app_url:
task_url = res.app_url
else:
if res.run_id and res.run_id.startswith("wr_"):
task_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/runs/{res.run_id}/overview"
else:
task_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/tasks/{res.run_id}/actions"
return {"output": output, "task_url": task_url, "run_id": res.run_id}
def get_pids_on_port(port: int) -> List[int]:
"""Return a list of PIDs listening on the given port."""
pids = []
@@ -295,20 +265,61 @@ def run_dev() -> None:
@run_app.command(name="mcp")
def run_mcp() -> None:
"""Run the MCP server."""
# This breaks the MCP processing because it expects json output only
# console.print(Panel("[bold green]Starting MCP Server...[/bold green]", border_style="green"))
def run_mcp(
transport: Annotated[
Literal["stdio", "sse", "streamable-http"],
typer.Option(
"--transport",
help="MCP transport: stdio (default), sse, or streamable-http.",
),
] = "stdio",
host: Annotated[str, typer.Option("--host", help="Host for HTTP transports.")] = "0.0.0.0",
port: Annotated[int, typer.Option("--port", help="Port for HTTP transports.")] = 8000,
path: Annotated[str, typer.Option("--path", help="HTTP endpoint path for MCP transport.")] = "/mcp",
stateless_http: Annotated[
bool,
typer.Option(
"--stateless-http/--no-stateless-http",
help="Use stateless HTTP semantics for HTTP transports (ignored for stdio).",
),
] = True,
) -> None:
"""Run the MCP server with configurable transport for local or remote hosting."""
path = _normalize_mcp_path(path)
stateless_http_enabled = transport != "stdio" and stateless_http
# atexit covers signal-based exits (SIGTERM); finally covers normal
# mcp.run() completion or unhandled exceptions. Both are needed because
# atexit doesn't fire on normal return and finally doesn't fire on signals.
atexit.register(_cleanup_mcp_resources_sync)
set_stateless_http_mode(stateless_http_enabled)
try:
mcp.run(transport="stdio")
if transport == "stdio":
mcp.run(transport="stdio")
return
middleware = [Middleware(MCPAPIKeyMiddleware)]
mcp.run(
transport=transport,
host=host,
port=port,
path=path,
middleware=middleware,
stateless_http=stateless_http_enabled,
)
finally:
set_stateless_http_mode(False)
_cleanup_mcp_resources_blocking()
def _normalize_mcp_path(path: str) -> str:
path = path.strip()
if not path:
return "/mcp"
if not path.startswith("/"):
return f"/{path}"
return path
@run_app.command(
name="code",
context_settings={"allow_interspersed_args": False},

View File

@@ -1,27 +1,80 @@
"""Workflow-related CLI helpers."""
"""Workflow-related CLI commands with MCP-parity flags and output."""
from __future__ import annotations
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Coroutine
import typer
from dotenv import load_dotenv
from rich.panel import Panel
from skyvern.client import Skyvern
from skyvern.config import settings
from skyvern.utils.env_paths import resolve_backend_env_path
from .console import console
from .tasks import _list_workflow_tasks
from .commands._output import output, output_error
from .mcp_tools.workflow import skyvern_workflow_cancel as tool_workflow_cancel
from .mcp_tools.workflow import skyvern_workflow_create as tool_workflow_create
from .mcp_tools.workflow import skyvern_workflow_delete as tool_workflow_delete
from .mcp_tools.workflow import skyvern_workflow_get as tool_workflow_get
from .mcp_tools.workflow import skyvern_workflow_list as tool_workflow_list
from .mcp_tools.workflow import skyvern_workflow_run as tool_workflow_run
from .mcp_tools.workflow import skyvern_workflow_status as tool_workflow_status
from .mcp_tools.workflow import skyvern_workflow_update as tool_workflow_update
workflow_app = typer.Typer(help="Manage Skyvern workflows.")
workflow_app = typer.Typer(help="Manage Skyvern workflows.", no_args_is_help=True)
def _emit_tool_result(result: dict[str, Any], *, json_output: bool) -> None:
if json_output:
json.dump(result, sys.stdout, indent=2, default=str)
sys.stdout.write("\n")
if not result.get("ok", False):
raise SystemExit(1)
return
if result.get("ok", False):
output(result.get("data"), action=str(result.get("action", "")), json_mode=False)
return
err = result.get("error") or {}
output_error(str(err.get("message", "Unknown error")), hint=str(err.get("hint", "")), json_mode=False)
def _run_tool(
runner: Callable[[], Coroutine[Any, Any, dict[str, Any]]],
*,
json_output: bool,
hint_on_exception: str,
) -> None:
try:
result: dict[str, Any] = asyncio.run(runner())
_emit_tool_result(result, json_output=json_output)
except typer.BadParameter:
raise
except Exception as e:
output_error(str(e), hint=hint_on_exception, json_mode=json_output)
def _resolve_inline_or_file(value: str | None, *, param_name: str) -> str | None:
if value is None or not value.startswith("@"):
return value
file_path = value[1:]
if not file_path:
raise typer.BadParameter(f"{param_name} file path cannot be empty after '@'.")
path = Path(file_path).expanduser()
try:
return path.read_text(encoding="utf-8")
except OSError as e:
raise typer.BadParameter(f"Unable to read {param_name} file '{path}': {e}") from e
@workflow_app.callback()
def workflow_callback(
ctx: typer.Context,
api_key: str | None = typer.Option(
None,
"--api-key",
@@ -29,86 +82,188 @@ def workflow_callback(
envvar="SKYVERN_API_KEY",
),
) -> None:
"""Store the provided API key in the Typer context."""
ctx.obj = {"api_key": api_key}
def _get_client(api_key: str | None = None) -> Skyvern:
"""Instantiate a Skyvern SDK client using environment variables."""
"""Load workflow CLI environment and optional API key override."""
load_dotenv(resolve_backend_env_path())
key = api_key or os.getenv("SKYVERN_API_KEY") or settings.SKYVERN_API_KEY
return Skyvern(base_url=settings.SKYVERN_BASE_URL, api_key=key)
if api_key:
settings.SKYVERN_API_KEY = api_key
@workflow_app.command("list")
def workflow_list(
search: str | None = typer.Option(
None,
"--search",
help="Search across workflow titles, folder names, and parameter metadata.",
),
page: int = typer.Option(1, "--page", min=1, help="Page number (1-based)."),
page_size: int = typer.Option(10, "--page-size", min=1, max=100, help="Results per page."),
only_workflows: bool = typer.Option(
False,
"--only-workflows",
help="Only return multi-step workflows (exclude saved tasks).",
),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""List workflows."""
async def _run() -> dict[str, Any]:
return await tool_workflow_list(
search=search,
page=page,
page_size=page_size,
only_workflows=only_workflows,
)
_run_tool(_run, json_output=json_output, hint_on_exception="Check your API key and workflow list filters.")
@workflow_app.command("get")
def workflow_get(
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
version: int | None = typer.Option(None, "--version", min=1, help="Specific version to retrieve."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Get a workflow definition by ID."""
async def _run() -> dict[str, Any]:
return await tool_workflow_get(workflow_id=workflow_id, version=version)
_run_tool(_run, json_output=json_output, hint_on_exception="Check your API key and workflow ID.")
@workflow_app.command("create")
def workflow_create(
definition: str = typer.Option(
...,
"--definition",
help="Workflow definition as YAML/JSON string or @file path.",
),
definition_format: str = typer.Option(
"auto",
"--format",
help="Definition format: json, yaml, or auto.",
),
folder_id: str | None = typer.Option(None, "--folder-id", help="Folder ID (fld_...) for the workflow."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Create a workflow."""
async def _run() -> dict[str, Any]:
resolved_definition = _resolve_inline_or_file(definition, param_name="definition")
assert resolved_definition is not None
return await tool_workflow_create(
definition=resolved_definition,
format=definition_format,
folder_id=folder_id,
)
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow definition syntax.")
@workflow_app.command("update")
def workflow_update(
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
definition: str = typer.Option(
...,
"--definition",
help="Updated workflow definition as YAML/JSON string or @file path.",
),
definition_format: str = typer.Option(
"auto",
"--format",
help="Definition format: json, yaml, or auto.",
),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Update a workflow definition."""
async def _run() -> dict[str, Any]:
resolved_definition = _resolve_inline_or_file(definition, param_name="definition")
assert resolved_definition is not None
return await tool_workflow_update(
workflow_id=workflow_id,
definition=resolved_definition,
format=definition_format,
)
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and definition syntax.")
@workflow_app.command("delete")
def workflow_delete(
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
force: bool = typer.Option(False, "--force", help="Confirm permanent deletion."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Delete a workflow."""
async def _run() -> dict[str, Any]:
return await tool_workflow_delete(workflow_id=workflow_id, force=force)
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and your permissions.")
@workflow_app.command("run")
def run_workflow(
ctx: typer.Context,
workflow_id: str = typer.Argument(..., help="Workflow permanent ID"),
parameters: str = typer.Option("{}", "--parameters", "-p", help="JSON parameters for the workflow"),
title: str | None = typer.Option(None, "--title", help="Title for the workflow run"),
max_steps: int | None = typer.Option(None, "--max-steps", help="Override the workflow max steps"),
def workflow_run(
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
params: str | None = typer.Option(
None,
"--params",
"--parameters",
"-p",
help="Workflow parameters as JSON string or @file path.",
),
session: str | None = typer.Option(None, "--session", help="Browser session ID (pbs_...) to reuse."),
webhook: str | None = typer.Option(None, "--webhook", help="Status webhook callback URL."),
proxy: str | None = typer.Option(None, "--proxy", help="Proxy location (e.g., RESIDENTIAL)."),
wait: bool = typer.Option(False, "--wait", help="Wait for workflow completion before returning."),
timeout: int = typer.Option(
300,
"--timeout",
min=10,
max=3600,
help="Max wait time in seconds when --wait is set.",
),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Run a workflow."""
try:
params_dict = json.loads(parameters) if parameters else {}
except json.JSONDecodeError:
console.print(f"[red]Invalid JSON for parameters: {parameters}[/red]")
raise typer.Exit(code=1)
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
run_resp = client.run_workflow(
workflow_id=workflow_id,
parameters=params_dict,
title=title,
max_steps_override=max_steps,
)
console.print(
Panel(
f"Started workflow run [bold]{run_resp.run_id}[/bold]",
border_style="green",
async def _run() -> dict[str, Any]:
resolved_params = _resolve_inline_or_file(params, param_name="params")
return await tool_workflow_run(
workflow_id=workflow_id,
parameters=resolved_params,
browser_session_id=session,
webhook_url=webhook,
proxy_location=proxy,
wait=wait,
timeout_seconds=timeout,
)
)
@workflow_app.command("cancel")
def cancel_workflow(
ctx: typer.Context,
run_id: str = typer.Argument(..., help="ID of the workflow run"),
) -> None:
"""Cancel a running workflow."""
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
client.cancel_run(run_id=run_id)
console.print(Panel(f"Cancel signal sent for run {run_id}", border_style="red"))
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and run parameters.")
@workflow_app.command("status")
def workflow_status(
ctx: typer.Context,
run_id: str = typer.Argument(..., help="ID of the workflow run"),
tasks: bool = typer.Option(False, "--tasks", help="Show task executions"),
run_id: str = typer.Option(..., "--run-id", help="Run ID (wr_... or tsk_v2_...)."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""Retrieve status information for a workflow run."""
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
run = client.get_run(run_id=run_id)
console.print(Panel(run.model_dump_json(indent=2), border_style="cyan"))
if tasks:
task_list = _list_workflow_tasks(client, run_id)
console.print(Panel(json.dumps(task_list, indent=2), border_style="magenta"))
"""Get workflow run status."""
async def _run() -> dict[str, Any]:
return await tool_workflow_status(run_id=run_id)
_run_tool(_run, json_output=json_output, hint_on_exception="Check the run ID and API key.")
@workflow_app.command("list")
def list_workflows(
ctx: typer.Context,
page: int = typer.Option(1, "--page", help="Page number"),
page_size: int = typer.Option(10, "--page-size", help="Number of workflows to return"),
template: bool = typer.Option(False, "--template", help="List template workflows"),
@workflow_app.command("cancel")
def workflow_cancel(
run_id: str = typer.Option(..., "--run-id", help="Run ID (wr_... or tsk_v2_...)."),
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
) -> None:
"""List workflows for the organization."""
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
resp = client._client_wrapper.httpx_client.request(
"api/v1/workflows",
method="GET",
params={"page": page, "page_size": page_size, "template": template},
)
resp.raise_for_status()
console.print(Panel(json.dumps(resp.json(), indent=2), border_style="cyan"))
"""Cancel a workflow run."""
async def _run() -> dict[str, Any]:
return await tool_workflow_cancel(run_id=run_id)
_run_tool(_run, json_output=json_output, hint_on_exception="Check the run ID and API key.")

View File

@@ -296,3 +296,175 @@ class TestBrowserCommands:
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is False
assert "Invalid state" in parsed["error"]["message"]
# ---------------------------------------------------------------------------
# Workflow command behavior
# ---------------------------------------------------------------------------
class TestWorkflowCommands:
def test_workflow_get_outputs_mcp_envelope_in_json_mode(
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
) -> None:
from skyvern.cli import workflow as workflow_cmd
expected = {
"ok": True,
"action": "skyvern_workflow_get",
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
"data": {"workflow_permanent_id": "wpid_123"},
"artifacts": [],
"timing_ms": {},
"warnings": [],
"error": None,
}
tool = AsyncMock(return_value=expected)
monkeypatch.setattr(workflow_cmd, "tool_workflow_get", tool)
workflow_cmd.workflow_get(workflow_id="wpid_123", version=2, json_output=True)
parsed = json.loads(capsys.readouterr().out)
assert parsed == expected
assert tool.await_args.kwargs == {"workflow_id": "wpid_123", "version": 2}
def test_workflow_create_reads_definition_from_file(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture,
) -> None:
from skyvern.cli import workflow as workflow_cmd
definition_file = tmp_path / "workflow.json"
definition_text = '{"title": "Example", "workflow_definition": {"blocks": []}}'
definition_file.write_text(definition_text)
tool = AsyncMock(
return_value={
"ok": True,
"action": "skyvern_workflow_create",
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
"data": {"workflow_permanent_id": "wpid_new"},
"artifacts": [],
"timing_ms": {},
"warnings": [],
"error": None,
}
)
monkeypatch.setattr(workflow_cmd, "tool_workflow_create", tool)
workflow_cmd.workflow_create(
definition=f"@{definition_file}",
definition_format="json",
folder_id="fld_123",
json_output=True,
)
assert tool.await_args.kwargs == {
"definition": definition_text,
"format": "json",
"folder_id": "fld_123",
}
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is True
assert parsed["data"]["workflow_permanent_id"] == "wpid_new"
def test_workflow_run_reads_params_file_and_maps_options(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture,
) -> None:
from skyvern.cli import workflow as workflow_cmd
params_file = tmp_path / "params.json"
params_file.write_text('{"company": "Acme"}')
tool = AsyncMock(
return_value={
"ok": True,
"action": "skyvern_workflow_run",
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
"data": {"run_id": "wr_123", "status": "queued"},
"artifacts": [],
"timing_ms": {},
"warnings": [],
"error": None,
}
)
monkeypatch.setattr(workflow_cmd, "tool_workflow_run", tool)
workflow_cmd.workflow_run(
workflow_id="wpid_123",
params=f"@{params_file}",
session="pbs_456",
webhook="https://example.com/webhook",
proxy="RESIDENTIAL",
wait=True,
timeout=450,
json_output=True,
)
assert tool.await_args.kwargs == {
"workflow_id": "wpid_123",
"parameters": '{"company": "Acme"}',
"browser_session_id": "pbs_456",
"webhook_url": "https://example.com/webhook",
"proxy_location": "RESIDENTIAL",
"wait": True,
"timeout_seconds": 450,
}
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is True
assert parsed["data"]["run_id"] == "wr_123"
def test_workflow_status_json_error_exits_nonzero(
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
) -> None:
from skyvern.cli import workflow as workflow_cmd
tool = AsyncMock(
return_value={
"ok": False,
"action": "skyvern_workflow_status",
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
"data": None,
"artifacts": [],
"timing_ms": {},
"warnings": [],
"error": {
"code": "RUN_NOT_FOUND",
"message": "Run 'wr_missing' not found",
"hint": "Verify the run ID",
"details": {},
},
}
)
monkeypatch.setattr(workflow_cmd, "tool_workflow_status", tool)
with pytest.raises(SystemExit, match="1"):
workflow_cmd.workflow_status(run_id="wr_missing", json_output=True)
parsed = json.loads(capsys.readouterr().out)
assert parsed["ok"] is False
assert parsed["error"]["code"] == "RUN_NOT_FOUND"
def test_workflow_update_missing_definition_file_raises_bad_parameter(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
from skyvern.cli import workflow as workflow_cmd
tool = AsyncMock()
monkeypatch.setattr(workflow_cmd, "tool_workflow_update", tool)
missing_file = tmp_path / "missing-definition.json"
with pytest.raises(typer.BadParameter, match="Unable to read definition file"):
workflow_cmd.workflow_update(
workflow_id="wpid_123",
definition=f"@{missing_file}",
definition_format="json",
json_output=False,
)
tool.assert_not_called()

View File

@@ -0,0 +1,304 @@
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock
import httpx
import pytest
from fastapi import HTTPException
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from skyvern.cli.core import client as client_mod
from skyvern.cli.core import mcp_http_auth
@pytest.fixture(autouse=True)
def _reset_auth_context() -> None:
client_mod._api_key_override.set(None)
mcp_http_auth._auth_db = None
mcp_http_auth._api_key_validation_cache.clear()
mcp_http_auth._API_KEY_CACHE_TTL_SECONDS = 30.0
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
async def _echo_request_context(request: Request) -> JSONResponse:
return JSONResponse(
{
"api_key": client_mod.get_active_api_key(),
"organization_id": getattr(request.state, "organization_id", None),
}
)
def _build_test_app() -> Starlette:
return Starlette(
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
middleware=[Middleware(mcp_http_auth.MCPAPIKeyMiddleware)],
)
@pytest.mark.asyncio
async def test_mcp_http_auth_rejects_missing_api_key() -> None:
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", json={})
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert "x-api-key" in response.json()["error"]["message"]
@pytest.mark.asyncio
async def test_mcp_http_auth_allows_health_checks_without_api_key() -> None:
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.get("/healthz")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
async def test_mcp_http_auth_rejects_invalid_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "bad-key"}, json={})
assert response.status_code == 401
assert response.json()["error"]["code"] == "UNAUTHORIZED"
assert response.json()["error"]["message"] == "Invalid API key"
@pytest.mark.asyncio
async def test_mcp_http_auth_returns_500_on_non_auth_http_exception(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=HTTPException(status_code=500, detail="db down")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 500
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
@pytest.mark.asyncio
async def test_mcp_http_auth_returns_503_on_transient_validation_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=HTTPException(status_code=503, detail="API key validation temporarily unavailable")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 503
assert response.json()["error"]["code"] == "SERVICE_UNAVAILABLE"
assert response.json()["error"]["message"] == "API key validation temporarily unavailable"
@pytest.mark.asyncio
async def test_mcp_http_auth_returns_500_on_unexpected_validation_error(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(side_effect=RuntimeError("boom")),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 500
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
@pytest.mark.asyncio
async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
mcp_http_auth,
"validate_mcp_api_key",
AsyncMock(return_value="org_123"),
)
app = _build_test_app()
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
assert response.status_code == 200
assert response.json() == {
"api_key": "sk_live_abc",
"organization_id": "org_123",
}
assert client_mod.get_active_api_key() != "sk_live_abc"
@pytest.mark.asyncio
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached"))
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
assert first == "org_cached"
assert second == "org_cached"
assert calls == 1
@pytest.mark.asyncio
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}"))
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire")
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
assert first == "org_1"
assert second == "org_2"
assert calls == 2
@pytest.mark.asyncio
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
raise HTTPException(status_code=401, detail="Invalid credentials")
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
with pytest.raises(HTTPException, match="Invalid credentials"):
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
with pytest.raises(HTTPException, match="Invalid API key"):
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
assert calls == 1
@pytest.mark.asyncio
async def test_validate_mcp_api_key_retries_transient_failure_without_negative_cache(
monkeypatch: pytest.MonkeyPatch,
) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
if calls == 1:
raise RuntimeError("transient db error")
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered"))
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
cache_key = mcp_http_auth._cache_key("sk_test_transient")
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered"
assert recovered_org == "org_recovered"
assert calls == 2
@pytest.mark.asyncio
async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Multiple concurrent callers for the same key all succeed; the cache
collapses subsequent calls after the first one populates it."""
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent"))
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
assert all(r == "org_concurrent" for r in results)
# First call populates cache; remaining may or may not hit DB depending on
# scheduling, but all must succeed.
assert calls >= 1
@pytest.mark.asyncio
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
calls = 0
async def _resolve(_api_key: str, _db: object) -> object:
nonlocal calls
calls += 1
raise RuntimeError("persistent db outage")
monkeypatch.setattr(mcp_http_auth, "_MAX_VALIDATION_RETRIES", 2)
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
with pytest.raises(HTTPException, match="temporarily unavailable") as exc_info:
await mcp_http_auth.validate_mcp_api_key("sk_test_transient_exhausted")
assert exc_info.value.status_code == 503
assert calls == 3 # initial + 2 retries
@pytest.mark.asyncio
async def test_close_auth_db_disposes_engine() -> None:
dispose = AsyncMock()
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
mcp_http_auth._api_key_validation_cache["k"] = ("org", 123.0)
await mcp_http_auth.close_auth_db()
dispose.assert_awaited_once()
assert mcp_http_auth._auth_db is None
assert mcp_http_auth._api_key_validation_cache == {}
@pytest.mark.asyncio
async def test_close_auth_db_noop_when_uninitialized() -> None:
mcp_http_auth._auth_db = None
await mcp_http_auth.close_auth_db()
assert mcp_http_auth._auth_db is None

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
@@ -14,10 +15,13 @@ from skyvern.cli.mcp_tools import session as mcp_session
@pytest.fixture(autouse=True)
def _reset_singletons() -> None:
client_mod._skyvern_instance.set(None)
client_mod._api_key_override.set(None)
client_mod._global_skyvern_instance = None
client_mod._api_key_clients.clear()
session_manager._current_session.set(None)
session_manager._global_session = None
session_manager.set_stateless_http_mode(False)
mcp_session.set_current_session(mcp_session.SessionState())
@@ -47,6 +51,115 @@ def test_get_skyvern_reuses_global_instance_across_contexts(monkeypatch: pytest.
assert len(created) == 1
def test_get_skyvern_reuses_override_instance_per_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
created_keys: list[str] = []
class FakeSkyvern:
def __init__(self, *args: object, **kwargs: object) -> None:
created_keys.append(kwargs["api_key"])
@classmethod
def local(cls) -> FakeSkyvern:
return cls(api_key="local")
async def aclose(self) -> None:
return None
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
token = client_mod.set_api_key_override("sk_key_a")
try:
first = client_mod.get_skyvern()
client_mod._skyvern_instance.set(None)
second = client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
assert first is second
assert created_keys == ["sk_key_a"]
def test_get_skyvern_override_client_cache_uses_lru_eviction(monkeypatch: pytest.MonkeyPatch) -> None:
created_keys: list[str] = []
class FakeSkyvern:
def __init__(self, *args: object, **kwargs: object) -> None:
created_keys.append(kwargs["api_key"])
@classmethod
def local(cls) -> FakeSkyvern:
return cls(api_key="local")
async def aclose(self) -> None:
return None
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 2)
for key in ("sk_key_a", "sk_key_b"):
token = client_mod.set_api_key_override(key)
try:
client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
# Touch key_a so key_b becomes least-recently-used.
token = client_mod.set_api_key_override("sk_key_a")
try:
client_mod._skyvern_instance.set(None)
client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
# Adding key_c should evict key_b.
token = client_mod.set_api_key_override("sk_key_c")
try:
client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
assert list(client_mod._api_key_clients.keys()) == [
client_mod._cache_key("sk_key_a"),
client_mod._cache_key("sk_key_c"),
]
# key_a, key_b, key_c were created exactly once each.
assert created_keys == ["sk_key_a", "sk_key_b", "sk_key_c"]
def test_get_skyvern_override_cache_closes_evicted_client(monkeypatch: pytest.MonkeyPatch) -> None:
closed_keys: list[str] = []
class FakeSkyvern:
def __init__(self, *args: object, **kwargs: object) -> None:
self.api_key = kwargs["api_key"]
@classmethod
def local(cls) -> FakeSkyvern:
return cls(api_key="local")
async def aclose(self) -> None:
closed_keys.append(self.api_key)
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 1)
for key in ("sk_key_a", "sk_key_b"):
token = client_mod.set_api_key_override(key)
try:
client_mod.get_skyvern()
finally:
client_mod.reset_api_key_override(token)
assert list(client_mod._api_key_clients.keys()) == [client_mod._cache_key("sk_key_b")]
assert closed_keys == ["sk_key_a"]
@pytest.mark.asyncio
async def test_close_skyvern_closes_singleton() -> None:
fake = MagicMock()
@@ -75,12 +188,53 @@ def test_get_current_session_falls_back_to_global_state() -> None:
assert recovered is state
def test_get_current_session_stateless_mode_ignores_global_state() -> None:
global_state = session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_999"),
)
session_manager._global_session = global_state
session_manager._current_session.set(None)
session_manager.set_stateless_http_mode(True)
try:
recovered = session_manager.get_current_session()
finally:
session_manager.set_stateless_http_mode(False)
assert recovered is not global_state
assert recovered.browser is None
assert recovered.context is None
def test_set_current_session_stateless_mode_does_not_override_global_state() -> None:
global_state = session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
)
session_manager._global_session = global_state
replacement = session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_request"),
)
session_manager.set_stateless_http_mode(True)
try:
session_manager.set_current_session(replacement)
finally:
session_manager.set_stateless_http_mode(False)
assert session_manager._global_session is global_state
assert session_manager._current_session.get() is replacement
@pytest.mark.asyncio
async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None:
current_browser = MagicMock()
current_state = session_manager.SessionState(
browser=current_browser,
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
api_key_hash=session_manager._api_key_hash(client_mod.get_active_api_key()),
)
session_manager.set_current_session(current_state)
@@ -95,6 +249,92 @@ async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest
fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_browser_does_not_reuse_session_for_different_api_key(
monkeypatch: pytest.MonkeyPatch,
) -> None:
current_browser = MagicMock()
session_manager.set_current_session(
session_manager.SessionState(
browser=current_browser,
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
api_key_hash=session_manager._api_key_hash("sk_key_a"),
)
)
replacement_browser = MagicMock()
fake_skyvern = MagicMock()
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
token = client_mod.set_api_key_override("sk_key_b")
try:
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
finally:
client_mod.reset_api_key_override(token)
assert browser is replacement_browser
assert ctx.session_id == "pbs_123"
fake_skyvern.connect_to_cloud_browser_session.assert_awaited_once_with("pbs_123")
@pytest.mark.asyncio
async def test_resolve_browser_stateless_mode_does_not_write_global_session(monkeypatch: pytest.MonkeyPatch) -> None:
global_state = session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
)
session_manager._global_session = global_state
replacement_browser = MagicMock()
fake_skyvern = MagicMock()
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
session_manager.set_stateless_http_mode(True)
try:
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
finally:
session_manager.set_stateless_http_mode(False)
assert browser is replacement_browser
assert ctx.session_id == "pbs_123"
assert session_manager._global_session is global_state
@pytest.mark.asyncio
async def test_resolve_browser_blocks_implicit_session_in_stateless_mode() -> None:
session_manager.set_current_session(
session_manager.SessionState(
browser=MagicMock(),
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
api_key_hash=session_manager._api_key_hash("sk_key_a"),
)
)
session_manager.set_stateless_http_mode(True)
try:
with pytest.raises(session_manager.BrowserNotAvailableError):
await session_manager.resolve_browser()
finally:
session_manager.set_stateless_http_mode(False)
@pytest.mark.asyncio
async def test_resolve_browser_raises_for_invalid_matching_state(monkeypatch: pytest.MonkeyPatch) -> None:
session_manager.set_current_session(
session_manager.SessionState(
browser=None,
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
)
)
monkeypatch.setattr(session_manager, "_matches_current", lambda *args, **kwargs: True)
monkeypatch.setattr(session_manager, "get_skyvern", lambda: MagicMock())
with pytest.raises(RuntimeError, match="Expected active browser and context"):
await session_manager.resolve_browser(session_id="pbs_123")
@pytest.mark.asyncio
async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None:
current_browser = MagicMock()
@@ -262,3 +502,73 @@ async def test_close_current_session_still_closes_browser_when_api_fails(monkeyp
# _browser_session_id should NOT be cleared (API close failed, let browser.close() try)
assert browser._browser_session_id == "pbs_fail"
assert session_manager.get_current_session().browser is None
# ---------------------------------------------------------------------------
# Tests for stateless HTTP mode session creation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_create_stateless_mode_returns_session_without_persisting_browser(
monkeypatch: pytest.MonkeyPatch,
) -> None:
session_manager.set_stateless_http_mode(True)
fake_skyvern = MagicMock()
fake_skyvern.create_browser_session = AsyncMock(return_value=SimpleNamespace(browser_session_id="pbs_abc"))
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
do_session_create = AsyncMock()
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
try:
result = await mcp_session.skyvern_session_create(timeout=45)
finally:
session_manager.set_stateless_http_mode(False)
assert result["ok"] is True
assert result["data"] == {"session_id": "pbs_abc", "timeout_minutes": 45}
do_session_create.assert_not_awaited()
assert mcp_session.get_current_session().browser is None
assert mcp_session.get_current_session().context is None
@pytest.mark.asyncio
async def test_session_create_stateless_mode_rejects_local() -> None:
session_manager.set_stateless_http_mode(True)
try:
result = await mcp_session.skyvern_session_create(local=True)
finally:
session_manager.set_stateless_http_mode(False)
assert result["ok"] is False
assert result["error"]["code"] == mcp_session.ErrorCode.INVALID_INPUT
@pytest.mark.asyncio
async def test_session_create_persists_active_api_key_hash_in_session_state(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_skyvern = MagicMock()
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
fake_browser = MagicMock()
do_session_create = AsyncMock(
return_value=(
fake_browser,
SimpleNamespace(local=False, session_id="pbs_123", timeout_minutes=60, headless=False),
)
)
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
token = client_mod.set_api_key_override("sk_key_create")
try:
result = await mcp_session.skyvern_session_create(timeout=60)
finally:
client_mod.reset_api_key_override(token)
assert result["ok"] is True
current = mcp_session.get_current_session()
assert current.browser is fake_browser
assert current.context == BrowserContext(mode="cloud_session", session_id="pbs_123")
assert current.api_key_hash == session_manager._api_key_hash("sk_key_create")
assert current.api_key_hash != "sk_key_create"

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, call
import pytest
@@ -13,6 +13,42 @@ def _reset_cleanup_state() -> None:
run_commands._mcp_cleanup_done = False
@pytest.mark.asyncio
async def test_cleanup_mcp_resources_closes_auth_db(monkeypatch: pytest.MonkeyPatch) -> None:
close_current_session = AsyncMock()
close_skyvern = AsyncMock()
close_auth_db = AsyncMock()
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
monkeypatch.setattr(run_commands, "close_skyvern", close_skyvern)
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
await run_commands._cleanup_mcp_resources()
close_current_session.assert_awaited_once()
close_skyvern.assert_awaited_once()
close_auth_db.assert_awaited_once()
@pytest.mark.asyncio
async def test_cleanup_mcp_resources_closes_auth_db_on_skyvern_close_error(monkeypatch: pytest.MonkeyPatch) -> None:
close_current_session = AsyncMock()
close_auth_db = AsyncMock()
async def _failing_close_skyvern() -> None:
raise RuntimeError("close failed")
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
monkeypatch.setattr(run_commands, "close_skyvern", _failing_close_skyvern)
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
with pytest.raises(RuntimeError, match="close failed"):
await run_commands._cleanup_mcp_resources()
close_current_session.assert_awaited_once()
close_auth_db.assert_awaited_once()
def test_cleanup_mcp_resources_sync_runs_without_running_loop(monkeypatch: pytest.MonkeyPatch) -> None:
cleanup = AsyncMock()
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup)
@@ -50,14 +86,56 @@ def test_run_mcp_calls_blocking_cleanup_in_finally(monkeypatch: pytest.MonkeyPat
cleanup_blocking = MagicMock()
register = MagicMock()
run = MagicMock(side_effect=RuntimeError("boom"))
set_stateless = MagicMock()
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
monkeypatch.setattr(run_commands.atexit, "register", register)
monkeypatch.setattr(run_commands.mcp, "run", run)
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
with pytest.raises(RuntimeError, match="boom"):
run_commands.run_mcp()
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
run.assert_called_once_with(transport="stdio")
set_stateless.assert_has_calls([call(False), call(False)])
cleanup_blocking.assert_called_once()
def test_run_mcp_http_transport_wires_auth_middleware(monkeypatch: pytest.MonkeyPatch) -> None:
cleanup_blocking = MagicMock()
register = MagicMock()
run = MagicMock()
set_stateless = MagicMock()
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
monkeypatch.setattr(run_commands.atexit, "register", register)
monkeypatch.setattr(run_commands.mcp, "run", run)
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
run_commands.run_mcp(
transport="streamable-http",
host="127.0.0.1",
port=9010,
path="mcp",
stateless_http=True,
)
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
run.assert_called_once()
kwargs = run.call_args.kwargs
assert kwargs["transport"] == "streamable-http"
assert kwargs["host"] == "127.0.0.1"
assert kwargs["port"] == 9010
assert kwargs["path"] == "/mcp"
assert kwargs["stateless_http"] is True
middleware = kwargs["middleware"]
assert len(middleware) == 1
assert middleware[0].cls is run_commands.MCPAPIKeyMiddleware
set_stateless.assert_has_calls([call(True), call(False)])
cleanup_blocking.assert_called_once()
def test_run_task_tool_registration_points_to_browser_module() -> None:
tool = run_commands.mcp._tool_manager._tools["skyvern_run_task"] # type: ignore[attr-defined]
assert tool.fn.__module__ == "skyvern.cli.mcp_tools.browser"