align workflow CLI commands with MCP parity (#4792)
This commit is contained in:
@@ -68,6 +68,30 @@ Add this to your MCP client's configuration file:
|
||||
|
||||
Replace `/usr/bin/python3` with the output of `which python3` on your machine. For local mode, set `SKYVERN_BASE_URL` to `http://localhost:8000` and find your API key in the `.env` file after running `skyvern init`.
|
||||
|
||||
### Option C: Remote MCP over HTTPS (streamable HTTP)
|
||||
|
||||
Use this when your team provides a hosted MCP endpoint (for example: `https://mcp.skyvern.com/mcp`).
|
||||
|
||||
In remote HTTP mode:
|
||||
- Clients must send `x-api-key` on every request.
|
||||
- Use `skyvern_session_create` first, then pass `session_id` explicitly on subsequent browser tool calls.
|
||||
|
||||
If your MCP client supports native remote HTTP transport, configure it directly:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"SkyvernRemote": {
|
||||
"type": "streamable-http",
|
||||
"url": "https://mcp.skyvern.com/mcp",
|
||||
"headers": {
|
||||
"x-api-key": "YOUR_SKYVERN_API_KEY"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<Accordion title="Config file locations by client">
|
||||
|
||||
| Client | Path |
|
||||
|
||||
29
skyvern/cli/core/api_key_hash.py
Normal file
29
skyvern/cli/core/api_key_hash.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
|
||||
def _resolve_api_key_hash_iterations() -> int:
|
||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_HASH_ITERATIONS", "120000")
|
||||
try:
|
||||
return max(10_000, int(raw))
|
||||
except ValueError:
|
||||
return 120_000
|
||||
|
||||
|
||||
_API_KEY_HASH_ITERATIONS = _resolve_api_key_hash_iterations()
|
||||
_API_KEY_HASH_SALT = os.environ.get(
|
||||
"SKYVERN_MCP_API_KEY_HASH_SALT",
|
||||
"skyvern-mcp-api-key-cache-v1",
|
||||
).encode("utf-8")
|
||||
|
||||
|
||||
def hash_api_key_for_cache(api_key: str) -> str:
|
||||
"""Derive a deterministic, non-reversible fingerprint for API-key keyed caches."""
|
||||
return hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
api_key.encode("utf-8"),
|
||||
_API_KEY_HASH_SALT,
|
||||
_API_KEY_HASH_ITERATIONS,
|
||||
).hex()
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from collections import OrderedDict
|
||||
from contextvars import ContextVar, Token
|
||||
from threading import RLock
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -9,35 +12,126 @@ from skyvern.client import SkyvernEnvironment
|
||||
from skyvern.config import settings
|
||||
from skyvern.library.skyvern import Skyvern
|
||||
|
||||
from .api_key_hash import hash_api_key_for_cache
|
||||
|
||||
_skyvern_instance: ContextVar[Skyvern | None] = ContextVar("skyvern_instance", default=None)
|
||||
_api_key_override: ContextVar[str | None] = ContextVar("skyvern_api_key_override", default=None)
|
||||
_global_skyvern_instance: Skyvern | None = None
|
||||
_api_key_clients: OrderedDict[str, Skyvern] = OrderedDict()
|
||||
_clients_lock = RLock()
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
def _resolve_api_key_cache_size() -> int:
|
||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_CLIENT_CACHE_SIZE", "128")
|
||||
try:
|
||||
return max(1, int(raw))
|
||||
except ValueError:
|
||||
return 128
|
||||
|
||||
|
||||
_API_KEY_CLIENT_CACHE_MAX = _resolve_api_key_cache_size()
|
||||
|
||||
|
||||
def _cache_key(api_key: str) -> str:
|
||||
"""Hash API key so raw secrets are never stored as dict keys."""
|
||||
return hash_api_key_for_cache(api_key)
|
||||
|
||||
|
||||
def _resolve_api_key() -> str | None:
|
||||
return settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
|
||||
|
||||
|
||||
def _resolve_base_url() -> str | None:
|
||||
return settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
|
||||
|
||||
|
||||
def _build_cloud_client(api_key: str) -> Skyvern:
|
||||
return Skyvern(
|
||||
api_key=api_key,
|
||||
environment=SkyvernEnvironment.CLOUD,
|
||||
base_url=_resolve_base_url(),
|
||||
)
|
||||
|
||||
|
||||
def _close_skyvern_instance_best_effort(instance: Skyvern) -> None:
|
||||
"""Close a Skyvern instance, regardless of whether an event loop is running."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
try:
|
||||
asyncio.run(instance.aclose())
|
||||
except Exception:
|
||||
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
|
||||
return
|
||||
|
||||
task = loop.create_task(instance.aclose())
|
||||
|
||||
def _on_done(done: asyncio.Task[None]) -> None:
|
||||
try:
|
||||
done.result()
|
||||
except Exception:
|
||||
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
|
||||
|
||||
task.add_done_callback(_on_done)
|
||||
|
||||
|
||||
def get_active_api_key() -> str | None:
|
||||
"""Return the effective API key for this request/context."""
|
||||
return _api_key_override.get() or _resolve_api_key()
|
||||
|
||||
|
||||
def set_api_key_override(api_key: str | None) -> Token[str | None]:
|
||||
"""Set request-scoped API key override for MCP HTTP requests."""
|
||||
_skyvern_instance.set(None)
|
||||
return _api_key_override.set(api_key)
|
||||
|
||||
|
||||
def reset_api_key_override(token: Token[str | None]) -> None:
|
||||
"""Reset request-scoped API key override."""
|
||||
_api_key_override.reset(token)
|
||||
_skyvern_instance.set(None)
|
||||
|
||||
|
||||
def get_skyvern() -> Skyvern:
|
||||
"""Get or create a Skyvern client instance."""
|
||||
global _global_skyvern_instance
|
||||
|
||||
instance = _skyvern_instance.get()
|
||||
if instance is None:
|
||||
instance = _global_skyvern_instance
|
||||
if instance is not None:
|
||||
override_api_key = _api_key_override.get()
|
||||
if override_api_key:
|
||||
instance = _skyvern_instance.get()
|
||||
if instance is None:
|
||||
key = _cache_key(override_api_key)
|
||||
evicted_clients: list[Skyvern] = []
|
||||
# Hold lock across lookup + build + insert to prevent two coroutines
|
||||
# from both building a client for the same API key concurrently.
|
||||
with _clients_lock:
|
||||
instance = _api_key_clients.get(key)
|
||||
if instance is not None:
|
||||
_api_key_clients.move_to_end(key)
|
||||
else:
|
||||
instance = _build_cloud_client(override_api_key)
|
||||
_api_key_clients[key] = instance
|
||||
_api_key_clients.move_to_end(key)
|
||||
while len(_api_key_clients) > _API_KEY_CLIENT_CACHE_MAX:
|
||||
_, evicted = _api_key_clients.popitem(last=False)
|
||||
evicted_clients.append(evicted)
|
||||
for evicted in evicted_clients:
|
||||
_close_skyvern_instance_best_effort(evicted)
|
||||
_skyvern_instance.set(instance)
|
||||
return instance
|
||||
|
||||
api_key = settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
|
||||
base_url = settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
|
||||
|
||||
if api_key:
|
||||
instance = Skyvern(
|
||||
api_key=api_key,
|
||||
environment=SkyvernEnvironment.CLOUD,
|
||||
base_url=base_url,
|
||||
)
|
||||
else:
|
||||
instance = Skyvern.local()
|
||||
|
||||
_global_skyvern_instance = instance
|
||||
instance = _skyvern_instance.get()
|
||||
if instance is None:
|
||||
with _clients_lock:
|
||||
instance = _global_skyvern_instance
|
||||
if instance is None:
|
||||
api_key = _resolve_api_key()
|
||||
if api_key:
|
||||
instance = _build_cloud_client(api_key)
|
||||
else:
|
||||
instance = Skyvern.local()
|
||||
_global_skyvern_instance = instance
|
||||
_skyvern_instance.set(instance)
|
||||
return instance
|
||||
|
||||
@@ -48,7 +142,12 @@ async def close_skyvern() -> None:
|
||||
|
||||
instances: list[Skyvern] = []
|
||||
seen: set[int] = set()
|
||||
for candidate in (_skyvern_instance.get(), _global_skyvern_instance):
|
||||
with _clients_lock:
|
||||
candidates = (_skyvern_instance.get(), _global_skyvern_instance, *_api_key_clients.values())
|
||||
_api_key_clients.clear()
|
||||
_global_skyvern_instance = None
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate is None or id(candidate) in seen:
|
||||
continue
|
||||
seen.add(id(candidate))
|
||||
@@ -61,4 +160,3 @@ async def close_skyvern() -> None:
|
||||
LOG.warning("Failed to close Skyvern client", exc_info=True)
|
||||
|
||||
_skyvern_instance.set(None)
|
||||
_global_skyvern_instance = None
|
||||
|
||||
206
skyvern/cli/core/mcp_http_auth.py
Normal file
206
skyvern/cli/core/mcp_http_auth.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from threading import RLock
|
||||
|
||||
import structlog
|
||||
from fastapi import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
|
||||
|
||||
from .api_key_hash import hash_api_key_for_cache
|
||||
from .client import reset_api_key_override, set_api_key_override
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
API_KEY_HEADER = "x-api-key"
|
||||
HEALTH_PATHS = {"/health", "/healthz"}
|
||||
_auth_db: AgentDB | None = None
|
||||
_auth_db_lock = RLock()
|
||||
_api_key_cache_lock = RLock()
|
||||
_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict()
|
||||
_NEGATIVE_CACHE_TTL_SECONDS = 5.0
|
||||
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
|
||||
_MAX_VALIDATION_RETRIES = 2
|
||||
_RETRY_DELAY_SECONDS = 0.25
|
||||
|
||||
|
||||
def _resolve_api_key_cache_ttl_seconds() -> float:
|
||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
|
||||
try:
|
||||
return max(1.0, float(raw))
|
||||
except ValueError:
|
||||
return 30.0
|
||||
|
||||
|
||||
def _resolve_api_key_cache_max_size() -> int:
|
||||
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_MAX_SIZE", "1024")
|
||||
try:
|
||||
return max(1, int(raw))
|
||||
except ValueError:
|
||||
return 1024
|
||||
|
||||
|
||||
_API_KEY_CACHE_TTL_SECONDS = _resolve_api_key_cache_ttl_seconds()
|
||||
_API_KEY_CACHE_MAX_SIZE = _resolve_api_key_cache_max_size()
|
||||
|
||||
|
||||
def _get_auth_db() -> AgentDB:
|
||||
global _auth_db
|
||||
# Guard singleton init in case HTTP transport is served with threaded workers.
|
||||
with _auth_db_lock:
|
||||
if _auth_db is None:
|
||||
_auth_db = AgentDB(settings.DATABASE_STRING, debug_enabled=settings.DEBUG_MODE)
|
||||
return _auth_db
|
||||
|
||||
|
||||
async def close_auth_db() -> None:
|
||||
"""Dispose the auth DB engine used by HTTP middleware, if initialized."""
|
||||
global _auth_db
|
||||
with _auth_db_lock:
|
||||
db = _auth_db
|
||||
_auth_db = None
|
||||
with _api_key_cache_lock:
|
||||
_api_key_validation_cache.clear()
|
||||
if db is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await db.engine.dispose()
|
||||
except Exception:
|
||||
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
|
||||
|
||||
|
||||
def _cache_key(api_key: str) -> str:
|
||||
return hash_api_key_for_cache(api_key)
|
||||
|
||||
|
||||
async def validate_mcp_api_key(api_key: str) -> str:
|
||||
"""Validate API key and return organization id for observability."""
|
||||
key = _cache_key(api_key)
|
||||
|
||||
# Check cache first.
|
||||
with _api_key_cache_lock:
|
||||
cached = _api_key_validation_cache.get(key)
|
||||
if cached is not None:
|
||||
organization_id, expires_at = cached
|
||||
if expires_at > time.monotonic():
|
||||
_api_key_validation_cache.move_to_end(key)
|
||||
if organization_id is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return organization_id
|
||||
_api_key_validation_cache.pop(key, None)
|
||||
|
||||
# Cache miss — do the DB lookup with simple retry on transient errors.
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_MAX_VALIDATION_RETRIES + 1):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(_RETRY_DELAY_SECONDS)
|
||||
try:
|
||||
validation = await resolve_org_from_api_key(api_key, _get_auth_db())
|
||||
organization_id = validation.organization.organization_id
|
||||
with _api_key_cache_lock:
|
||||
_api_key_validation_cache[key] = (
|
||||
organization_id,
|
||||
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
|
||||
)
|
||||
_api_key_validation_cache.move_to_end(key)
|
||||
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
|
||||
_api_key_validation_cache.popitem(last=False)
|
||||
return organization_id
|
||||
except HTTPException as e:
|
||||
if e.status_code in {401, 403}:
|
||||
with _api_key_cache_lock:
|
||||
_api_key_validation_cache[key] = (None, time.monotonic() + _NEGATIVE_CACHE_TTL_SECONDS)
|
||||
_api_key_validation_cache.move_to_end(key)
|
||||
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
|
||||
_api_key_validation_cache.popitem(last=False)
|
||||
raise
|
||||
last_exc = e
|
||||
except Exception as e:
|
||||
last_exc = e
|
||||
|
||||
LOG.warning("API key validation retries exhausted", attempts=_MAX_VALIDATION_RETRIES + 1, exc_info=last_exc)
|
||||
raise HTTPException(status_code=503, detail=_VALIDATION_RETRY_EXHAUSTED_MESSAGE)
|
||||
|
||||
|
||||
def _unauthorized_response(message: str) -> JSONResponse:
|
||||
return JSONResponse({"error": {"code": "UNAUTHORIZED", "message": message}}, status_code=401)
|
||||
|
||||
|
||||
def _internal_error_response() -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
def _service_unavailable_response(message: str) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{"error": {"code": "SERVICE_UNAVAILABLE", "message": message}},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
|
||||
class MCPAPIKeyMiddleware:
|
||||
"""Require x-api-key for MCP HTTP transport and scope requests to that key."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope, receive=receive)
|
||||
if request.url.path in HEALTH_PATHS:
|
||||
response = JSONResponse({"status": "ok"})
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
api_key = request.headers.get(API_KEY_HEADER)
|
||||
if not api_key:
|
||||
response = _unauthorized_response("Missing x-api-key header")
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
try:
|
||||
organization_id = await validate_mcp_api_key(api_key)
|
||||
scope.setdefault("state", {})
|
||||
scope["state"]["organization_id"] = organization_id
|
||||
except HTTPException as e:
|
||||
if e.status_code in {401, 403}:
|
||||
response = _unauthorized_response("Invalid API key")
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
if e.status_code == 503:
|
||||
response = _service_unavailable_response(e.detail or _VALIDATION_RETRY_EXHAUSTED_MESSAGE)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
LOG.warning("Unexpected HTTPException during MCP API key validation", status_code=e.status_code)
|
||||
response = _internal_error_response()
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
except Exception:
|
||||
LOG.exception("Unexpected MCP API key validation failure")
|
||||
response = _internal_error_response()
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
token = set_api_key_override(api_key)
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
finally:
|
||||
reset_api_key_override(token)
|
||||
@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator
|
||||
|
||||
import structlog
|
||||
|
||||
from .client import get_skyvern
|
||||
from .api_key_hash import hash_api_key_for_cache
|
||||
from .client import get_active_api_key, get_skyvern
|
||||
from .result import BrowserContext, ErrorCode, make_error
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
|
||||
class SessionState:
|
||||
browser: SkyvernBrowser | None = None
|
||||
context: BrowserContext | None = None
|
||||
api_key_hash: str | None = None
|
||||
console_messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
tracing_active: bool = False
|
||||
har_enabled: bool = False
|
||||
@@ -28,26 +30,52 @@ class SessionState:
|
||||
|
||||
_current_session: ContextVar[SessionState | None] = ContextVar("mcp_session", default=None)
|
||||
_global_session: SessionState | None = None
|
||||
_stateless_http_mode = False
|
||||
|
||||
|
||||
def get_current_session() -> SessionState:
|
||||
global _global_session
|
||||
|
||||
state = _current_session.get()
|
||||
if state is None:
|
||||
if _global_session is None:
|
||||
_global_session = SessionState()
|
||||
state = _global_session
|
||||
if state is not None:
|
||||
return state
|
||||
|
||||
# In stateless HTTP mode, avoid process-wide fallback state so requests
|
||||
# cannot inherit session context from other requests.
|
||||
if _stateless_http_mode:
|
||||
state = SessionState()
|
||||
_current_session.set(state)
|
||||
return state
|
||||
|
||||
if _global_session is None:
|
||||
_global_session = SessionState()
|
||||
state = _global_session
|
||||
_current_session.set(state)
|
||||
return state
|
||||
|
||||
|
||||
def set_current_session(state: SessionState) -> None:
|
||||
global _global_session
|
||||
_global_session = state
|
||||
if not _stateless_http_mode:
|
||||
_global_session = state
|
||||
_current_session.set(state)
|
||||
|
||||
|
||||
def set_stateless_http_mode(enabled: bool) -> None:
|
||||
global _stateless_http_mode
|
||||
_stateless_http_mode = enabled
|
||||
|
||||
|
||||
def is_stateless_http_mode() -> bool:
|
||||
return _stateless_http_mode
|
||||
|
||||
|
||||
def _api_key_hash(api_key: str | None) -> str | None:
|
||||
if not api_key:
|
||||
return None
|
||||
return hash_api_key_for_cache(api_key)
|
||||
|
||||
|
||||
def _matches_current(
|
||||
current: SessionState,
|
||||
*,
|
||||
@@ -57,6 +85,8 @@ def _matches_current(
|
||||
) -> bool:
|
||||
if current.browser is None or current.context is None:
|
||||
return False
|
||||
if current.api_key_hash != _api_key_hash(get_active_api_key()):
|
||||
return False
|
||||
|
||||
if session_id:
|
||||
return current.context.mode == "cloud_session" and current.context.session_id == session_id
|
||||
@@ -84,35 +114,39 @@ async def resolve_browser(
|
||||
skyvern = get_skyvern()
|
||||
current = get_current_session()
|
||||
|
||||
if _stateless_http_mode and not (session_id or cdp_url or local or create_session):
|
||||
raise BrowserNotAvailableError()
|
||||
|
||||
if _matches_current(current, session_id=session_id, cdp_url=cdp_url, local=local):
|
||||
# _matches_current() guarantees both are non-None
|
||||
assert current.browser is not None and current.context is not None
|
||||
if current.browser is None or current.context is None:
|
||||
raise RuntimeError("Expected active browser and context for matching session")
|
||||
return current.browser, current.context
|
||||
|
||||
active_api_key_hash = _api_key_hash(get_active_api_key())
|
||||
browser: SkyvernBrowser | None = None
|
||||
try:
|
||||
if session_id:
|
||||
browser = await skyvern.connect_to_cloud_browser_session(session_id)
|
||||
ctx = BrowserContext(mode="cloud_session", session_id=session_id)
|
||||
set_current_session(SessionState(browser=browser, context=ctx))
|
||||
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||
return browser, ctx
|
||||
|
||||
if cdp_url:
|
||||
browser = await skyvern.connect_to_browser_over_cdp(cdp_url)
|
||||
ctx = BrowserContext(mode="cdp", cdp_url=cdp_url)
|
||||
set_current_session(SessionState(browser=browser, context=ctx))
|
||||
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||
return browser, ctx
|
||||
|
||||
if local:
|
||||
browser = await skyvern.launch_local_browser(headless=headless)
|
||||
ctx = BrowserContext(mode="local")
|
||||
set_current_session(SessionState(browser=browser, context=ctx))
|
||||
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||
return browser, ctx
|
||||
|
||||
if create_session:
|
||||
browser = await skyvern.launch_cloud_browser(timeout=timeout)
|
||||
ctx = BrowserContext(mode="cloud_session", session_id=browser.browser_session_id)
|
||||
set_current_session(SessionState(browser=browser, context=ctx))
|
||||
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||
return browser, ctx
|
||||
except Exception:
|
||||
if browser is not None:
|
||||
|
||||
@@ -4,7 +4,11 @@ from typing import Annotated, Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from skyvern.cli.core.api_key_hash import hash_api_key_for_cache
|
||||
from skyvern.cli.core.client import get_active_api_key
|
||||
from skyvern.cli.core.session_manager import is_stateless_http_mode
|
||||
from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list
|
||||
from skyvern.schemas.runs import ProxyLocation
|
||||
|
||||
from ._common import BrowserContext, ErrorCode, Timer, make_error, make_result
|
||||
from ._session import (
|
||||
@@ -16,6 +20,13 @@ from ._session import (
|
||||
)
|
||||
|
||||
|
||||
def _session_api_key_hash() -> str | None:
|
||||
api_key = get_active_api_key()
|
||||
if not api_key:
|
||||
return None
|
||||
return hash_api_key_for_cache(api_key)
|
||||
|
||||
|
||||
async def skyvern_session_create(
|
||||
timeout: Annotated[int | None, Field(description="Session timeout in minutes (5-1440)")] = 60,
|
||||
proxy_location: Annotated[str | None, Field(description="Proxy location: RESIDENTIAL, US, etc.")] = None,
|
||||
@@ -29,7 +40,33 @@ async def skyvern_session_create(
|
||||
"""
|
||||
with Timer() as timer:
|
||||
try:
|
||||
if is_stateless_http_mode() and local:
|
||||
return make_result(
|
||||
"skyvern_session_create",
|
||||
ok=False,
|
||||
error=make_error(
|
||||
ErrorCode.INVALID_INPUT,
|
||||
"Local browser sessions are not supported in stateless HTTP mode",
|
||||
"Use cloud sessions for remote MCP transport",
|
||||
),
|
||||
)
|
||||
|
||||
skyvern = get_skyvern()
|
||||
if is_stateless_http_mode():
|
||||
proxy = ProxyLocation(proxy_location) if proxy_location else None
|
||||
session = await skyvern.create_browser_session(timeout=timeout or 60, proxy_location=proxy)
|
||||
timer.mark("sdk")
|
||||
ctx = BrowserContext(mode="cloud_session", session_id=session.browser_session_id)
|
||||
return make_result(
|
||||
"skyvern_session_create",
|
||||
browser_context=ctx,
|
||||
data={
|
||||
"session_id": session.browser_session_id,
|
||||
"timeout_minutes": timeout or 60,
|
||||
},
|
||||
timing_ms=timer.timing_ms,
|
||||
)
|
||||
|
||||
browser, result = await do_session_create(
|
||||
skyvern,
|
||||
timeout=timeout or 60,
|
||||
@@ -43,7 +80,7 @@ async def skyvern_session_create(
|
||||
ctx = BrowserContext(mode="local")
|
||||
else:
|
||||
ctx = BrowserContext(mode="cloud_session", session_id=result.session_id)
|
||||
set_current_session(SessionState(browser=browser, context=ctx))
|
||||
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=_session_api_key_hash()))
|
||||
|
||||
except ValueError as e:
|
||||
return make_result(
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, List, Optional
|
||||
from typing import Annotated, List, Literal, Optional
|
||||
|
||||
import psutil
|
||||
import typer
|
||||
@@ -13,17 +13,17 @@ import uvicorn
|
||||
from dotenv import load_dotenv, set_key
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
from starlette.middleware import Middleware
|
||||
|
||||
from skyvern.cli.console import console
|
||||
from skyvern.cli.core.client import close_skyvern
|
||||
from skyvern.cli.core.session_manager import close_current_session
|
||||
from skyvern.cli.core.mcp_http_auth import MCPAPIKeyMiddleware, close_auth_db
|
||||
from skyvern.cli.core.session_manager import close_current_session, set_stateless_http_mode
|
||||
from skyvern.cli.mcp_tools import mcp # Uses standalone fastmcp (v2.x)
|
||||
from skyvern.cli.utils import start_services
|
||||
from skyvern.client import SkyvernEnvironment
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.forge_log import setup_logger
|
||||
from skyvern.library.skyvern import Skyvern
|
||||
from skyvern.services.script_service import run_script
|
||||
from skyvern.utils import detect_os
|
||||
from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path
|
||||
@@ -36,7 +36,10 @@ async def _cleanup_mcp_resources() -> None:
|
||||
try:
|
||||
await close_current_session()
|
||||
finally:
|
||||
await close_skyvern()
|
||||
try:
|
||||
await close_skyvern()
|
||||
finally:
|
||||
await close_auth_db()
|
||||
|
||||
|
||||
def _cleanup_mcp_resources_blocking() -> None:
|
||||
@@ -67,39 +70,6 @@ def _cleanup_mcp_resources_sync() -> None:
|
||||
logger.debug("Skipping MCP cleanup because event loop is still running")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def skyvern_run_task(prompt: str, url: str) -> dict[str, Any]:
|
||||
"""Use Skyvern to execute anything in the browser. Useful for accomplishing tasks that require browser automation.
|
||||
|
||||
This tool uses Skyvern's browser automation to navigate websites and perform actions to achieve
|
||||
the user's intended outcome. It can handle tasks like form filling, clicking buttons, data extraction,
|
||||
and multi-step workflows.
|
||||
|
||||
It can even help you find updated data on the internet if your model information is outdated.
|
||||
|
||||
Args:
|
||||
prompt: A natural language description of what needs to be accomplished (e.g. "Book a flight from
|
||||
NYC to LA", "Sign up for the newsletter", "Find the price of item X", "Apply to a job")
|
||||
url: The starting URL of the website where the task should be performed
|
||||
"""
|
||||
skyvern_agent = Skyvern(
|
||||
environment=SkyvernEnvironment.CLOUD,
|
||||
base_url=settings.SKYVERN_BASE_URL,
|
||||
api_key=settings.SKYVERN_API_KEY,
|
||||
)
|
||||
res = await skyvern_agent.run_task(prompt=prompt, url=url, user_agent="skyvern-mcp", wait_for_completion=True)
|
||||
|
||||
output = res.model_dump()["output"]
|
||||
if res.app_url:
|
||||
task_url = res.app_url
|
||||
else:
|
||||
if res.run_id and res.run_id.startswith("wr_"):
|
||||
task_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/runs/{res.run_id}/overview"
|
||||
else:
|
||||
task_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/tasks/{res.run_id}/actions"
|
||||
return {"output": output, "task_url": task_url, "run_id": res.run_id}
|
||||
|
||||
|
||||
def get_pids_on_port(port: int) -> List[int]:
|
||||
"""Return a list of PIDs listening on the given port."""
|
||||
pids = []
|
||||
@@ -295,20 +265,61 @@ def run_dev() -> None:
|
||||
|
||||
|
||||
@run_app.command(name="mcp")
|
||||
def run_mcp() -> None:
|
||||
"""Run the MCP server."""
|
||||
# This breaks the MCP processing because it expects json output only
|
||||
# console.print(Panel("[bold green]Starting MCP Server...[/bold green]", border_style="green"))
|
||||
def run_mcp(
|
||||
transport: Annotated[
|
||||
Literal["stdio", "sse", "streamable-http"],
|
||||
typer.Option(
|
||||
"--transport",
|
||||
help="MCP transport: stdio (default), sse, or streamable-http.",
|
||||
),
|
||||
] = "stdio",
|
||||
host: Annotated[str, typer.Option("--host", help="Host for HTTP transports.")] = "0.0.0.0",
|
||||
port: Annotated[int, typer.Option("--port", help="Port for HTTP transports.")] = 8000,
|
||||
path: Annotated[str, typer.Option("--path", help="HTTP endpoint path for MCP transport.")] = "/mcp",
|
||||
stateless_http: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--stateless-http/--no-stateless-http",
|
||||
help="Use stateless HTTP semantics for HTTP transports (ignored for stdio).",
|
||||
),
|
||||
] = True,
|
||||
) -> None:
|
||||
"""Run the MCP server with configurable transport for local or remote hosting."""
|
||||
path = _normalize_mcp_path(path)
|
||||
stateless_http_enabled = transport != "stdio" and stateless_http
|
||||
# atexit covers signal-based exits (SIGTERM); finally covers normal
|
||||
# mcp.run() completion or unhandled exceptions. Both are needed because
|
||||
# atexit doesn't fire on normal return and finally doesn't fire on signals.
|
||||
atexit.register(_cleanup_mcp_resources_sync)
|
||||
set_stateless_http_mode(stateless_http_enabled)
|
||||
try:
|
||||
mcp.run(transport="stdio")
|
||||
if transport == "stdio":
|
||||
mcp.run(transport="stdio")
|
||||
return
|
||||
|
||||
middleware = [Middleware(MCPAPIKeyMiddleware)]
|
||||
mcp.run(
|
||||
transport=transport,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
middleware=middleware,
|
||||
stateless_http=stateless_http_enabled,
|
||||
)
|
||||
finally:
|
||||
set_stateless_http_mode(False)
|
||||
_cleanup_mcp_resources_blocking()
|
||||
|
||||
|
||||
def _normalize_mcp_path(path: str) -> str:
|
||||
path = path.strip()
|
||||
if not path:
|
||||
return "/mcp"
|
||||
if not path.startswith("/"):
|
||||
return f"/{path}"
|
||||
return path
|
||||
|
||||
|
||||
@run_app.command(
|
||||
name="code",
|
||||
context_settings={"allow_interspersed_args": False},
|
||||
|
||||
@@ -1,27 +1,80 @@
|
||||
"""Workflow-related CLI helpers."""
|
||||
"""Workflow-related CLI commands with MCP-parity flags and output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import typer
|
||||
from dotenv import load_dotenv
|
||||
from rich.panel import Panel
|
||||
|
||||
from skyvern.client import Skyvern
|
||||
from skyvern.config import settings
|
||||
from skyvern.utils.env_paths import resolve_backend_env_path
|
||||
|
||||
from .console import console
|
||||
from .tasks import _list_workflow_tasks
|
||||
from .commands._output import output, output_error
|
||||
from .mcp_tools.workflow import skyvern_workflow_cancel as tool_workflow_cancel
|
||||
from .mcp_tools.workflow import skyvern_workflow_create as tool_workflow_create
|
||||
from .mcp_tools.workflow import skyvern_workflow_delete as tool_workflow_delete
|
||||
from .mcp_tools.workflow import skyvern_workflow_get as tool_workflow_get
|
||||
from .mcp_tools.workflow import skyvern_workflow_list as tool_workflow_list
|
||||
from .mcp_tools.workflow import skyvern_workflow_run as tool_workflow_run
|
||||
from .mcp_tools.workflow import skyvern_workflow_status as tool_workflow_status
|
||||
from .mcp_tools.workflow import skyvern_workflow_update as tool_workflow_update
|
||||
|
||||
workflow_app = typer.Typer(help="Manage Skyvern workflows.")
|
||||
workflow_app = typer.Typer(help="Manage Skyvern workflows.", no_args_is_help=True)
|
||||
|
||||
|
||||
def _emit_tool_result(result: dict[str, Any], *, json_output: bool) -> None:
|
||||
if json_output:
|
||||
json.dump(result, sys.stdout, indent=2, default=str)
|
||||
sys.stdout.write("\n")
|
||||
if not result.get("ok", False):
|
||||
raise SystemExit(1)
|
||||
return
|
||||
|
||||
if result.get("ok", False):
|
||||
output(result.get("data"), action=str(result.get("action", "")), json_mode=False)
|
||||
return
|
||||
|
||||
err = result.get("error") or {}
|
||||
output_error(str(err.get("message", "Unknown error")), hint=str(err.get("hint", "")), json_mode=False)
|
||||
|
||||
|
||||
def _run_tool(
|
||||
runner: Callable[[], Coroutine[Any, Any, dict[str, Any]]],
|
||||
*,
|
||||
json_output: bool,
|
||||
hint_on_exception: str,
|
||||
) -> None:
|
||||
try:
|
||||
result: dict[str, Any] = asyncio.run(runner())
|
||||
_emit_tool_result(result, json_output=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint=hint_on_exception, json_mode=json_output)
|
||||
|
||||
|
||||
def _resolve_inline_or_file(value: str | None, *, param_name: str) -> str | None:
|
||||
if value is None or not value.startswith("@"):
|
||||
return value
|
||||
|
||||
file_path = value[1:]
|
||||
if not file_path:
|
||||
raise typer.BadParameter(f"{param_name} file path cannot be empty after '@'.")
|
||||
|
||||
path = Path(file_path).expanduser()
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except OSError as e:
|
||||
raise typer.BadParameter(f"Unable to read {param_name} file '{path}': {e}") from e
|
||||
|
||||
|
||||
@workflow_app.callback()
|
||||
def workflow_callback(
|
||||
ctx: typer.Context,
|
||||
api_key: str | None = typer.Option(
|
||||
None,
|
||||
"--api-key",
|
||||
@@ -29,86 +82,188 @@ def workflow_callback(
|
||||
envvar="SKYVERN_API_KEY",
|
||||
),
|
||||
) -> None:
|
||||
"""Store the provided API key in the Typer context."""
|
||||
ctx.obj = {"api_key": api_key}
|
||||
|
||||
|
||||
def _get_client(api_key: str | None = None) -> Skyvern:
|
||||
"""Instantiate a Skyvern SDK client using environment variables."""
|
||||
"""Load workflow CLI environment and optional API key override."""
|
||||
load_dotenv(resolve_backend_env_path())
|
||||
key = api_key or os.getenv("SKYVERN_API_KEY") or settings.SKYVERN_API_KEY
|
||||
return Skyvern(base_url=settings.SKYVERN_BASE_URL, api_key=key)
|
||||
if api_key:
|
||||
settings.SKYVERN_API_KEY = api_key
|
||||
|
||||
|
||||
@workflow_app.command("list")
|
||||
def workflow_list(
|
||||
search: str | None = typer.Option(
|
||||
None,
|
||||
"--search",
|
||||
help="Search across workflow titles, folder names, and parameter metadata.",
|
||||
),
|
||||
page: int = typer.Option(1, "--page", min=1, help="Page number (1-based)."),
|
||||
page_size: int = typer.Option(10, "--page-size", min=1, max=100, help="Results per page."),
|
||||
only_workflows: bool = typer.Option(
|
||||
False,
|
||||
"--only-workflows",
|
||||
help="Only return multi-step workflows (exclude saved tasks).",
|
||||
),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""List workflows."""
|
||||
|
||||
async def _run() -> dict[str, Any]:
|
||||
return await tool_workflow_list(
|
||||
search=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
only_workflows=only_workflows,
|
||||
)
|
||||
|
||||
_run_tool(_run, json_output=json_output, hint_on_exception="Check your API key and workflow list filters.")
|
||||
|
||||
|
||||
@workflow_app.command("get")
|
||||
def workflow_get(
|
||||
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
|
||||
version: int | None = typer.Option(None, "--version", min=1, help="Specific version to retrieve."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Get a workflow definition by ID."""
|
||||
|
||||
async def _run() -> dict[str, Any]:
|
||||
return await tool_workflow_get(workflow_id=workflow_id, version=version)
|
||||
|
||||
_run_tool(_run, json_output=json_output, hint_on_exception="Check your API key and workflow ID.")
|
||||
|
||||
|
||||
@workflow_app.command("create")
|
||||
def workflow_create(
|
||||
definition: str = typer.Option(
|
||||
...,
|
||||
"--definition",
|
||||
help="Workflow definition as YAML/JSON string or @file path.",
|
||||
),
|
||||
definition_format: str = typer.Option(
|
||||
"auto",
|
||||
"--format",
|
||||
help="Definition format: json, yaml, or auto.",
|
||||
),
|
||||
folder_id: str | None = typer.Option(None, "--folder-id", help="Folder ID (fld_...) for the workflow."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Create a workflow."""
|
||||
|
||||
async def _run() -> dict[str, Any]:
|
||||
resolved_definition = _resolve_inline_or_file(definition, param_name="definition")
|
||||
assert resolved_definition is not None
|
||||
return await tool_workflow_create(
|
||||
definition=resolved_definition,
|
||||
format=definition_format,
|
||||
folder_id=folder_id,
|
||||
)
|
||||
|
||||
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow definition syntax.")
|
||||
|
||||
|
||||
@workflow_app.command("update")
|
||||
def workflow_update(
|
||||
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
|
||||
definition: str = typer.Option(
|
||||
...,
|
||||
"--definition",
|
||||
help="Updated workflow definition as YAML/JSON string or @file path.",
|
||||
),
|
||||
definition_format: str = typer.Option(
|
||||
"auto",
|
||||
"--format",
|
||||
help="Definition format: json, yaml, or auto.",
|
||||
),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Update a workflow definition."""
|
||||
|
||||
async def _run() -> dict[str, Any]:
|
||||
resolved_definition = _resolve_inline_or_file(definition, param_name="definition")
|
||||
assert resolved_definition is not None
|
||||
return await tool_workflow_update(
|
||||
workflow_id=workflow_id,
|
||||
definition=resolved_definition,
|
||||
format=definition_format,
|
||||
)
|
||||
|
||||
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and definition syntax.")
|
||||
|
||||
|
||||
@workflow_app.command("delete")
|
||||
def workflow_delete(
|
||||
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
|
||||
force: bool = typer.Option(False, "--force", help="Confirm permanent deletion."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Delete a workflow."""
|
||||
|
||||
async def _run() -> dict[str, Any]:
|
||||
return await tool_workflow_delete(workflow_id=workflow_id, force=force)
|
||||
|
||||
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and your permissions.")
|
||||
|
||||
|
||||
@workflow_app.command("run")
|
||||
def run_workflow(
|
||||
ctx: typer.Context,
|
||||
workflow_id: str = typer.Argument(..., help="Workflow permanent ID"),
|
||||
parameters: str = typer.Option("{}", "--parameters", "-p", help="JSON parameters for the workflow"),
|
||||
title: str | None = typer.Option(None, "--title", help="Title for the workflow run"),
|
||||
max_steps: int | None = typer.Option(None, "--max-steps", help="Override the workflow max steps"),
|
||||
def workflow_run(
|
||||
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
|
||||
params: str | None = typer.Option(
|
||||
None,
|
||||
"--params",
|
||||
"--parameters",
|
||||
"-p",
|
||||
help="Workflow parameters as JSON string or @file path.",
|
||||
),
|
||||
session: str | None = typer.Option(None, "--session", help="Browser session ID (pbs_...) to reuse."),
|
||||
webhook: str | None = typer.Option(None, "--webhook", help="Status webhook callback URL."),
|
||||
proxy: str | None = typer.Option(None, "--proxy", help="Proxy location (e.g., RESIDENTIAL)."),
|
||||
wait: bool = typer.Option(False, "--wait", help="Wait for workflow completion before returning."),
|
||||
timeout: int = typer.Option(
|
||||
300,
|
||||
"--timeout",
|
||||
min=10,
|
||||
max=3600,
|
||||
help="Max wait time in seconds when --wait is set.",
|
||||
),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Run a workflow."""
|
||||
try:
|
||||
params_dict = json.loads(parameters) if parameters else {}
|
||||
except json.JSONDecodeError:
|
||||
console.print(f"[red]Invalid JSON for parameters: {parameters}[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
|
||||
run_resp = client.run_workflow(
|
||||
workflow_id=workflow_id,
|
||||
parameters=params_dict,
|
||||
title=title,
|
||||
max_steps_override=max_steps,
|
||||
)
|
||||
console.print(
|
||||
Panel(
|
||||
f"Started workflow run [bold]{run_resp.run_id}[/bold]",
|
||||
border_style="green",
|
||||
async def _run() -> dict[str, Any]:
|
||||
resolved_params = _resolve_inline_or_file(params, param_name="params")
|
||||
return await tool_workflow_run(
|
||||
workflow_id=workflow_id,
|
||||
parameters=resolved_params,
|
||||
browser_session_id=session,
|
||||
webhook_url=webhook,
|
||||
proxy_location=proxy,
|
||||
wait=wait,
|
||||
timeout_seconds=timeout,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@workflow_app.command("cancel")
|
||||
def cancel_workflow(
|
||||
ctx: typer.Context,
|
||||
run_id: str = typer.Argument(..., help="ID of the workflow run"),
|
||||
) -> None:
|
||||
"""Cancel a running workflow."""
|
||||
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
|
||||
client.cancel_run(run_id=run_id)
|
||||
console.print(Panel(f"Cancel signal sent for run {run_id}", border_style="red"))
|
||||
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and run parameters.")
|
||||
|
||||
|
||||
@workflow_app.command("status")
|
||||
def workflow_status(
|
||||
ctx: typer.Context,
|
||||
run_id: str = typer.Argument(..., help="ID of the workflow run"),
|
||||
tasks: bool = typer.Option(False, "--tasks", help="Show task executions"),
|
||||
run_id: str = typer.Option(..., "--run-id", help="Run ID (wr_... or tsk_v2_...)."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Retrieve status information for a workflow run."""
|
||||
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
|
||||
run = client.get_run(run_id=run_id)
|
||||
console.print(Panel(run.model_dump_json(indent=2), border_style="cyan"))
|
||||
if tasks:
|
||||
task_list = _list_workflow_tasks(client, run_id)
|
||||
console.print(Panel(json.dumps(task_list, indent=2), border_style="magenta"))
|
||||
"""Get workflow run status."""
|
||||
|
||||
async def _run() -> dict[str, Any]:
|
||||
return await tool_workflow_status(run_id=run_id)
|
||||
|
||||
_run_tool(_run, json_output=json_output, hint_on_exception="Check the run ID and API key.")
|
||||
|
||||
|
||||
@workflow_app.command("list")
|
||||
def list_workflows(
|
||||
ctx: typer.Context,
|
||||
page: int = typer.Option(1, "--page", help="Page number"),
|
||||
page_size: int = typer.Option(10, "--page-size", help="Number of workflows to return"),
|
||||
template: bool = typer.Option(False, "--template", help="List template workflows"),
|
||||
@workflow_app.command("cancel")
|
||||
def workflow_cancel(
|
||||
run_id: str = typer.Option(..., "--run-id", help="Run ID (wr_... or tsk_v2_...)."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""List workflows for the organization."""
|
||||
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
|
||||
resp = client._client_wrapper.httpx_client.request(
|
||||
"api/v1/workflows",
|
||||
method="GET",
|
||||
params={"page": page, "page_size": page_size, "template": template},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
console.print(Panel(json.dumps(resp.json(), indent=2), border_style="cyan"))
|
||||
"""Cancel a workflow run."""
|
||||
|
||||
async def _run() -> dict[str, Any]:
|
||||
return await tool_workflow_cancel(run_id=run_id)
|
||||
|
||||
_run_tool(_run, json_output=json_output, hint_on_exception="Check the run ID and API key.")
|
||||
|
||||
@@ -296,3 +296,175 @@ class TestBrowserCommands:
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is False
|
||||
assert "Invalid state" in parsed["error"]["message"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow command behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWorkflowCommands:
|
||||
def test_workflow_get_outputs_mcp_envelope_in_json_mode(
|
||||
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
expected = {
|
||||
"ok": True,
|
||||
"action": "skyvern_workflow_get",
|
||||
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||
"data": {"workflow_permanent_id": "wpid_123"},
|
||||
"artifacts": [],
|
||||
"timing_ms": {},
|
||||
"warnings": [],
|
||||
"error": None,
|
||||
}
|
||||
tool = AsyncMock(return_value=expected)
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_get", tool)
|
||||
|
||||
workflow_cmd.workflow_get(workflow_id="wpid_123", version=2, json_output=True)
|
||||
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed == expected
|
||||
assert tool.await_args.kwargs == {"workflow_id": "wpid_123", "version": 2}
|
||||
|
||||
def test_workflow_create_reads_definition_from_file(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
capsys: pytest.CaptureFixture,
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
definition_file = tmp_path / "workflow.json"
|
||||
definition_text = '{"title": "Example", "workflow_definition": {"blocks": []}}'
|
||||
definition_file.write_text(definition_text)
|
||||
|
||||
tool = AsyncMock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"action": "skyvern_workflow_create",
|
||||
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||
"data": {"workflow_permanent_id": "wpid_new"},
|
||||
"artifacts": [],
|
||||
"timing_ms": {},
|
||||
"warnings": [],
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_create", tool)
|
||||
|
||||
workflow_cmd.workflow_create(
|
||||
definition=f"@{definition_file}",
|
||||
definition_format="json",
|
||||
folder_id="fld_123",
|
||||
json_output=True,
|
||||
)
|
||||
|
||||
assert tool.await_args.kwargs == {
|
||||
"definition": definition_text,
|
||||
"format": "json",
|
||||
"folder_id": "fld_123",
|
||||
}
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is True
|
||||
assert parsed["data"]["workflow_permanent_id"] == "wpid_new"
|
||||
|
||||
def test_workflow_run_reads_params_file_and_maps_options(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
capsys: pytest.CaptureFixture,
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
params_file = tmp_path / "params.json"
|
||||
params_file.write_text('{"company": "Acme"}')
|
||||
|
||||
tool = AsyncMock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"action": "skyvern_workflow_run",
|
||||
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||
"data": {"run_id": "wr_123", "status": "queued"},
|
||||
"artifacts": [],
|
||||
"timing_ms": {},
|
||||
"warnings": [],
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_run", tool)
|
||||
|
||||
workflow_cmd.workflow_run(
|
||||
workflow_id="wpid_123",
|
||||
params=f"@{params_file}",
|
||||
session="pbs_456",
|
||||
webhook="https://example.com/webhook",
|
||||
proxy="RESIDENTIAL",
|
||||
wait=True,
|
||||
timeout=450,
|
||||
json_output=True,
|
||||
)
|
||||
|
||||
assert tool.await_args.kwargs == {
|
||||
"workflow_id": "wpid_123",
|
||||
"parameters": '{"company": "Acme"}',
|
||||
"browser_session_id": "pbs_456",
|
||||
"webhook_url": "https://example.com/webhook",
|
||||
"proxy_location": "RESIDENTIAL",
|
||||
"wait": True,
|
||||
"timeout_seconds": 450,
|
||||
}
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is True
|
||||
assert parsed["data"]["run_id"] == "wr_123"
|
||||
|
||||
def test_workflow_status_json_error_exits_nonzero(
|
||||
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
tool = AsyncMock(
|
||||
return_value={
|
||||
"ok": False,
|
||||
"action": "skyvern_workflow_status",
|
||||
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||
"data": None,
|
||||
"artifacts": [],
|
||||
"timing_ms": {},
|
||||
"warnings": [],
|
||||
"error": {
|
||||
"code": "RUN_NOT_FOUND",
|
||||
"message": "Run 'wr_missing' not found",
|
||||
"hint": "Verify the run ID",
|
||||
"details": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_status", tool)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
workflow_cmd.workflow_status(run_id="wr_missing", json_output=True)
|
||||
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is False
|
||||
assert parsed["error"]["code"] == "RUN_NOT_FOUND"
|
||||
|
||||
def test_workflow_update_missing_definition_file_raises_bad_parameter(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
from skyvern.cli import workflow as workflow_cmd
|
||||
|
||||
tool = AsyncMock()
|
||||
monkeypatch.setattr(workflow_cmd, "tool_workflow_update", tool)
|
||||
missing_file = tmp_path / "missing-definition.json"
|
||||
|
||||
with pytest.raises(typer.BadParameter, match="Unable to read definition file"):
|
||||
workflow_cmd.workflow_update(
|
||||
workflow_id="wpid_123",
|
||||
definition=f"@{missing_file}",
|
||||
definition_format="json",
|
||||
json_output=False,
|
||||
)
|
||||
|
||||
tool.assert_not_called()
|
||||
|
||||
304
tests/unit/test_mcp_http_auth.py
Normal file
304
tests/unit/test_mcp_http_auth.py
Normal file
@@ -0,0 +1,304 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from skyvern.cli.core import client as client_mod
|
||||
from skyvern.cli.core import mcp_http_auth
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_auth_context() -> None:
|
||||
client_mod._api_key_override.set(None)
|
||||
mcp_http_auth._auth_db = None
|
||||
mcp_http_auth._api_key_validation_cache.clear()
|
||||
mcp_http_auth._API_KEY_CACHE_TTL_SECONDS = 30.0
|
||||
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
|
||||
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
|
||||
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
|
||||
|
||||
|
||||
async def _echo_request_context(request: Request) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{
|
||||
"api_key": client_mod.get_active_api_key(),
|
||||
"organization_id": getattr(request.state, "organization_id", None),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_test_app() -> Starlette:
|
||||
return Starlette(
|
||||
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
|
||||
middleware=[Middleware(mcp_http_auth.MCPAPIKeyMiddleware)],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_rejects_missing_api_key() -> None:
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", json={})
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert "x-api-key" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_allows_health_checks_without_api_key() -> None:
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.get("/healthz")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_rejects_invalid_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "bad-key"}, json={})
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||
assert response.json()["error"]["message"] == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_returns_500_on_non_auth_http_exception(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(side_effect=HTTPException(status_code=500, detail="db down")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||
|
||||
assert response.status_code == 500
|
||||
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_returns_503_on_transient_validation_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(side_effect=HTTPException(status_code=503, detail="API key validation temporarily unavailable")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||
|
||||
assert response.status_code == 503
|
||||
assert response.json()["error"]["code"] == "SERVICE_UNAVAILABLE"
|
||||
assert response.json()["error"]["message"] == "API key validation temporarily unavailable"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_returns_500_on_unexpected_validation_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(side_effect=RuntimeError("boom")),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||
|
||||
assert response.status_code == 500
|
||||
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
mcp_http_auth,
|
||||
"validate_mcp_api_key",
|
||||
AsyncMock(return_value="org_123"),
|
||||
)
|
||||
app = _build_test_app()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"api_key": "sk_live_abc",
|
||||
"organization_id": "org_123",
|
||||
}
|
||||
assert client_mod.get_active_api_key() != "sk_live_abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached"))
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||
|
||||
assert first == "org_cached"
|
||||
assert second == "org_cached"
|
||||
assert calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}"))
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire")
|
||||
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
|
||||
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||
|
||||
assert first == "org_1"
|
||||
assert second == "org_2"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
with pytest.raises(HTTPException, match="Invalid credentials"):
|
||||
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
|
||||
|
||||
with pytest.raises(HTTPException, match="Invalid API key"):
|
||||
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
|
||||
|
||||
assert calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_retries_transient_failure_without_negative_cache(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
if calls == 1:
|
||||
raise RuntimeError("transient db error")
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered"))
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
|
||||
|
||||
cache_key = mcp_http_auth._cache_key("sk_test_transient")
|
||||
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered"
|
||||
|
||||
assert recovered_org == "org_recovered"
|
||||
assert calls == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Multiple concurrent callers for the same key all succeed; the cache
|
||||
collapses subsequent calls after the first one populates it."""
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent"))
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
|
||||
assert all(r == "org_concurrent" for r in results)
|
||||
# First call populates cache; remaining may or may not hit DB depending on
|
||||
# scheduling, but all must succeed.
|
||||
assert calls >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls = 0
|
||||
|
||||
async def _resolve(_api_key: str, _db: object) -> object:
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
raise RuntimeError("persistent db outage")
|
||||
|
||||
monkeypatch.setattr(mcp_http_auth, "_MAX_VALIDATION_RETRIES", 2)
|
||||
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||
|
||||
with pytest.raises(HTTPException, match="temporarily unavailable") as exc_info:
|
||||
await mcp_http_auth.validate_mcp_api_key("sk_test_transient_exhausted")
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert calls == 3 # initial + 2 retries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_auth_db_disposes_engine() -> None:
|
||||
dispose = AsyncMock()
|
||||
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
|
||||
mcp_http_auth._api_key_validation_cache["k"] = ("org", 123.0)
|
||||
|
||||
await mcp_http_auth.close_auth_db()
|
||||
|
||||
dispose.assert_awaited_once()
|
||||
assert mcp_http_auth._auth_db is None
|
||||
assert mcp_http_auth._api_key_validation_cache == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_auth_db_noop_when_uninitialized() -> None:
|
||||
mcp_http_auth._auth_db = None
|
||||
await mcp_http_auth.close_auth_db()
|
||||
assert mcp_http_auth._auth_db is None
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -14,10 +15,13 @@ from skyvern.cli.mcp_tools import session as mcp_session
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singletons() -> None:
|
||||
client_mod._skyvern_instance.set(None)
|
||||
client_mod._api_key_override.set(None)
|
||||
client_mod._global_skyvern_instance = None
|
||||
client_mod._api_key_clients.clear()
|
||||
|
||||
session_manager._current_session.set(None)
|
||||
session_manager._global_session = None
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
mcp_session.set_current_session(mcp_session.SessionState())
|
||||
|
||||
|
||||
@@ -47,6 +51,115 @@ def test_get_skyvern_reuses_global_instance_across_contexts(monkeypatch: pytest.
|
||||
assert len(created) == 1
|
||||
|
||||
|
||||
def test_get_skyvern_reuses_override_instance_per_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created_keys: list[str] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
created_keys.append(kwargs["api_key"])
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls(api_key="local")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
|
||||
token = client_mod.set_api_key_override("sk_key_a")
|
||||
try:
|
||||
first = client_mod.get_skyvern()
|
||||
client_mod._skyvern_instance.set(None)
|
||||
second = client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert first is second
|
||||
assert created_keys == ["sk_key_a"]
|
||||
|
||||
|
||||
def test_get_skyvern_override_client_cache_uses_lru_eviction(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created_keys: list[str] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
created_keys.append(kwargs["api_key"])
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls(api_key="local")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 2)
|
||||
|
||||
for key in ("sk_key_a", "sk_key_b"):
|
||||
token = client_mod.set_api_key_override(key)
|
||||
try:
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
# Touch key_a so key_b becomes least-recently-used.
|
||||
token = client_mod.set_api_key_override("sk_key_a")
|
||||
try:
|
||||
client_mod._skyvern_instance.set(None)
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
# Adding key_c should evict key_b.
|
||||
token = client_mod.set_api_key_override("sk_key_c")
|
||||
try:
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert list(client_mod._api_key_clients.keys()) == [
|
||||
client_mod._cache_key("sk_key_a"),
|
||||
client_mod._cache_key("sk_key_c"),
|
||||
]
|
||||
# key_a, key_b, key_c were created exactly once each.
|
||||
assert created_keys == ["sk_key_a", "sk_key_b", "sk_key_c"]
|
||||
|
||||
|
||||
def test_get_skyvern_override_cache_closes_evicted_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
closed_keys: list[str] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
self.api_key = kwargs["api_key"]
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls(api_key="local")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
closed_keys.append(self.api_key)
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 1)
|
||||
|
||||
for key in ("sk_key_a", "sk_key_b"):
|
||||
token = client_mod.set_api_key_override(key)
|
||||
try:
|
||||
client_mod.get_skyvern()
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert list(client_mod._api_key_clients.keys()) == [client_mod._cache_key("sk_key_b")]
|
||||
assert closed_keys == ["sk_key_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_skyvern_closes_singleton() -> None:
|
||||
fake = MagicMock()
|
||||
@@ -75,12 +188,53 @@ def test_get_current_session_falls_back_to_global_state() -> None:
|
||||
assert recovered is state
|
||||
|
||||
|
||||
def test_get_current_session_stateless_mode_ignores_global_state() -> None:
|
||||
global_state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_999"),
|
||||
)
|
||||
session_manager._global_session = global_state
|
||||
session_manager._current_session.set(None)
|
||||
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
recovered = session_manager.get_current_session()
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert recovered is not global_state
|
||||
assert recovered.browser is None
|
||||
assert recovered.context is None
|
||||
|
||||
|
||||
def test_set_current_session_stateless_mode_does_not_override_global_state() -> None:
|
||||
global_state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
|
||||
)
|
||||
session_manager._global_session = global_state
|
||||
replacement = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_request"),
|
||||
)
|
||||
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
session_manager.set_current_session(replacement)
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert session_manager._global_session is global_state
|
||||
assert session_manager._current_session.get() is replacement
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_browser = MagicMock()
|
||||
current_state = session_manager.SessionState(
|
||||
browser=current_browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
api_key_hash=session_manager._api_key_hash(client_mod.get_active_api_key()),
|
||||
)
|
||||
session_manager.set_current_session(current_state)
|
||||
|
||||
@@ -95,6 +249,92 @@ async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest
|
||||
fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_does_not_reuse_session_for_different_api_key(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
current_browser = MagicMock()
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=current_browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
api_key_hash=session_manager._api_key_hash("sk_key_a"),
|
||||
)
|
||||
)
|
||||
|
||||
replacement_browser = MagicMock()
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
token = client_mod.set_api_key_override("sk_key_b")
|
||||
try:
|
||||
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert browser is replacement_browser
|
||||
assert ctx.session_id == "pbs_123"
|
||||
fake_skyvern.connect_to_cloud_browser_session.assert_awaited_once_with("pbs_123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_stateless_mode_does_not_write_global_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
global_state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
|
||||
)
|
||||
session_manager._global_session = global_state
|
||||
|
||||
replacement_browser = MagicMock()
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert browser is replacement_browser
|
||||
assert ctx.session_id == "pbs_123"
|
||||
assert session_manager._global_session is global_state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_blocks_implicit_session_in_stateless_mode() -> None:
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
api_key_hash=session_manager._api_key_hash("sk_key_a"),
|
||||
)
|
||||
)
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
with pytest.raises(session_manager.BrowserNotAvailableError):
|
||||
await session_manager.resolve_browser()
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_raises_for_invalid_matching_state(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=None,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(session_manager, "_matches_current", lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: MagicMock())
|
||||
|
||||
with pytest.raises(RuntimeError, match="Expected active browser and context"):
|
||||
await session_manager.resolve_browser(session_id="pbs_123")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_browser = MagicMock()
|
||||
@@ -262,3 +502,73 @@ async def test_close_current_session_still_closes_browser_when_api_fails(monkeyp
|
||||
# _browser_session_id should NOT be cleared (API close failed, let browser.close() try)
|
||||
assert browser._browser_session_id == "pbs_fail"
|
||||
assert session_manager.get_current_session().browser is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for stateless HTTP mode session creation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_create_stateless_mode_returns_session_without_persisting_browser(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.create_browser_session = AsyncMock(return_value=SimpleNamespace(browser_session_id="pbs_abc"))
|
||||
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||
do_session_create = AsyncMock()
|
||||
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
|
||||
|
||||
try:
|
||||
result = await mcp_session.skyvern_session_create(timeout=45)
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["data"] == {"session_id": "pbs_abc", "timeout_minutes": 45}
|
||||
do_session_create.assert_not_awaited()
|
||||
assert mcp_session.get_current_session().browser is None
|
||||
assert mcp_session.get_current_session().context is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_create_stateless_mode_rejects_local() -> None:
|
||||
session_manager.set_stateless_http_mode(True)
|
||||
try:
|
||||
result = await mcp_session.skyvern_session_create(local=True)
|
||||
finally:
|
||||
session_manager.set_stateless_http_mode(False)
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["error"]["code"] == mcp_session.ErrorCode.INVALID_INPUT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_create_persists_active_api_key_hash_in_session_state(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_skyvern = MagicMock()
|
||||
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
fake_browser = MagicMock()
|
||||
do_session_create = AsyncMock(
|
||||
return_value=(
|
||||
fake_browser,
|
||||
SimpleNamespace(local=False, session_id="pbs_123", timeout_minutes=60, headless=False),
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
|
||||
|
||||
token = client_mod.set_api_key_override("sk_key_create")
|
||||
try:
|
||||
result = await mcp_session.skyvern_session_create(timeout=60)
|
||||
finally:
|
||||
client_mod.reset_api_key_override(token)
|
||||
|
||||
assert result["ok"] is True
|
||||
current = mcp_session.get_current_session()
|
||||
assert current.browser is fake_browser
|
||||
assert current.context == BrowserContext(mode="cloud_session", session_id="pbs_123")
|
||||
assert current.api_key_hash == session_manager._api_key_hash("sk_key_create")
|
||||
assert current.api_key_hash != "sk_key_create"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,6 +13,42 @@ def _reset_cleanup_state() -> None:
|
||||
run_commands._mcp_cleanup_done = False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_mcp_resources_closes_auth_db(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
close_current_session = AsyncMock()
|
||||
close_skyvern = AsyncMock()
|
||||
close_auth_db = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
|
||||
monkeypatch.setattr(run_commands, "close_skyvern", close_skyvern)
|
||||
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
|
||||
|
||||
await run_commands._cleanup_mcp_resources()
|
||||
|
||||
close_current_session.assert_awaited_once()
|
||||
close_skyvern.assert_awaited_once()
|
||||
close_auth_db.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_mcp_resources_closes_auth_db_on_skyvern_close_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
close_current_session = AsyncMock()
|
||||
close_auth_db = AsyncMock()
|
||||
|
||||
async def _failing_close_skyvern() -> None:
|
||||
raise RuntimeError("close failed")
|
||||
|
||||
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
|
||||
monkeypatch.setattr(run_commands, "close_skyvern", _failing_close_skyvern)
|
||||
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
|
||||
|
||||
with pytest.raises(RuntimeError, match="close failed"):
|
||||
await run_commands._cleanup_mcp_resources()
|
||||
|
||||
close_current_session.assert_awaited_once()
|
||||
close_auth_db.assert_awaited_once()
|
||||
|
||||
|
||||
def test_cleanup_mcp_resources_sync_runs_without_running_loop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = AsyncMock()
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup)
|
||||
@@ -50,14 +86,56 @@ def test_run_mcp_calls_blocking_cleanup_in_finally(monkeypatch: pytest.MonkeyPat
|
||||
cleanup_blocking = MagicMock()
|
||||
register = MagicMock()
|
||||
run = MagicMock(side_effect=RuntimeError("boom"))
|
||||
set_stateless = MagicMock()
|
||||
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
|
||||
monkeypatch.setattr(run_commands.atexit, "register", register)
|
||||
monkeypatch.setattr(run_commands.mcp, "run", run)
|
||||
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
run_commands.run_mcp()
|
||||
|
||||
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
|
||||
run.assert_called_once_with(transport="stdio")
|
||||
set_stateless.assert_has_calls([call(False), call(False)])
|
||||
cleanup_blocking.assert_called_once()
|
||||
|
||||
|
||||
def test_run_mcp_http_transport_wires_auth_middleware(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup_blocking = MagicMock()
|
||||
register = MagicMock()
|
||||
run = MagicMock()
|
||||
set_stateless = MagicMock()
|
||||
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
|
||||
monkeypatch.setattr(run_commands.atexit, "register", register)
|
||||
monkeypatch.setattr(run_commands.mcp, "run", run)
|
||||
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
|
||||
|
||||
run_commands.run_mcp(
|
||||
transport="streamable-http",
|
||||
host="127.0.0.1",
|
||||
port=9010,
|
||||
path="mcp",
|
||||
stateless_http=True,
|
||||
)
|
||||
|
||||
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
|
||||
run.assert_called_once()
|
||||
kwargs = run.call_args.kwargs
|
||||
assert kwargs["transport"] == "streamable-http"
|
||||
assert kwargs["host"] == "127.0.0.1"
|
||||
assert kwargs["port"] == 9010
|
||||
assert kwargs["path"] == "/mcp"
|
||||
assert kwargs["stateless_http"] is True
|
||||
middleware = kwargs["middleware"]
|
||||
assert len(middleware) == 1
|
||||
assert middleware[0].cls is run_commands.MCPAPIKeyMiddleware
|
||||
set_stateless.assert_has_calls([call(True), call(False)])
|
||||
cleanup_blocking.assert_called_once()
|
||||
|
||||
|
||||
def test_run_task_tool_registration_points_to_browser_module() -> None:
|
||||
tool = run_commands.mcp._tool_manager._tools["skyvern_run_task"] # type: ignore[attr-defined]
|
||||
assert tool.fn.__module__ == "skyvern.cli.mcp_tools.browser"
|
||||
|
||||
Reference in New Issue
Block a user