align workflow CLI commands with MCP parity (#4792)
This commit is contained in:
@@ -68,6 +68,30 @@ Add this to your MCP client's configuration file:
|
|||||||
|
|
||||||
Replace `/usr/bin/python3` with the output of `which python3` on your machine. For local mode, set `SKYVERN_BASE_URL` to `http://localhost:8000` and find your API key in the `.env` file after running `skyvern init`.
|
Replace `/usr/bin/python3` with the output of `which python3` on your machine. For local mode, set `SKYVERN_BASE_URL` to `http://localhost:8000` and find your API key in the `.env` file after running `skyvern init`.
|
||||||
|
|
||||||
|
### Option C: Remote MCP over HTTPS (streamable HTTP)
|
||||||
|
|
||||||
|
Use this when your team provides a hosted MCP endpoint (for example: `https://mcp.skyvern.com/mcp`).
|
||||||
|
|
||||||
|
In remote HTTP mode:
|
||||||
|
- Clients must send `x-api-key` on every request.
|
||||||
|
- Use `skyvern_session_create` first, then pass `session_id` explicitly on subsequent browser tool calls.
|
||||||
|
|
||||||
|
If your MCP client supports native remote HTTP transport, configure it directly:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"SkyvernRemote": {
|
||||||
|
"type": "streamable-http",
|
||||||
|
"url": "https://mcp.skyvern.com/mcp",
|
||||||
|
"headers": {
|
||||||
|
"x-api-key": "YOUR_SKYVERN_API_KEY"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
<Accordion title="Config file locations by client">
|
<Accordion title="Config file locations by client">
|
||||||
|
|
||||||
| Client | Path |
|
| Client | Path |
|
||||||
|
|||||||
29
skyvern/cli/core/api_key_hash.py
Normal file
29
skyvern/cli/core/api_key_hash.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_api_key_hash_iterations() -> int:
|
||||||
|
raw = os.environ.get("SKYVERN_MCP_API_KEY_HASH_ITERATIONS", "120000")
|
||||||
|
try:
|
||||||
|
return max(10_000, int(raw))
|
||||||
|
except ValueError:
|
||||||
|
return 120_000
|
||||||
|
|
||||||
|
|
||||||
|
_API_KEY_HASH_ITERATIONS = _resolve_api_key_hash_iterations()
|
||||||
|
_API_KEY_HASH_SALT = os.environ.get(
|
||||||
|
"SKYVERN_MCP_API_KEY_HASH_SALT",
|
||||||
|
"skyvern-mcp-api-key-cache-v1",
|
||||||
|
).encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def hash_api_key_for_cache(api_key: str) -> str:
|
||||||
|
"""Derive a deterministic, non-reversible fingerprint for API-key keyed caches."""
|
||||||
|
return hashlib.pbkdf2_hmac(
|
||||||
|
"sha256",
|
||||||
|
api_key.encode("utf-8"),
|
||||||
|
_API_KEY_HASH_SALT,
|
||||||
|
_API_KEY_HASH_ITERATIONS,
|
||||||
|
).hex()
|
||||||
@@ -1,7 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from contextvars import ContextVar
|
from collections import OrderedDict
|
||||||
|
from contextvars import ContextVar, Token
|
||||||
|
from threading import RLock
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -9,35 +12,126 @@ from skyvern.client import SkyvernEnvironment
|
|||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.library.skyvern import Skyvern
|
from skyvern.library.skyvern import Skyvern
|
||||||
|
|
||||||
|
from .api_key_hash import hash_api_key_for_cache
|
||||||
|
|
||||||
_skyvern_instance: ContextVar[Skyvern | None] = ContextVar("skyvern_instance", default=None)
|
_skyvern_instance: ContextVar[Skyvern | None] = ContextVar("skyvern_instance", default=None)
|
||||||
|
_api_key_override: ContextVar[str | None] = ContextVar("skyvern_api_key_override", default=None)
|
||||||
_global_skyvern_instance: Skyvern | None = None
|
_global_skyvern_instance: Skyvern | None = None
|
||||||
|
_api_key_clients: OrderedDict[str, Skyvern] = OrderedDict()
|
||||||
|
_clients_lock = RLock()
|
||||||
LOG = structlog.get_logger(__name__)
|
LOG = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_api_key_cache_size() -> int:
|
||||||
|
raw = os.environ.get("SKYVERN_MCP_API_KEY_CLIENT_CACHE_SIZE", "128")
|
||||||
|
try:
|
||||||
|
return max(1, int(raw))
|
||||||
|
except ValueError:
|
||||||
|
return 128
|
||||||
|
|
||||||
|
|
||||||
|
_API_KEY_CLIENT_CACHE_MAX = _resolve_api_key_cache_size()
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_key(api_key: str) -> str:
|
||||||
|
"""Hash API key so raw secrets are never stored as dict keys."""
|
||||||
|
return hash_api_key_for_cache(api_key)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_api_key() -> str | None:
|
||||||
|
return settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_base_url() -> str | None:
|
||||||
|
return settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_cloud_client(api_key: str) -> Skyvern:
|
||||||
|
return Skyvern(
|
||||||
|
api_key=api_key,
|
||||||
|
environment=SkyvernEnvironment.CLOUD,
|
||||||
|
base_url=_resolve_base_url(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _close_skyvern_instance_best_effort(instance: Skyvern) -> None:
|
||||||
|
"""Close a Skyvern instance, regardless of whether an event loop is running."""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
try:
|
||||||
|
asyncio.run(instance.aclose())
|
||||||
|
except Exception:
|
||||||
|
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
task = loop.create_task(instance.aclose())
|
||||||
|
|
||||||
|
def _on_done(done: asyncio.Task[None]) -> None:
|
||||||
|
try:
|
||||||
|
done.result()
|
||||||
|
except Exception:
|
||||||
|
LOG.debug("Failed to close evicted Skyvern client", exc_info=True)
|
||||||
|
|
||||||
|
task.add_done_callback(_on_done)
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_api_key() -> str | None:
|
||||||
|
"""Return the effective API key for this request/context."""
|
||||||
|
return _api_key_override.get() or _resolve_api_key()
|
||||||
|
|
||||||
|
|
||||||
|
def set_api_key_override(api_key: str | None) -> Token[str | None]:
|
||||||
|
"""Set request-scoped API key override for MCP HTTP requests."""
|
||||||
|
_skyvern_instance.set(None)
|
||||||
|
return _api_key_override.set(api_key)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_api_key_override(token: Token[str | None]) -> None:
|
||||||
|
"""Reset request-scoped API key override."""
|
||||||
|
_api_key_override.reset(token)
|
||||||
|
_skyvern_instance.set(None)
|
||||||
|
|
||||||
|
|
||||||
def get_skyvern() -> Skyvern:
|
def get_skyvern() -> Skyvern:
|
||||||
"""Get or create a Skyvern client instance."""
|
"""Get or create a Skyvern client instance."""
|
||||||
global _global_skyvern_instance
|
global _global_skyvern_instance
|
||||||
|
|
||||||
instance = _skyvern_instance.get()
|
override_api_key = _api_key_override.get()
|
||||||
if instance is None:
|
if override_api_key:
|
||||||
instance = _global_skyvern_instance
|
instance = _skyvern_instance.get()
|
||||||
if instance is not None:
|
if instance is None:
|
||||||
|
key = _cache_key(override_api_key)
|
||||||
|
evicted_clients: list[Skyvern] = []
|
||||||
|
# Hold lock across lookup + build + insert to prevent two coroutines
|
||||||
|
# from both building a client for the same API key concurrently.
|
||||||
|
with _clients_lock:
|
||||||
|
instance = _api_key_clients.get(key)
|
||||||
|
if instance is not None:
|
||||||
|
_api_key_clients.move_to_end(key)
|
||||||
|
else:
|
||||||
|
instance = _build_cloud_client(override_api_key)
|
||||||
|
_api_key_clients[key] = instance
|
||||||
|
_api_key_clients.move_to_end(key)
|
||||||
|
while len(_api_key_clients) > _API_KEY_CLIENT_CACHE_MAX:
|
||||||
|
_, evicted = _api_key_clients.popitem(last=False)
|
||||||
|
evicted_clients.append(evicted)
|
||||||
|
for evicted in evicted_clients:
|
||||||
|
_close_skyvern_instance_best_effort(evicted)
|
||||||
_skyvern_instance.set(instance)
|
_skyvern_instance.set(instance)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
api_key = settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
|
instance = _skyvern_instance.get()
|
||||||
base_url = settings.SKYVERN_BASE_URL or os.environ.get("SKYVERN_BASE_URL")
|
if instance is None:
|
||||||
|
with _clients_lock:
|
||||||
if api_key:
|
instance = _global_skyvern_instance
|
||||||
instance = Skyvern(
|
if instance is None:
|
||||||
api_key=api_key,
|
api_key = _resolve_api_key()
|
||||||
environment=SkyvernEnvironment.CLOUD,
|
if api_key:
|
||||||
base_url=base_url,
|
instance = _build_cloud_client(api_key)
|
||||||
)
|
else:
|
||||||
else:
|
instance = Skyvern.local()
|
||||||
instance = Skyvern.local()
|
_global_skyvern_instance = instance
|
||||||
|
|
||||||
_global_skyvern_instance = instance
|
|
||||||
_skyvern_instance.set(instance)
|
_skyvern_instance.set(instance)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
@@ -48,7 +142,12 @@ async def close_skyvern() -> None:
|
|||||||
|
|
||||||
instances: list[Skyvern] = []
|
instances: list[Skyvern] = []
|
||||||
seen: set[int] = set()
|
seen: set[int] = set()
|
||||||
for candidate in (_skyvern_instance.get(), _global_skyvern_instance):
|
with _clients_lock:
|
||||||
|
candidates = (_skyvern_instance.get(), _global_skyvern_instance, *_api_key_clients.values())
|
||||||
|
_api_key_clients.clear()
|
||||||
|
_global_skyvern_instance = None
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
if candidate is None or id(candidate) in seen:
|
if candidate is None or id(candidate) in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(id(candidate))
|
seen.add(id(candidate))
|
||||||
@@ -61,4 +160,3 @@ async def close_skyvern() -> None:
|
|||||||
LOG.warning("Failed to close Skyvern client", exc_info=True)
|
LOG.warning("Failed to close Skyvern client", exc_info=True)
|
||||||
|
|
||||||
_skyvern_instance.set(None)
|
_skyvern_instance.set(None)
|
||||||
_global_skyvern_instance = None
|
|
||||||
|
|||||||
206
skyvern/cli/core/mcp_http_auth.py
Normal file
206
skyvern/cli/core/mcp_http_auth.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from threading import RLock
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
|
from skyvern.forge.sdk.db.agent_db import AgentDB
|
||||||
|
from skyvern.forge.sdk.services.org_auth_service import resolve_org_from_api_key
|
||||||
|
|
||||||
|
from .api_key_hash import hash_api_key_for_cache
|
||||||
|
from .client import reset_api_key_override, set_api_key_override
|
||||||
|
|
||||||
|
LOG = structlog.get_logger(__name__)
|
||||||
|
API_KEY_HEADER = "x-api-key"
|
||||||
|
HEALTH_PATHS = {"/health", "/healthz"}
|
||||||
|
_auth_db: AgentDB | None = None
|
||||||
|
_auth_db_lock = RLock()
|
||||||
|
_api_key_cache_lock = RLock()
|
||||||
|
_api_key_validation_cache: OrderedDict[str, tuple[str | None, float]] = OrderedDict()
|
||||||
|
_NEGATIVE_CACHE_TTL_SECONDS = 5.0
|
||||||
|
_VALIDATION_RETRY_EXHAUSTED_MESSAGE = "API key validation temporarily unavailable"
|
||||||
|
_MAX_VALIDATION_RETRIES = 2
|
||||||
|
_RETRY_DELAY_SECONDS = 0.25
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_api_key_cache_ttl_seconds() -> float:
|
||||||
|
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_TTL_SECONDS", "30")
|
||||||
|
try:
|
||||||
|
return max(1.0, float(raw))
|
||||||
|
except ValueError:
|
||||||
|
return 30.0
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_api_key_cache_max_size() -> int:
|
||||||
|
raw = os.environ.get("SKYVERN_MCP_API_KEY_CACHE_MAX_SIZE", "1024")
|
||||||
|
try:
|
||||||
|
return max(1, int(raw))
|
||||||
|
except ValueError:
|
||||||
|
return 1024
|
||||||
|
|
||||||
|
|
||||||
|
_API_KEY_CACHE_TTL_SECONDS = _resolve_api_key_cache_ttl_seconds()
|
||||||
|
_API_KEY_CACHE_MAX_SIZE = _resolve_api_key_cache_max_size()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_auth_db() -> AgentDB:
|
||||||
|
global _auth_db
|
||||||
|
# Guard singleton init in case HTTP transport is served with threaded workers.
|
||||||
|
with _auth_db_lock:
|
||||||
|
if _auth_db is None:
|
||||||
|
_auth_db = AgentDB(settings.DATABASE_STRING, debug_enabled=settings.DEBUG_MODE)
|
||||||
|
return _auth_db
|
||||||
|
|
||||||
|
|
||||||
|
async def close_auth_db() -> None:
|
||||||
|
"""Dispose the auth DB engine used by HTTP middleware, if initialized."""
|
||||||
|
global _auth_db
|
||||||
|
with _auth_db_lock:
|
||||||
|
db = _auth_db
|
||||||
|
_auth_db = None
|
||||||
|
with _api_key_cache_lock:
|
||||||
|
_api_key_validation_cache.clear()
|
||||||
|
if db is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.engine.dispose()
|
||||||
|
except Exception:
|
||||||
|
LOG.warning("Failed to dispose MCP auth DB engine", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_key(api_key: str) -> str:
|
||||||
|
return hash_api_key_for_cache(api_key)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_mcp_api_key(api_key: str) -> str:
|
||||||
|
"""Validate API key and return organization id for observability."""
|
||||||
|
key = _cache_key(api_key)
|
||||||
|
|
||||||
|
# Check cache first.
|
||||||
|
with _api_key_cache_lock:
|
||||||
|
cached = _api_key_validation_cache.get(key)
|
||||||
|
if cached is not None:
|
||||||
|
organization_id, expires_at = cached
|
||||||
|
if expires_at > time.monotonic():
|
||||||
|
_api_key_validation_cache.move_to_end(key)
|
||||||
|
if organization_id is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
return organization_id
|
||||||
|
_api_key_validation_cache.pop(key, None)
|
||||||
|
|
||||||
|
# Cache miss — do the DB lookup with simple retry on transient errors.
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for attempt in range(_MAX_VALIDATION_RETRIES + 1):
|
||||||
|
if attempt > 0:
|
||||||
|
await asyncio.sleep(_RETRY_DELAY_SECONDS)
|
||||||
|
try:
|
||||||
|
validation = await resolve_org_from_api_key(api_key, _get_auth_db())
|
||||||
|
organization_id = validation.organization.organization_id
|
||||||
|
with _api_key_cache_lock:
|
||||||
|
_api_key_validation_cache[key] = (
|
||||||
|
organization_id,
|
||||||
|
time.monotonic() + _API_KEY_CACHE_TTL_SECONDS,
|
||||||
|
)
|
||||||
|
_api_key_validation_cache.move_to_end(key)
|
||||||
|
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
|
||||||
|
_api_key_validation_cache.popitem(last=False)
|
||||||
|
return organization_id
|
||||||
|
except HTTPException as e:
|
||||||
|
if e.status_code in {401, 403}:
|
||||||
|
with _api_key_cache_lock:
|
||||||
|
_api_key_validation_cache[key] = (None, time.monotonic() + _NEGATIVE_CACHE_TTL_SECONDS)
|
||||||
|
_api_key_validation_cache.move_to_end(key)
|
||||||
|
while len(_api_key_validation_cache) > _API_KEY_CACHE_MAX_SIZE:
|
||||||
|
_api_key_validation_cache.popitem(last=False)
|
||||||
|
raise
|
||||||
|
last_exc = e
|
||||||
|
except Exception as e:
|
||||||
|
last_exc = e
|
||||||
|
|
||||||
|
LOG.warning("API key validation retries exhausted", attempts=_MAX_VALIDATION_RETRIES + 1, exc_info=last_exc)
|
||||||
|
raise HTTPException(status_code=503, detail=_VALIDATION_RETRY_EXHAUSTED_MESSAGE)
|
||||||
|
|
||||||
|
|
||||||
|
def _unauthorized_response(message: str) -> JSONResponse:
|
||||||
|
return JSONResponse({"error": {"code": "UNAUTHORIZED", "message": message}}, status_code=401)
|
||||||
|
|
||||||
|
|
||||||
|
def _internal_error_response() -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": {"code": "INTERNAL_ERROR", "message": "Internal server error"}},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _service_unavailable_response(message: str) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": {"code": "SERVICE_UNAVAILABLE", "message": message}},
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPAPIKeyMiddleware:
|
||||||
|
"""Require x-api-key for MCP HTTP transport and scope requests to that key."""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
if scope["type"] != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
request = Request(scope, receive=receive)
|
||||||
|
if request.url.path in HEALTH_PATHS:
|
||||||
|
response = JSONResponse({"status": "ok"})
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
api_key = request.headers.get(API_KEY_HEADER)
|
||||||
|
if not api_key:
|
||||||
|
response = _unauthorized_response("Missing x-api-key header")
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
organization_id = await validate_mcp_api_key(api_key)
|
||||||
|
scope.setdefault("state", {})
|
||||||
|
scope["state"]["organization_id"] = organization_id
|
||||||
|
except HTTPException as e:
|
||||||
|
if e.status_code in {401, 403}:
|
||||||
|
response = _unauthorized_response("Invalid API key")
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
if e.status_code == 503:
|
||||||
|
response = _service_unavailable_response(e.detail or _VALIDATION_RETRY_EXHAUSTED_MESSAGE)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
LOG.warning("Unexpected HTTPException during MCP API key validation", status_code=e.status_code)
|
||||||
|
response = _internal_error_response()
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
LOG.exception("Unexpected MCP API key validation failure")
|
||||||
|
response = _internal_error_response()
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
token = set_api_key_override(api_key)
|
||||||
|
try:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
finally:
|
||||||
|
reset_api_key_override(token)
|
||||||
@@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator
|
|||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from .client import get_skyvern
|
from .api_key_hash import hash_api_key_for_cache
|
||||||
|
from .client import get_active_api_key, get_skyvern
|
||||||
from .result import BrowserContext, ErrorCode, make_error
|
from .result import BrowserContext, ErrorCode, make_error
|
||||||
|
|
||||||
LOG = structlog.get_logger(__name__)
|
LOG = structlog.get_logger(__name__)
|
||||||
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
|
|||||||
class SessionState:
|
class SessionState:
|
||||||
browser: SkyvernBrowser | None = None
|
browser: SkyvernBrowser | None = None
|
||||||
context: BrowserContext | None = None
|
context: BrowserContext | None = None
|
||||||
|
api_key_hash: str | None = None
|
||||||
console_messages: list[dict[str, Any]] = field(default_factory=list)
|
console_messages: list[dict[str, Any]] = field(default_factory=list)
|
||||||
tracing_active: bool = False
|
tracing_active: bool = False
|
||||||
har_enabled: bool = False
|
har_enabled: bool = False
|
||||||
@@ -28,26 +30,52 @@ class SessionState:
|
|||||||
|
|
||||||
_current_session: ContextVar[SessionState | None] = ContextVar("mcp_session", default=None)
|
_current_session: ContextVar[SessionState | None] = ContextVar("mcp_session", default=None)
|
||||||
_global_session: SessionState | None = None
|
_global_session: SessionState | None = None
|
||||||
|
_stateless_http_mode = False
|
||||||
|
|
||||||
|
|
||||||
def get_current_session() -> SessionState:
|
def get_current_session() -> SessionState:
|
||||||
global _global_session
|
global _global_session
|
||||||
|
|
||||||
state = _current_session.get()
|
state = _current_session.get()
|
||||||
if state is None:
|
if state is not None:
|
||||||
if _global_session is None:
|
return state
|
||||||
_global_session = SessionState()
|
|
||||||
state = _global_session
|
# In stateless HTTP mode, avoid process-wide fallback state so requests
|
||||||
|
# cannot inherit session context from other requests.
|
||||||
|
if _stateless_http_mode:
|
||||||
|
state = SessionState()
|
||||||
_current_session.set(state)
|
_current_session.set(state)
|
||||||
|
return state
|
||||||
|
|
||||||
|
if _global_session is None:
|
||||||
|
_global_session = SessionState()
|
||||||
|
state = _global_session
|
||||||
|
_current_session.set(state)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
def set_current_session(state: SessionState) -> None:
|
def set_current_session(state: SessionState) -> None:
|
||||||
global _global_session
|
global _global_session
|
||||||
_global_session = state
|
if not _stateless_http_mode:
|
||||||
|
_global_session = state
|
||||||
_current_session.set(state)
|
_current_session.set(state)
|
||||||
|
|
||||||
|
|
||||||
|
def set_stateless_http_mode(enabled: bool) -> None:
|
||||||
|
global _stateless_http_mode
|
||||||
|
_stateless_http_mode = enabled
|
||||||
|
|
||||||
|
|
||||||
|
def is_stateless_http_mode() -> bool:
|
||||||
|
return _stateless_http_mode
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_hash(api_key: str | None) -> str | None:
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
return hash_api_key_for_cache(api_key)
|
||||||
|
|
||||||
|
|
||||||
def _matches_current(
|
def _matches_current(
|
||||||
current: SessionState,
|
current: SessionState,
|
||||||
*,
|
*,
|
||||||
@@ -57,6 +85,8 @@ def _matches_current(
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
if current.browser is None or current.context is None:
|
if current.browser is None or current.context is None:
|
||||||
return False
|
return False
|
||||||
|
if current.api_key_hash != _api_key_hash(get_active_api_key()):
|
||||||
|
return False
|
||||||
|
|
||||||
if session_id:
|
if session_id:
|
||||||
return current.context.mode == "cloud_session" and current.context.session_id == session_id
|
return current.context.mode == "cloud_session" and current.context.session_id == session_id
|
||||||
@@ -84,35 +114,39 @@ async def resolve_browser(
|
|||||||
skyvern = get_skyvern()
|
skyvern = get_skyvern()
|
||||||
current = get_current_session()
|
current = get_current_session()
|
||||||
|
|
||||||
|
if _stateless_http_mode and not (session_id or cdp_url or local or create_session):
|
||||||
|
raise BrowserNotAvailableError()
|
||||||
|
|
||||||
if _matches_current(current, session_id=session_id, cdp_url=cdp_url, local=local):
|
if _matches_current(current, session_id=session_id, cdp_url=cdp_url, local=local):
|
||||||
# _matches_current() guarantees both are non-None
|
if current.browser is None or current.context is None:
|
||||||
assert current.browser is not None and current.context is not None
|
raise RuntimeError("Expected active browser and context for matching session")
|
||||||
return current.browser, current.context
|
return current.browser, current.context
|
||||||
|
|
||||||
|
active_api_key_hash = _api_key_hash(get_active_api_key())
|
||||||
browser: SkyvernBrowser | None = None
|
browser: SkyvernBrowser | None = None
|
||||||
try:
|
try:
|
||||||
if session_id:
|
if session_id:
|
||||||
browser = await skyvern.connect_to_cloud_browser_session(session_id)
|
browser = await skyvern.connect_to_cloud_browser_session(session_id)
|
||||||
ctx = BrowserContext(mode="cloud_session", session_id=session_id)
|
ctx = BrowserContext(mode="cloud_session", session_id=session_id)
|
||||||
set_current_session(SessionState(browser=browser, context=ctx))
|
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||||
return browser, ctx
|
return browser, ctx
|
||||||
|
|
||||||
if cdp_url:
|
if cdp_url:
|
||||||
browser = await skyvern.connect_to_browser_over_cdp(cdp_url)
|
browser = await skyvern.connect_to_browser_over_cdp(cdp_url)
|
||||||
ctx = BrowserContext(mode="cdp", cdp_url=cdp_url)
|
ctx = BrowserContext(mode="cdp", cdp_url=cdp_url)
|
||||||
set_current_session(SessionState(browser=browser, context=ctx))
|
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||||
return browser, ctx
|
return browser, ctx
|
||||||
|
|
||||||
if local:
|
if local:
|
||||||
browser = await skyvern.launch_local_browser(headless=headless)
|
browser = await skyvern.launch_local_browser(headless=headless)
|
||||||
ctx = BrowserContext(mode="local")
|
ctx = BrowserContext(mode="local")
|
||||||
set_current_session(SessionState(browser=browser, context=ctx))
|
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||||
return browser, ctx
|
return browser, ctx
|
||||||
|
|
||||||
if create_session:
|
if create_session:
|
||||||
browser = await skyvern.launch_cloud_browser(timeout=timeout)
|
browser = await skyvern.launch_cloud_browser(timeout=timeout)
|
||||||
ctx = BrowserContext(mode="cloud_session", session_id=browser.browser_session_id)
|
ctx = BrowserContext(mode="cloud_session", session_id=browser.browser_session_id)
|
||||||
set_current_session(SessionState(browser=browser, context=ctx))
|
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=active_api_key_hash))
|
||||||
return browser, ctx
|
return browser, ctx
|
||||||
except Exception:
|
except Exception:
|
||||||
if browser is not None:
|
if browser is not None:
|
||||||
|
|||||||
@@ -4,7 +4,11 @@ from typing import Annotated, Any
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from skyvern.cli.core.api_key_hash import hash_api_key_for_cache
|
||||||
|
from skyvern.cli.core.client import get_active_api_key
|
||||||
|
from skyvern.cli.core.session_manager import is_stateless_http_mode
|
||||||
from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list
|
from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list
|
||||||
|
from skyvern.schemas.runs import ProxyLocation
|
||||||
|
|
||||||
from ._common import BrowserContext, ErrorCode, Timer, make_error, make_result
|
from ._common import BrowserContext, ErrorCode, Timer, make_error, make_result
|
||||||
from ._session import (
|
from ._session import (
|
||||||
@@ -16,6 +20,13 @@ from ._session import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _session_api_key_hash() -> str | None:
|
||||||
|
api_key = get_active_api_key()
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
return hash_api_key_for_cache(api_key)
|
||||||
|
|
||||||
|
|
||||||
async def skyvern_session_create(
|
async def skyvern_session_create(
|
||||||
timeout: Annotated[int | None, Field(description="Session timeout in minutes (5-1440)")] = 60,
|
timeout: Annotated[int | None, Field(description="Session timeout in minutes (5-1440)")] = 60,
|
||||||
proxy_location: Annotated[str | None, Field(description="Proxy location: RESIDENTIAL, US, etc.")] = None,
|
proxy_location: Annotated[str | None, Field(description="Proxy location: RESIDENTIAL, US, etc.")] = None,
|
||||||
@@ -29,7 +40,33 @@ async def skyvern_session_create(
|
|||||||
"""
|
"""
|
||||||
with Timer() as timer:
|
with Timer() as timer:
|
||||||
try:
|
try:
|
||||||
|
if is_stateless_http_mode() and local:
|
||||||
|
return make_result(
|
||||||
|
"skyvern_session_create",
|
||||||
|
ok=False,
|
||||||
|
error=make_error(
|
||||||
|
ErrorCode.INVALID_INPUT,
|
||||||
|
"Local browser sessions are not supported in stateless HTTP mode",
|
||||||
|
"Use cloud sessions for remote MCP transport",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
skyvern = get_skyvern()
|
skyvern = get_skyvern()
|
||||||
|
if is_stateless_http_mode():
|
||||||
|
proxy = ProxyLocation(proxy_location) if proxy_location else None
|
||||||
|
session = await skyvern.create_browser_session(timeout=timeout or 60, proxy_location=proxy)
|
||||||
|
timer.mark("sdk")
|
||||||
|
ctx = BrowserContext(mode="cloud_session", session_id=session.browser_session_id)
|
||||||
|
return make_result(
|
||||||
|
"skyvern_session_create",
|
||||||
|
browser_context=ctx,
|
||||||
|
data={
|
||||||
|
"session_id": session.browser_session_id,
|
||||||
|
"timeout_minutes": timeout or 60,
|
||||||
|
},
|
||||||
|
timing_ms=timer.timing_ms,
|
||||||
|
)
|
||||||
|
|
||||||
browser, result = await do_session_create(
|
browser, result = await do_session_create(
|
||||||
skyvern,
|
skyvern,
|
||||||
timeout=timeout or 60,
|
timeout=timeout or 60,
|
||||||
@@ -43,7 +80,7 @@ async def skyvern_session_create(
|
|||||||
ctx = BrowserContext(mode="local")
|
ctx = BrowserContext(mode="local")
|
||||||
else:
|
else:
|
||||||
ctx = BrowserContext(mode="cloud_session", session_id=result.session_id)
|
ctx = BrowserContext(mode="cloud_session", session_id=result.session_id)
|
||||||
set_current_session(SessionState(browser=browser, context=ctx))
|
set_current_session(SessionState(browser=browser, context=ctx, api_key_hash=_session_api_key_hash()))
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return make_result(
|
return make_result(
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Any, List, Optional
|
from typing import Annotated, List, Literal, Optional
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import typer
|
import typer
|
||||||
@@ -13,17 +13,17 @@ import uvicorn
|
|||||||
from dotenv import load_dotenv, set_key
|
from dotenv import load_dotenv, set_key
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.prompt import Confirm
|
from rich.prompt import Confirm
|
||||||
|
from starlette.middleware import Middleware
|
||||||
|
|
||||||
from skyvern.cli.console import console
|
from skyvern.cli.console import console
|
||||||
from skyvern.cli.core.client import close_skyvern
|
from skyvern.cli.core.client import close_skyvern
|
||||||
from skyvern.cli.core.session_manager import close_current_session
|
from skyvern.cli.core.mcp_http_auth import MCPAPIKeyMiddleware, close_auth_db
|
||||||
|
from skyvern.cli.core.session_manager import close_current_session, set_stateless_http_mode
|
||||||
from skyvern.cli.mcp_tools import mcp # Uses standalone fastmcp (v2.x)
|
from skyvern.cli.mcp_tools import mcp # Uses standalone fastmcp (v2.x)
|
||||||
from skyvern.cli.utils import start_services
|
from skyvern.cli.utils import start_services
|
||||||
from skyvern.client import SkyvernEnvironment
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.forge_log import setup_logger
|
from skyvern.forge.sdk.forge_log import setup_logger
|
||||||
from skyvern.library.skyvern import Skyvern
|
|
||||||
from skyvern.services.script_service import run_script
|
from skyvern.services.script_service import run_script
|
||||||
from skyvern.utils import detect_os
|
from skyvern.utils import detect_os
|
||||||
from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path
|
from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path
|
||||||
@@ -36,7 +36,10 @@ async def _cleanup_mcp_resources() -> None:
|
|||||||
try:
|
try:
|
||||||
await close_current_session()
|
await close_current_session()
|
||||||
finally:
|
finally:
|
||||||
await close_skyvern()
|
try:
|
||||||
|
await close_skyvern()
|
||||||
|
finally:
|
||||||
|
await close_auth_db()
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_mcp_resources_blocking() -> None:
|
def _cleanup_mcp_resources_blocking() -> None:
|
||||||
@@ -67,39 +70,6 @@ def _cleanup_mcp_resources_sync() -> None:
|
|||||||
logger.debug("Skipping MCP cleanup because event loop is still running")
|
logger.debug("Skipping MCP cleanup because event loop is still running")
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def skyvern_run_task(prompt: str, url: str) -> dict[str, Any]:
|
|
||||||
"""Use Skyvern to execute anything in the browser. Useful for accomplishing tasks that require browser automation.
|
|
||||||
|
|
||||||
This tool uses Skyvern's browser automation to navigate websites and perform actions to achieve
|
|
||||||
the user's intended outcome. It can handle tasks like form filling, clicking buttons, data extraction,
|
|
||||||
and multi-step workflows.
|
|
||||||
|
|
||||||
It can even help you find updated data on the internet if your model information is outdated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: A natural language description of what needs to be accomplished (e.g. "Book a flight from
|
|
||||||
NYC to LA", "Sign up for the newsletter", "Find the price of item X", "Apply to a job")
|
|
||||||
url: The starting URL of the website where the task should be performed
|
|
||||||
"""
|
|
||||||
skyvern_agent = Skyvern(
|
|
||||||
environment=SkyvernEnvironment.CLOUD,
|
|
||||||
base_url=settings.SKYVERN_BASE_URL,
|
|
||||||
api_key=settings.SKYVERN_API_KEY,
|
|
||||||
)
|
|
||||||
res = await skyvern_agent.run_task(prompt=prompt, url=url, user_agent="skyvern-mcp", wait_for_completion=True)
|
|
||||||
|
|
||||||
output = res.model_dump()["output"]
|
|
||||||
if res.app_url:
|
|
||||||
task_url = res.app_url
|
|
||||||
else:
|
|
||||||
if res.run_id and res.run_id.startswith("wr_"):
|
|
||||||
task_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/runs/{res.run_id}/overview"
|
|
||||||
else:
|
|
||||||
task_url = f"{settings.SKYVERN_APP_URL.rstrip('/')}/tasks/{res.run_id}/actions"
|
|
||||||
return {"output": output, "task_url": task_url, "run_id": res.run_id}
|
|
||||||
|
|
||||||
|
|
||||||
def get_pids_on_port(port: int) -> List[int]:
|
def get_pids_on_port(port: int) -> List[int]:
|
||||||
"""Return a list of PIDs listening on the given port."""
|
"""Return a list of PIDs listening on the given port."""
|
||||||
pids = []
|
pids = []
|
||||||
@@ -295,20 +265,61 @@ def run_dev() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@run_app.command(name="mcp")
|
@run_app.command(name="mcp")
|
||||||
def run_mcp() -> None:
|
def run_mcp(
|
||||||
"""Run the MCP server."""
|
transport: Annotated[
|
||||||
# This breaks the MCP processing because it expects json output only
|
Literal["stdio", "sse", "streamable-http"],
|
||||||
# console.print(Panel("[bold green]Starting MCP Server...[/bold green]", border_style="green"))
|
typer.Option(
|
||||||
|
"--transport",
|
||||||
|
help="MCP transport: stdio (default), sse, or streamable-http.",
|
||||||
|
),
|
||||||
|
] = "stdio",
|
||||||
|
host: Annotated[str, typer.Option("--host", help="Host for HTTP transports.")] = "0.0.0.0",
|
||||||
|
port: Annotated[int, typer.Option("--port", help="Port for HTTP transports.")] = 8000,
|
||||||
|
path: Annotated[str, typer.Option("--path", help="HTTP endpoint path for MCP transport.")] = "/mcp",
|
||||||
|
stateless_http: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--stateless-http/--no-stateless-http",
|
||||||
|
help="Use stateless HTTP semantics for HTTP transports (ignored for stdio).",
|
||||||
|
),
|
||||||
|
] = True,
|
||||||
|
) -> None:
|
||||||
|
"""Run the MCP server with configurable transport for local or remote hosting."""
|
||||||
|
path = _normalize_mcp_path(path)
|
||||||
|
stateless_http_enabled = transport != "stdio" and stateless_http
|
||||||
# atexit covers signal-based exits (SIGTERM); finally covers normal
|
# atexit covers signal-based exits (SIGTERM); finally covers normal
|
||||||
# mcp.run() completion or unhandled exceptions. Both are needed because
|
# mcp.run() completion or unhandled exceptions. Both are needed because
|
||||||
# atexit doesn't fire on normal return and finally doesn't fire on signals.
|
# atexit doesn't fire on normal return and finally doesn't fire on signals.
|
||||||
atexit.register(_cleanup_mcp_resources_sync)
|
atexit.register(_cleanup_mcp_resources_sync)
|
||||||
|
set_stateless_http_mode(stateless_http_enabled)
|
||||||
try:
|
try:
|
||||||
mcp.run(transport="stdio")
|
if transport == "stdio":
|
||||||
|
mcp.run(transport="stdio")
|
||||||
|
return
|
||||||
|
|
||||||
|
middleware = [Middleware(MCPAPIKeyMiddleware)]
|
||||||
|
mcp.run(
|
||||||
|
transport=transport,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
path=path,
|
||||||
|
middleware=middleware,
|
||||||
|
stateless_http=stateless_http_enabled,
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
|
set_stateless_http_mode(False)
|
||||||
_cleanup_mcp_resources_blocking()
|
_cleanup_mcp_resources_blocking()
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_mcp_path(path: str) -> str:
|
||||||
|
path = path.strip()
|
||||||
|
if not path:
|
||||||
|
return "/mcp"
|
||||||
|
if not path.startswith("/"):
|
||||||
|
return f"/{path}"
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
@run_app.command(
|
@run_app.command(
|
||||||
name="code",
|
name="code",
|
||||||
context_settings={"allow_interspersed_args": False},
|
context_settings={"allow_interspersed_args": False},
|
||||||
|
|||||||
@@ -1,27 +1,80 @@
|
|||||||
"""Workflow-related CLI helpers."""
|
"""Workflow-related CLI commands with MCP-parity flags and output."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Coroutine
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from rich.panel import Panel
|
|
||||||
|
|
||||||
from skyvern.client import Skyvern
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.utils.env_paths import resolve_backend_env_path
|
from skyvern.utils.env_paths import resolve_backend_env_path
|
||||||
|
|
||||||
from .console import console
|
from .commands._output import output, output_error
|
||||||
from .tasks import _list_workflow_tasks
|
from .mcp_tools.workflow import skyvern_workflow_cancel as tool_workflow_cancel
|
||||||
|
from .mcp_tools.workflow import skyvern_workflow_create as tool_workflow_create
|
||||||
|
from .mcp_tools.workflow import skyvern_workflow_delete as tool_workflow_delete
|
||||||
|
from .mcp_tools.workflow import skyvern_workflow_get as tool_workflow_get
|
||||||
|
from .mcp_tools.workflow import skyvern_workflow_list as tool_workflow_list
|
||||||
|
from .mcp_tools.workflow import skyvern_workflow_run as tool_workflow_run
|
||||||
|
from .mcp_tools.workflow import skyvern_workflow_status as tool_workflow_status
|
||||||
|
from .mcp_tools.workflow import skyvern_workflow_update as tool_workflow_update
|
||||||
|
|
||||||
workflow_app = typer.Typer(help="Manage Skyvern workflows.")
|
workflow_app = typer.Typer(help="Manage Skyvern workflows.", no_args_is_help=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _emit_tool_result(result: dict[str, Any], *, json_output: bool) -> None:
|
||||||
|
if json_output:
|
||||||
|
json.dump(result, sys.stdout, indent=2, default=str)
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
if not result.get("ok", False):
|
||||||
|
raise SystemExit(1)
|
||||||
|
return
|
||||||
|
|
||||||
|
if result.get("ok", False):
|
||||||
|
output(result.get("data"), action=str(result.get("action", "")), json_mode=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
err = result.get("error") or {}
|
||||||
|
output_error(str(err.get("message", "Unknown error")), hint=str(err.get("hint", "")), json_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_tool(
|
||||||
|
runner: Callable[[], Coroutine[Any, Any, dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
json_output: bool,
|
||||||
|
hint_on_exception: str,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
result: dict[str, Any] = asyncio.run(runner())
|
||||||
|
_emit_tool_result(result, json_output=json_output)
|
||||||
|
except typer.BadParameter:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
output_error(str(e), hint=hint_on_exception, json_mode=json_output)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_inline_or_file(value: str | None, *, param_name: str) -> str | None:
|
||||||
|
if value is None or not value.startswith("@"):
|
||||||
|
return value
|
||||||
|
|
||||||
|
file_path = value[1:]
|
||||||
|
if not file_path:
|
||||||
|
raise typer.BadParameter(f"{param_name} file path cannot be empty after '@'.")
|
||||||
|
|
||||||
|
path = Path(file_path).expanduser()
|
||||||
|
try:
|
||||||
|
return path.read_text(encoding="utf-8")
|
||||||
|
except OSError as e:
|
||||||
|
raise typer.BadParameter(f"Unable to read {param_name} file '{path}': {e}") from e
|
||||||
|
|
||||||
|
|
||||||
@workflow_app.callback()
|
@workflow_app.callback()
|
||||||
def workflow_callback(
|
def workflow_callback(
|
||||||
ctx: typer.Context,
|
|
||||||
api_key: str | None = typer.Option(
|
api_key: str | None = typer.Option(
|
||||||
None,
|
None,
|
||||||
"--api-key",
|
"--api-key",
|
||||||
@@ -29,86 +82,188 @@ def workflow_callback(
|
|||||||
envvar="SKYVERN_API_KEY",
|
envvar="SKYVERN_API_KEY",
|
||||||
),
|
),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store the provided API key in the Typer context."""
|
"""Load workflow CLI environment and optional API key override."""
|
||||||
ctx.obj = {"api_key": api_key}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client(api_key: str | None = None) -> Skyvern:
|
|
||||||
"""Instantiate a Skyvern SDK client using environment variables."""
|
|
||||||
load_dotenv(resolve_backend_env_path())
|
load_dotenv(resolve_backend_env_path())
|
||||||
key = api_key or os.getenv("SKYVERN_API_KEY") or settings.SKYVERN_API_KEY
|
if api_key:
|
||||||
return Skyvern(base_url=settings.SKYVERN_BASE_URL, api_key=key)
|
settings.SKYVERN_API_KEY = api_key
|
||||||
|
|
||||||
|
|
||||||
|
@workflow_app.command("list")
|
||||||
|
def workflow_list(
|
||||||
|
search: str | None = typer.Option(
|
||||||
|
None,
|
||||||
|
"--search",
|
||||||
|
help="Search across workflow titles, folder names, and parameter metadata.",
|
||||||
|
),
|
||||||
|
page: int = typer.Option(1, "--page", min=1, help="Page number (1-based)."),
|
||||||
|
page_size: int = typer.Option(10, "--page-size", min=1, max=100, help="Results per page."),
|
||||||
|
only_workflows: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--only-workflows",
|
||||||
|
help="Only return multi-step workflows (exclude saved tasks).",
|
||||||
|
),
|
||||||
|
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||||
|
) -> None:
|
||||||
|
"""List workflows."""
|
||||||
|
|
||||||
|
async def _run() -> dict[str, Any]:
|
||||||
|
return await tool_workflow_list(
|
||||||
|
search=search,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
only_workflows=only_workflows,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_tool(_run, json_output=json_output, hint_on_exception="Check your API key and workflow list filters.")
|
||||||
|
|
||||||
|
|
||||||
|
@workflow_app.command("get")
|
||||||
|
def workflow_get(
|
||||||
|
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
|
||||||
|
version: int | None = typer.Option(None, "--version", min=1, help="Specific version to retrieve."),
|
||||||
|
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||||
|
) -> None:
|
||||||
|
"""Get a workflow definition by ID."""
|
||||||
|
|
||||||
|
async def _run() -> dict[str, Any]:
|
||||||
|
return await tool_workflow_get(workflow_id=workflow_id, version=version)
|
||||||
|
|
||||||
|
_run_tool(_run, json_output=json_output, hint_on_exception="Check your API key and workflow ID.")
|
||||||
|
|
||||||
|
|
||||||
|
@workflow_app.command("create")
|
||||||
|
def workflow_create(
|
||||||
|
definition: str = typer.Option(
|
||||||
|
...,
|
||||||
|
"--definition",
|
||||||
|
help="Workflow definition as YAML/JSON string or @file path.",
|
||||||
|
),
|
||||||
|
definition_format: str = typer.Option(
|
||||||
|
"auto",
|
||||||
|
"--format",
|
||||||
|
help="Definition format: json, yaml, or auto.",
|
||||||
|
),
|
||||||
|
folder_id: str | None = typer.Option(None, "--folder-id", help="Folder ID (fld_...) for the workflow."),
|
||||||
|
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||||
|
) -> None:
|
||||||
|
"""Create a workflow."""
|
||||||
|
|
||||||
|
async def _run() -> dict[str, Any]:
|
||||||
|
resolved_definition = _resolve_inline_or_file(definition, param_name="definition")
|
||||||
|
assert resolved_definition is not None
|
||||||
|
return await tool_workflow_create(
|
||||||
|
definition=resolved_definition,
|
||||||
|
format=definition_format,
|
||||||
|
folder_id=folder_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow definition syntax.")
|
||||||
|
|
||||||
|
|
||||||
|
@workflow_app.command("update")
|
||||||
|
def workflow_update(
|
||||||
|
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
|
||||||
|
definition: str = typer.Option(
|
||||||
|
...,
|
||||||
|
"--definition",
|
||||||
|
help="Updated workflow definition as YAML/JSON string or @file path.",
|
||||||
|
),
|
||||||
|
definition_format: str = typer.Option(
|
||||||
|
"auto",
|
||||||
|
"--format",
|
||||||
|
help="Definition format: json, yaml, or auto.",
|
||||||
|
),
|
||||||
|
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||||
|
) -> None:
|
||||||
|
"""Update a workflow definition."""
|
||||||
|
|
||||||
|
async def _run() -> dict[str, Any]:
|
||||||
|
resolved_definition = _resolve_inline_or_file(definition, param_name="definition")
|
||||||
|
assert resolved_definition is not None
|
||||||
|
return await tool_workflow_update(
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
definition=resolved_definition,
|
||||||
|
format=definition_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and definition syntax.")
|
||||||
|
|
||||||
|
|
||||||
|
@workflow_app.command("delete")
|
||||||
|
def workflow_delete(
|
||||||
|
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
|
||||||
|
force: bool = typer.Option(False, "--force", help="Confirm permanent deletion."),
|
||||||
|
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||||
|
) -> None:
|
||||||
|
"""Delete a workflow."""
|
||||||
|
|
||||||
|
async def _run() -> dict[str, Any]:
|
||||||
|
return await tool_workflow_delete(workflow_id=workflow_id, force=force)
|
||||||
|
|
||||||
|
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and your permissions.")
|
||||||
|
|
||||||
|
|
||||||
@workflow_app.command("run")
|
@workflow_app.command("run")
|
||||||
def run_workflow(
|
def workflow_run(
|
||||||
ctx: typer.Context,
|
workflow_id: str = typer.Option(..., "--id", help="Workflow permanent ID (wpid_...)."),
|
||||||
workflow_id: str = typer.Argument(..., help="Workflow permanent ID"),
|
params: str | None = typer.Option(
|
||||||
parameters: str = typer.Option("{}", "--parameters", "-p", help="JSON parameters for the workflow"),
|
None,
|
||||||
title: str | None = typer.Option(None, "--title", help="Title for the workflow run"),
|
"--params",
|
||||||
max_steps: int | None = typer.Option(None, "--max-steps", help="Override the workflow max steps"),
|
"--parameters",
|
||||||
|
"-p",
|
||||||
|
help="Workflow parameters as JSON string or @file path.",
|
||||||
|
),
|
||||||
|
session: str | None = typer.Option(None, "--session", help="Browser session ID (pbs_...) to reuse."),
|
||||||
|
webhook: str | None = typer.Option(None, "--webhook", help="Status webhook callback URL."),
|
||||||
|
proxy: str | None = typer.Option(None, "--proxy", help="Proxy location (e.g., RESIDENTIAL)."),
|
||||||
|
wait: bool = typer.Option(False, "--wait", help="Wait for workflow completion before returning."),
|
||||||
|
timeout: int = typer.Option(
|
||||||
|
300,
|
||||||
|
"--timeout",
|
||||||
|
min=10,
|
||||||
|
max=3600,
|
||||||
|
help="Max wait time in seconds when --wait is set.",
|
||||||
|
),
|
||||||
|
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run a workflow."""
|
"""Run a workflow."""
|
||||||
try:
|
|
||||||
params_dict = json.loads(parameters) if parameters else {}
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
console.print(f"[red]Invalid JSON for parameters: {parameters}[/red]")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
|
async def _run() -> dict[str, Any]:
|
||||||
run_resp = client.run_workflow(
|
resolved_params = _resolve_inline_or_file(params, param_name="params")
|
||||||
workflow_id=workflow_id,
|
return await tool_workflow_run(
|
||||||
parameters=params_dict,
|
workflow_id=workflow_id,
|
||||||
title=title,
|
parameters=resolved_params,
|
||||||
max_steps_override=max_steps,
|
browser_session_id=session,
|
||||||
)
|
webhook_url=webhook,
|
||||||
console.print(
|
proxy_location=proxy,
|
||||||
Panel(
|
wait=wait,
|
||||||
f"Started workflow run [bold]{run_resp.run_id}[/bold]",
|
timeout_seconds=timeout,
|
||||||
border_style="green",
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
_run_tool(_run, json_output=json_output, hint_on_exception="Check the workflow ID and run parameters.")
|
||||||
@workflow_app.command("cancel")
|
|
||||||
def cancel_workflow(
|
|
||||||
ctx: typer.Context,
|
|
||||||
run_id: str = typer.Argument(..., help="ID of the workflow run"),
|
|
||||||
) -> None:
|
|
||||||
"""Cancel a running workflow."""
|
|
||||||
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
|
|
||||||
client.cancel_run(run_id=run_id)
|
|
||||||
console.print(Panel(f"Cancel signal sent for run {run_id}", border_style="red"))
|
|
||||||
|
|
||||||
|
|
||||||
@workflow_app.command("status")
|
@workflow_app.command("status")
|
||||||
def workflow_status(
|
def workflow_status(
|
||||||
ctx: typer.Context,
|
run_id: str = typer.Option(..., "--run-id", help="Run ID (wr_... or tsk_v2_...)."),
|
||||||
run_id: str = typer.Argument(..., help="ID of the workflow run"),
|
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||||
tasks: bool = typer.Option(False, "--tasks", help="Show task executions"),
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Retrieve status information for a workflow run."""
|
"""Get workflow run status."""
|
||||||
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
|
|
||||||
run = client.get_run(run_id=run_id)
|
async def _run() -> dict[str, Any]:
|
||||||
console.print(Panel(run.model_dump_json(indent=2), border_style="cyan"))
|
return await tool_workflow_status(run_id=run_id)
|
||||||
if tasks:
|
|
||||||
task_list = _list_workflow_tasks(client, run_id)
|
_run_tool(_run, json_output=json_output, hint_on_exception="Check the run ID and API key.")
|
||||||
console.print(Panel(json.dumps(task_list, indent=2), border_style="magenta"))
|
|
||||||
|
|
||||||
|
|
||||||
@workflow_app.command("list")
|
@workflow_app.command("cancel")
|
||||||
def list_workflows(
|
def workflow_cancel(
|
||||||
ctx: typer.Context,
|
run_id: str = typer.Option(..., "--run-id", help="Run ID (wr_... or tsk_v2_...)."),
|
||||||
page: int = typer.Option(1, "--page", help="Page number"),
|
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||||
page_size: int = typer.Option(10, "--page-size", help="Number of workflows to return"),
|
|
||||||
template: bool = typer.Option(False, "--template", help="List template workflows"),
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List workflows for the organization."""
|
"""Cancel a workflow run."""
|
||||||
client = _get_client(ctx.obj.get("api_key") if ctx.obj else None)
|
|
||||||
resp = client._client_wrapper.httpx_client.request(
|
async def _run() -> dict[str, Any]:
|
||||||
"api/v1/workflows",
|
return await tool_workflow_cancel(run_id=run_id)
|
||||||
method="GET",
|
|
||||||
params={"page": page, "page_size": page_size, "template": template},
|
_run_tool(_run, json_output=json_output, hint_on_exception="Check the run ID and API key.")
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
console.print(Panel(json.dumps(resp.json(), indent=2), border_style="cyan"))
|
|
||||||
|
|||||||
@@ -296,3 +296,175 @@ class TestBrowserCommands:
|
|||||||
parsed = json.loads(capsys.readouterr().out)
|
parsed = json.loads(capsys.readouterr().out)
|
||||||
assert parsed["ok"] is False
|
assert parsed["ok"] is False
|
||||||
assert "Invalid state" in parsed["error"]["message"]
|
assert "Invalid state" in parsed["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Workflow command behavior
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowCommands:
|
||||||
|
def test_workflow_get_outputs_mcp_envelope_in_json_mode(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||||
|
) -> None:
|
||||||
|
from skyvern.cli import workflow as workflow_cmd
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"ok": True,
|
||||||
|
"action": "skyvern_workflow_get",
|
||||||
|
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||||
|
"data": {"workflow_permanent_id": "wpid_123"},
|
||||||
|
"artifacts": [],
|
||||||
|
"timing_ms": {},
|
||||||
|
"warnings": [],
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
tool = AsyncMock(return_value=expected)
|
||||||
|
monkeypatch.setattr(workflow_cmd, "tool_workflow_get", tool)
|
||||||
|
|
||||||
|
workflow_cmd.workflow_get(workflow_id="wpid_123", version=2, json_output=True)
|
||||||
|
|
||||||
|
parsed = json.loads(capsys.readouterr().out)
|
||||||
|
assert parsed == expected
|
||||||
|
assert tool.await_args.kwargs == {"workflow_id": "wpid_123", "version": 2}
|
||||||
|
|
||||||
|
def test_workflow_create_reads_definition_from_file(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
capsys: pytest.CaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
from skyvern.cli import workflow as workflow_cmd
|
||||||
|
|
||||||
|
definition_file = tmp_path / "workflow.json"
|
||||||
|
definition_text = '{"title": "Example", "workflow_definition": {"blocks": []}}'
|
||||||
|
definition_file.write_text(definition_text)
|
||||||
|
|
||||||
|
tool = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"ok": True,
|
||||||
|
"action": "skyvern_workflow_create",
|
||||||
|
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||||
|
"data": {"workflow_permanent_id": "wpid_new"},
|
||||||
|
"artifacts": [],
|
||||||
|
"timing_ms": {},
|
||||||
|
"warnings": [],
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(workflow_cmd, "tool_workflow_create", tool)
|
||||||
|
|
||||||
|
workflow_cmd.workflow_create(
|
||||||
|
definition=f"@{definition_file}",
|
||||||
|
definition_format="json",
|
||||||
|
folder_id="fld_123",
|
||||||
|
json_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool.await_args.kwargs == {
|
||||||
|
"definition": definition_text,
|
||||||
|
"format": "json",
|
||||||
|
"folder_id": "fld_123",
|
||||||
|
}
|
||||||
|
parsed = json.loads(capsys.readouterr().out)
|
||||||
|
assert parsed["ok"] is True
|
||||||
|
assert parsed["data"]["workflow_permanent_id"] == "wpid_new"
|
||||||
|
|
||||||
|
def test_workflow_run_reads_params_file_and_maps_options(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
capsys: pytest.CaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
from skyvern.cli import workflow as workflow_cmd
|
||||||
|
|
||||||
|
params_file = tmp_path / "params.json"
|
||||||
|
params_file.write_text('{"company": "Acme"}')
|
||||||
|
|
||||||
|
tool = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"ok": True,
|
||||||
|
"action": "skyvern_workflow_run",
|
||||||
|
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||||
|
"data": {"run_id": "wr_123", "status": "queued"},
|
||||||
|
"artifacts": [],
|
||||||
|
"timing_ms": {},
|
||||||
|
"warnings": [],
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(workflow_cmd, "tool_workflow_run", tool)
|
||||||
|
|
||||||
|
workflow_cmd.workflow_run(
|
||||||
|
workflow_id="wpid_123",
|
||||||
|
params=f"@{params_file}",
|
||||||
|
session="pbs_456",
|
||||||
|
webhook="https://example.com/webhook",
|
||||||
|
proxy="RESIDENTIAL",
|
||||||
|
wait=True,
|
||||||
|
timeout=450,
|
||||||
|
json_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool.await_args.kwargs == {
|
||||||
|
"workflow_id": "wpid_123",
|
||||||
|
"parameters": '{"company": "Acme"}',
|
||||||
|
"browser_session_id": "pbs_456",
|
||||||
|
"webhook_url": "https://example.com/webhook",
|
||||||
|
"proxy_location": "RESIDENTIAL",
|
||||||
|
"wait": True,
|
||||||
|
"timeout_seconds": 450,
|
||||||
|
}
|
||||||
|
parsed = json.loads(capsys.readouterr().out)
|
||||||
|
assert parsed["ok"] is True
|
||||||
|
assert parsed["data"]["run_id"] == "wr_123"
|
||||||
|
|
||||||
|
def test_workflow_status_json_error_exits_nonzero(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||||
|
) -> None:
|
||||||
|
from skyvern.cli import workflow as workflow_cmd
|
||||||
|
|
||||||
|
tool = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"ok": False,
|
||||||
|
"action": "skyvern_workflow_status",
|
||||||
|
"browser_context": {"mode": "none", "session_id": None, "cdp_url": None},
|
||||||
|
"data": None,
|
||||||
|
"artifacts": [],
|
||||||
|
"timing_ms": {},
|
||||||
|
"warnings": [],
|
||||||
|
"error": {
|
||||||
|
"code": "RUN_NOT_FOUND",
|
||||||
|
"message": "Run 'wr_missing' not found",
|
||||||
|
"hint": "Verify the run ID",
|
||||||
|
"details": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(workflow_cmd, "tool_workflow_status", tool)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit, match="1"):
|
||||||
|
workflow_cmd.workflow_status(run_id="wr_missing", json_output=True)
|
||||||
|
|
||||||
|
parsed = json.loads(capsys.readouterr().out)
|
||||||
|
assert parsed["ok"] is False
|
||||||
|
assert parsed["error"]["code"] == "RUN_NOT_FOUND"
|
||||||
|
|
||||||
|
def test_workflow_update_missing_definition_file_raises_bad_parameter(
|
||||||
|
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
from skyvern.cli import workflow as workflow_cmd
|
||||||
|
|
||||||
|
tool = AsyncMock()
|
||||||
|
monkeypatch.setattr(workflow_cmd, "tool_workflow_update", tool)
|
||||||
|
missing_file = tmp_path / "missing-definition.json"
|
||||||
|
|
||||||
|
with pytest.raises(typer.BadParameter, match="Unable to read definition file"):
|
||||||
|
workflow_cmd.workflow_update(
|
||||||
|
workflow_id="wpid_123",
|
||||||
|
definition=f"@{missing_file}",
|
||||||
|
definition_format="json",
|
||||||
|
json_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool.assert_not_called()
|
||||||
|
|||||||
304
tests/unit/test_mcp_http_auth.py
Normal file
304
tests/unit/test_mcp_http_auth.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.middleware import Middleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.routing import Route
|
||||||
|
|
||||||
|
from skyvern.cli.core import client as client_mod
|
||||||
|
from skyvern.cli.core import mcp_http_auth
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_auth_context() -> None:
|
||||||
|
client_mod._api_key_override.set(None)
|
||||||
|
mcp_http_auth._auth_db = None
|
||||||
|
mcp_http_auth._api_key_validation_cache.clear()
|
||||||
|
mcp_http_auth._API_KEY_CACHE_TTL_SECONDS = 30.0
|
||||||
|
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
|
||||||
|
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
|
||||||
|
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
|
||||||
|
|
||||||
|
|
||||||
|
async def _echo_request_context(request: Request) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"api_key": client_mod.get_active_api_key(),
|
||||||
|
"organization_id": getattr(request.state, "organization_id", None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_test_app() -> Starlette:
|
||||||
|
return Starlette(
|
||||||
|
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["POST"])],
|
||||||
|
middleware=[Middleware(mcp_http_auth.MCPAPIKeyMiddleware)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_rejects_missing_api_key() -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post("/mcp", json={})
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||||
|
assert "x-api-key" in response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_allows_health_checks_without_api_key() -> None:
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.get("/healthz")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_rejects_invalid_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mcp_http_auth,
|
||||||
|
"validate_mcp_api_key",
|
||||||
|
AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials")),
|
||||||
|
)
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post("/mcp", headers={"x-api-key": "bad-key"}, json={})
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
||||||
|
assert response.json()["error"]["message"] == "Invalid API key"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_returns_500_on_non_auth_http_exception(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mcp_http_auth,
|
||||||
|
"validate_mcp_api_key",
|
||||||
|
AsyncMock(side_effect=HTTPException(status_code=500, detail="db down")),
|
||||||
|
)
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||||
|
|
||||||
|
assert response.status_code == 500
|
||||||
|
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_returns_503_on_transient_validation_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mcp_http_auth,
|
||||||
|
"validate_mcp_api_key",
|
||||||
|
AsyncMock(side_effect=HTTPException(status_code=503, detail="API key validation temporarily unavailable")),
|
||||||
|
)
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||||
|
|
||||||
|
assert response.status_code == 503
|
||||||
|
assert response.json()["error"]["code"] == "SERVICE_UNAVAILABLE"
|
||||||
|
assert response.json()["error"]["message"] == "API key validation temporarily unavailable"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_returns_500_on_unexpected_validation_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mcp_http_auth,
|
||||||
|
"validate_mcp_api_key",
|
||||||
|
AsyncMock(side_effect=RuntimeError("boom")),
|
||||||
|
)
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||||
|
|
||||||
|
assert response.status_code == 500
|
||||||
|
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mcp_http_auth,
|
||||||
|
"validate_mcp_api_key",
|
||||||
|
AsyncMock(return_value="org_123"),
|
||||||
|
)
|
||||||
|
app = _build_test_app()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") as client:
|
||||||
|
response = await client.post("/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {
|
||||||
|
"api_key": "sk_live_abc",
|
||||||
|
"organization_id": "org_123",
|
||||||
|
}
|
||||||
|
assert client_mod.get_active_api_key() != "sk_live_abc"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
calls = 0
|
||||||
|
|
||||||
|
async def _resolve(_api_key: str, _db: object) -> object:
|
||||||
|
nonlocal calls
|
||||||
|
calls += 1
|
||||||
|
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_cached"))
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
|
|
||||||
|
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||||
|
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
||||||
|
|
||||||
|
assert first == "org_cached"
|
||||||
|
assert second == "org_cached"
|
||||||
|
assert calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
calls = 0
|
||||||
|
|
||||||
|
async def _resolve(_api_key: str, _db: object) -> object:
|
||||||
|
nonlocal calls
|
||||||
|
calls += 1
|
||||||
|
return SimpleNamespace(organization=SimpleNamespace(organization_id=f"org_{calls}"))
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
|
|
||||||
|
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||||
|
cache_key = mcp_http_auth._cache_key("sk_test_cache_expire")
|
||||||
|
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
|
||||||
|
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
||||||
|
|
||||||
|
assert first == "org_1"
|
||||||
|
assert second == "org_2"
|
||||||
|
assert calls == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
calls = 0
|
||||||
|
|
||||||
|
async def _resolve(_api_key: str, _db: object) -> object:
|
||||||
|
nonlocal calls
|
||||||
|
calls += 1
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException, match="Invalid credentials"):
|
||||||
|
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException, match="Invalid API key"):
|
||||||
|
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
|
||||||
|
|
||||||
|
assert calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_mcp_api_key_retries_transient_failure_without_negative_cache(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
calls = 0
|
||||||
|
|
||||||
|
async def _resolve(_api_key: str, _db: object) -> object:
|
||||||
|
nonlocal calls
|
||||||
|
calls += 1
|
||||||
|
if calls == 1:
|
||||||
|
raise RuntimeError("transient db error")
|
||||||
|
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_recovered"))
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
|
|
||||||
|
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
|
||||||
|
|
||||||
|
cache_key = mcp_http_auth._cache_key("sk_test_transient")
|
||||||
|
assert mcp_http_auth._api_key_validation_cache[cache_key][0] == "org_recovered"
|
||||||
|
|
||||||
|
assert recovered_org == "org_recovered"
|
||||||
|
assert calls == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Multiple concurrent callers for the same key all succeed; the cache
|
||||||
|
collapses subsequent calls after the first one populates it."""
|
||||||
|
calls = 0
|
||||||
|
|
||||||
|
async def _resolve(_api_key: str, _db: object) -> object:
|
||||||
|
nonlocal calls
|
||||||
|
calls += 1
|
||||||
|
return SimpleNamespace(organization=SimpleNamespace(organization_id="org_concurrent"))
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
|
|
||||||
|
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("sk_test_concurrent") for _ in range(5)])
|
||||||
|
assert all(r == "org_concurrent" for r in results)
|
||||||
|
# First call populates cache; remaining may or may not hit DB depending on
|
||||||
|
# scheduling, but all must succeed.
|
||||||
|
assert calls >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
calls = 0
|
||||||
|
|
||||||
|
async def _resolve(_api_key: str, _db: object) -> object:
|
||||||
|
nonlocal calls
|
||||||
|
calls += 1
|
||||||
|
raise RuntimeError("persistent db outage")
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_MAX_VALIDATION_RETRIES", 2)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
||||||
|
monkeypatch.setattr(mcp_http_auth, "_get_auth_db", lambda: object())
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException, match="temporarily unavailable") as exc_info:
|
||||||
|
await mcp_http_auth.validate_mcp_api_key("sk_test_transient_exhausted")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 503
|
||||||
|
assert calls == 3 # initial + 2 retries
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_auth_db_disposes_engine() -> None:
|
||||||
|
dispose = AsyncMock()
|
||||||
|
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
|
||||||
|
mcp_http_auth._api_key_validation_cache["k"] = ("org", 123.0)
|
||||||
|
|
||||||
|
await mcp_http_auth.close_auth_db()
|
||||||
|
|
||||||
|
dispose.assert_awaited_once()
|
||||||
|
assert mcp_http_auth._auth_db is None
|
||||||
|
assert mcp_http_auth._api_key_validation_cache == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_auth_db_noop_when_uninitialized() -> None:
|
||||||
|
mcp_http_auth._auth_db = None
|
||||||
|
await mcp_http_auth.close_auth_db()
|
||||||
|
assert mcp_http_auth._auth_db is None
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -14,10 +15,13 @@ from skyvern.cli.mcp_tools import session as mcp_session
|
|||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _reset_singletons() -> None:
|
def _reset_singletons() -> None:
|
||||||
client_mod._skyvern_instance.set(None)
|
client_mod._skyvern_instance.set(None)
|
||||||
|
client_mod._api_key_override.set(None)
|
||||||
client_mod._global_skyvern_instance = None
|
client_mod._global_skyvern_instance = None
|
||||||
|
client_mod._api_key_clients.clear()
|
||||||
|
|
||||||
session_manager._current_session.set(None)
|
session_manager._current_session.set(None)
|
||||||
session_manager._global_session = None
|
session_manager._global_session = None
|
||||||
|
session_manager.set_stateless_http_mode(False)
|
||||||
mcp_session.set_current_session(mcp_session.SessionState())
|
mcp_session.set_current_session(mcp_session.SessionState())
|
||||||
|
|
||||||
|
|
||||||
@@ -47,6 +51,115 @@ def test_get_skyvern_reuses_global_instance_across_contexts(monkeypatch: pytest.
|
|||||||
assert len(created) == 1
|
assert len(created) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_skyvern_reuses_override_instance_per_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
created_keys: list[str] = []
|
||||||
|
|
||||||
|
class FakeSkyvern:
|
||||||
|
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||||
|
created_keys.append(kwargs["api_key"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def local(cls) -> FakeSkyvern:
|
||||||
|
return cls(api_key="local")
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||||
|
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||||
|
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||||
|
|
||||||
|
token = client_mod.set_api_key_override("sk_key_a")
|
||||||
|
try:
|
||||||
|
first = client_mod.get_skyvern()
|
||||||
|
client_mod._skyvern_instance.set(None)
|
||||||
|
second = client_mod.get_skyvern()
|
||||||
|
finally:
|
||||||
|
client_mod.reset_api_key_override(token)
|
||||||
|
|
||||||
|
assert first is second
|
||||||
|
assert created_keys == ["sk_key_a"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_skyvern_override_client_cache_uses_lru_eviction(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
created_keys: list[str] = []
|
||||||
|
|
||||||
|
class FakeSkyvern:
|
||||||
|
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||||
|
created_keys.append(kwargs["api_key"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def local(cls) -> FakeSkyvern:
|
||||||
|
return cls(api_key="local")
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||||
|
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||||
|
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||||
|
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 2)
|
||||||
|
|
||||||
|
for key in ("sk_key_a", "sk_key_b"):
|
||||||
|
token = client_mod.set_api_key_override(key)
|
||||||
|
try:
|
||||||
|
client_mod.get_skyvern()
|
||||||
|
finally:
|
||||||
|
client_mod.reset_api_key_override(token)
|
||||||
|
|
||||||
|
# Touch key_a so key_b becomes least-recently-used.
|
||||||
|
token = client_mod.set_api_key_override("sk_key_a")
|
||||||
|
try:
|
||||||
|
client_mod._skyvern_instance.set(None)
|
||||||
|
client_mod.get_skyvern()
|
||||||
|
finally:
|
||||||
|
client_mod.reset_api_key_override(token)
|
||||||
|
|
||||||
|
# Adding key_c should evict key_b.
|
||||||
|
token = client_mod.set_api_key_override("sk_key_c")
|
||||||
|
try:
|
||||||
|
client_mod.get_skyvern()
|
||||||
|
finally:
|
||||||
|
client_mod.reset_api_key_override(token)
|
||||||
|
|
||||||
|
assert list(client_mod._api_key_clients.keys()) == [
|
||||||
|
client_mod._cache_key("sk_key_a"),
|
||||||
|
client_mod._cache_key("sk_key_c"),
|
||||||
|
]
|
||||||
|
# key_a, key_b, key_c were created exactly once each.
|
||||||
|
assert created_keys == ["sk_key_a", "sk_key_b", "sk_key_c"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_skyvern_override_cache_closes_evicted_client(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
closed_keys: list[str] = []
|
||||||
|
|
||||||
|
class FakeSkyvern:
|
||||||
|
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||||
|
self.api_key = kwargs["api_key"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def local(cls) -> FakeSkyvern:
|
||||||
|
return cls(api_key="local")
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
closed_keys.append(self.api_key)
|
||||||
|
|
||||||
|
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||||
|
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||||
|
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||||
|
monkeypatch.setattr(client_mod, "_API_KEY_CLIENT_CACHE_MAX", 1)
|
||||||
|
|
||||||
|
for key in ("sk_key_a", "sk_key_b"):
|
||||||
|
token = client_mod.set_api_key_override(key)
|
||||||
|
try:
|
||||||
|
client_mod.get_skyvern()
|
||||||
|
finally:
|
||||||
|
client_mod.reset_api_key_override(token)
|
||||||
|
|
||||||
|
assert list(client_mod._api_key_clients.keys()) == [client_mod._cache_key("sk_key_b")]
|
||||||
|
assert closed_keys == ["sk_key_a"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_close_skyvern_closes_singleton() -> None:
|
async def test_close_skyvern_closes_singleton() -> None:
|
||||||
fake = MagicMock()
|
fake = MagicMock()
|
||||||
@@ -75,12 +188,53 @@ def test_get_current_session_falls_back_to_global_state() -> None:
|
|||||||
assert recovered is state
|
assert recovered is state
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_current_session_stateless_mode_ignores_global_state() -> None:
|
||||||
|
global_state = session_manager.SessionState(
|
||||||
|
browser=MagicMock(),
|
||||||
|
context=BrowserContext(mode="cloud_session", session_id="pbs_999"),
|
||||||
|
)
|
||||||
|
session_manager._global_session = global_state
|
||||||
|
session_manager._current_session.set(None)
|
||||||
|
|
||||||
|
session_manager.set_stateless_http_mode(True)
|
||||||
|
try:
|
||||||
|
recovered = session_manager.get_current_session()
|
||||||
|
finally:
|
||||||
|
session_manager.set_stateless_http_mode(False)
|
||||||
|
|
||||||
|
assert recovered is not global_state
|
||||||
|
assert recovered.browser is None
|
||||||
|
assert recovered.context is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_current_session_stateless_mode_does_not_override_global_state() -> None:
|
||||||
|
global_state = session_manager.SessionState(
|
||||||
|
browser=MagicMock(),
|
||||||
|
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
|
||||||
|
)
|
||||||
|
session_manager._global_session = global_state
|
||||||
|
replacement = session_manager.SessionState(
|
||||||
|
browser=MagicMock(),
|
||||||
|
context=BrowserContext(mode="cloud_session", session_id="pbs_request"),
|
||||||
|
)
|
||||||
|
|
||||||
|
session_manager.set_stateless_http_mode(True)
|
||||||
|
try:
|
||||||
|
session_manager.set_current_session(replacement)
|
||||||
|
finally:
|
||||||
|
session_manager.set_stateless_http_mode(False)
|
||||||
|
|
||||||
|
assert session_manager._global_session is global_state
|
||||||
|
assert session_manager._current_session.get() is replacement
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
current_browser = MagicMock()
|
current_browser = MagicMock()
|
||||||
current_state = session_manager.SessionState(
|
current_state = session_manager.SessionState(
|
||||||
browser=current_browser,
|
browser=current_browser,
|
||||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||||
|
api_key_hash=session_manager._api_key_hash(client_mod.get_active_api_key()),
|
||||||
)
|
)
|
||||||
session_manager.set_current_session(current_state)
|
session_manager.set_current_session(current_state)
|
||||||
|
|
||||||
@@ -95,6 +249,92 @@ async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest
|
|||||||
fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited()
|
fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_browser_does_not_reuse_session_for_different_api_key(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
current_browser = MagicMock()
|
||||||
|
session_manager.set_current_session(
|
||||||
|
session_manager.SessionState(
|
||||||
|
browser=current_browser,
|
||||||
|
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||||
|
api_key_hash=session_manager._api_key_hash("sk_key_a"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
replacement_browser = MagicMock()
|
||||||
|
fake_skyvern = MagicMock()
|
||||||
|
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
|
||||||
|
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||||
|
|
||||||
|
token = client_mod.set_api_key_override("sk_key_b")
|
||||||
|
try:
|
||||||
|
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
|
||||||
|
finally:
|
||||||
|
client_mod.reset_api_key_override(token)
|
||||||
|
|
||||||
|
assert browser is replacement_browser
|
||||||
|
assert ctx.session_id == "pbs_123"
|
||||||
|
fake_skyvern.connect_to_cloud_browser_session.assert_awaited_once_with("pbs_123")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_browser_stateless_mode_does_not_write_global_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
global_state = session_manager.SessionState(
|
||||||
|
browser=MagicMock(),
|
||||||
|
context=BrowserContext(mode="cloud_session", session_id="pbs_global"),
|
||||||
|
)
|
||||||
|
session_manager._global_session = global_state
|
||||||
|
|
||||||
|
replacement_browser = MagicMock()
|
||||||
|
fake_skyvern = MagicMock()
|
||||||
|
fake_skyvern.connect_to_cloud_browser_session = AsyncMock(return_value=replacement_browser)
|
||||||
|
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||||
|
|
||||||
|
session_manager.set_stateless_http_mode(True)
|
||||||
|
try:
|
||||||
|
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
|
||||||
|
finally:
|
||||||
|
session_manager.set_stateless_http_mode(False)
|
||||||
|
|
||||||
|
assert browser is replacement_browser
|
||||||
|
assert ctx.session_id == "pbs_123"
|
||||||
|
assert session_manager._global_session is global_state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_browser_blocks_implicit_session_in_stateless_mode() -> None:
|
||||||
|
session_manager.set_current_session(
|
||||||
|
session_manager.SessionState(
|
||||||
|
browser=MagicMock(),
|
||||||
|
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||||
|
api_key_hash=session_manager._api_key_hash("sk_key_a"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session_manager.set_stateless_http_mode(True)
|
||||||
|
try:
|
||||||
|
with pytest.raises(session_manager.BrowserNotAvailableError):
|
||||||
|
await session_manager.resolve_browser()
|
||||||
|
finally:
|
||||||
|
session_manager.set_stateless_http_mode(False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_browser_raises_for_invalid_matching_state(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
session_manager.set_current_session(
|
||||||
|
session_manager.SessionState(
|
||||||
|
browser=None,
|
||||||
|
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(session_manager, "_matches_current", lambda *args, **kwargs: True)
|
||||||
|
monkeypatch.setattr(session_manager, "get_skyvern", lambda: MagicMock())
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Expected active browser and context"):
|
||||||
|
await session_manager.resolve_browser(session_id="pbs_123")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
current_browser = MagicMock()
|
current_browser = MagicMock()
|
||||||
@@ -262,3 +502,73 @@ async def test_close_current_session_still_closes_browser_when_api_fails(monkeyp
|
|||||||
# _browser_session_id should NOT be cleared (API close failed, let browser.close() try)
|
# _browser_session_id should NOT be cleared (API close failed, let browser.close() try)
|
||||||
assert browser._browser_session_id == "pbs_fail"
|
assert browser._browser_session_id == "pbs_fail"
|
||||||
assert session_manager.get_current_session().browser is None
|
assert session_manager.get_current_session().browser is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests for stateless HTTP mode session creation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_create_stateless_mode_returns_session_without_persisting_browser(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
session_manager.set_stateless_http_mode(True)
|
||||||
|
fake_skyvern = MagicMock()
|
||||||
|
fake_skyvern.create_browser_session = AsyncMock(return_value=SimpleNamespace(browser_session_id="pbs_abc"))
|
||||||
|
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||||
|
do_session_create = AsyncMock()
|
||||||
|
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await mcp_session.skyvern_session_create(timeout=45)
|
||||||
|
finally:
|
||||||
|
session_manager.set_stateless_http_mode(False)
|
||||||
|
|
||||||
|
assert result["ok"] is True
|
||||||
|
assert result["data"] == {"session_id": "pbs_abc", "timeout_minutes": 45}
|
||||||
|
do_session_create.assert_not_awaited()
|
||||||
|
assert mcp_session.get_current_session().browser is None
|
||||||
|
assert mcp_session.get_current_session().context is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_create_stateless_mode_rejects_local() -> None:
|
||||||
|
session_manager.set_stateless_http_mode(True)
|
||||||
|
try:
|
||||||
|
result = await mcp_session.skyvern_session_create(local=True)
|
||||||
|
finally:
|
||||||
|
session_manager.set_stateless_http_mode(False)
|
||||||
|
|
||||||
|
assert result["ok"] is False
|
||||||
|
assert result["error"]["code"] == mcp_session.ErrorCode.INVALID_INPUT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_create_persists_active_api_key_hash_in_session_state(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
fake_skyvern = MagicMock()
|
||||||
|
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||||
|
|
||||||
|
fake_browser = MagicMock()
|
||||||
|
do_session_create = AsyncMock(
|
||||||
|
return_value=(
|
||||||
|
fake_browser,
|
||||||
|
SimpleNamespace(local=False, session_id="pbs_123", timeout_minutes=60, headless=False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(mcp_session, "do_session_create", do_session_create)
|
||||||
|
|
||||||
|
token = client_mod.set_api_key_override("sk_key_create")
|
||||||
|
try:
|
||||||
|
result = await mcp_session.skyvern_session_create(timeout=60)
|
||||||
|
finally:
|
||||||
|
client_mod.reset_api_key_override(token)
|
||||||
|
|
||||||
|
assert result["ok"] is True
|
||||||
|
current = mcp_session.get_current_session()
|
||||||
|
assert current.browser is fake_browser
|
||||||
|
assert current.context == BrowserContext(mode="cloud_session", session_id="pbs_123")
|
||||||
|
assert current.api_key_hash == session_manager._api_key_hash("sk_key_create")
|
||||||
|
assert current.api_key_hash != "sk_key_create"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock, call
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -13,6 +13,42 @@ def _reset_cleanup_state() -> None:
|
|||||||
run_commands._mcp_cleanup_done = False
|
run_commands._mcp_cleanup_done = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_mcp_resources_closes_auth_db(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
close_current_session = AsyncMock()
|
||||||
|
close_skyvern = AsyncMock()
|
||||||
|
close_auth_db = AsyncMock()
|
||||||
|
|
||||||
|
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
|
||||||
|
monkeypatch.setattr(run_commands, "close_skyvern", close_skyvern)
|
||||||
|
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
|
||||||
|
|
||||||
|
await run_commands._cleanup_mcp_resources()
|
||||||
|
|
||||||
|
close_current_session.assert_awaited_once()
|
||||||
|
close_skyvern.assert_awaited_once()
|
||||||
|
close_auth_db.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_mcp_resources_closes_auth_db_on_skyvern_close_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
close_current_session = AsyncMock()
|
||||||
|
close_auth_db = AsyncMock()
|
||||||
|
|
||||||
|
async def _failing_close_skyvern() -> None:
|
||||||
|
raise RuntimeError("close failed")
|
||||||
|
|
||||||
|
monkeypatch.setattr(run_commands, "close_current_session", close_current_session)
|
||||||
|
monkeypatch.setattr(run_commands, "close_skyvern", _failing_close_skyvern)
|
||||||
|
monkeypatch.setattr(run_commands, "close_auth_db", close_auth_db)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="close failed"):
|
||||||
|
await run_commands._cleanup_mcp_resources()
|
||||||
|
|
||||||
|
close_current_session.assert_awaited_once()
|
||||||
|
close_auth_db.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
def test_cleanup_mcp_resources_sync_runs_without_running_loop(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_cleanup_mcp_resources_sync_runs_without_running_loop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
cleanup = AsyncMock()
|
cleanup = AsyncMock()
|
||||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup)
|
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup)
|
||||||
@@ -50,14 +86,56 @@ def test_run_mcp_calls_blocking_cleanup_in_finally(monkeypatch: pytest.MonkeyPat
|
|||||||
cleanup_blocking = MagicMock()
|
cleanup_blocking = MagicMock()
|
||||||
register = MagicMock()
|
register = MagicMock()
|
||||||
run = MagicMock(side_effect=RuntimeError("boom"))
|
run = MagicMock(side_effect=RuntimeError("boom"))
|
||||||
|
set_stateless = MagicMock()
|
||||||
|
|
||||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
|
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
|
||||||
monkeypatch.setattr(run_commands.atexit, "register", register)
|
monkeypatch.setattr(run_commands.atexit, "register", register)
|
||||||
monkeypatch.setattr(run_commands.mcp, "run", run)
|
monkeypatch.setattr(run_commands.mcp, "run", run)
|
||||||
|
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="boom"):
|
with pytest.raises(RuntimeError, match="boom"):
|
||||||
run_commands.run_mcp()
|
run_commands.run_mcp()
|
||||||
|
|
||||||
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
|
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
|
||||||
run.assert_called_once_with(transport="stdio")
|
run.assert_called_once_with(transport="stdio")
|
||||||
|
set_stateless.assert_has_calls([call(False), call(False)])
|
||||||
cleanup_blocking.assert_called_once()
|
cleanup_blocking.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_mcp_http_transport_wires_auth_middleware(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
cleanup_blocking = MagicMock()
|
||||||
|
register = MagicMock()
|
||||||
|
run = MagicMock()
|
||||||
|
set_stateless = MagicMock()
|
||||||
|
|
||||||
|
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
|
||||||
|
monkeypatch.setattr(run_commands.atexit, "register", register)
|
||||||
|
monkeypatch.setattr(run_commands.mcp, "run", run)
|
||||||
|
monkeypatch.setattr(run_commands, "set_stateless_http_mode", set_stateless)
|
||||||
|
|
||||||
|
run_commands.run_mcp(
|
||||||
|
transport="streamable-http",
|
||||||
|
host="127.0.0.1",
|
||||||
|
port=9010,
|
||||||
|
path="mcp",
|
||||||
|
stateless_http=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
|
||||||
|
run.assert_called_once()
|
||||||
|
kwargs = run.call_args.kwargs
|
||||||
|
assert kwargs["transport"] == "streamable-http"
|
||||||
|
assert kwargs["host"] == "127.0.0.1"
|
||||||
|
assert kwargs["port"] == 9010
|
||||||
|
assert kwargs["path"] == "/mcp"
|
||||||
|
assert kwargs["stateless_http"] is True
|
||||||
|
middleware = kwargs["middleware"]
|
||||||
|
assert len(middleware) == 1
|
||||||
|
assert middleware[0].cls is run_commands.MCPAPIKeyMiddleware
|
||||||
|
set_stateless.assert_has_calls([call(True), call(False)])
|
||||||
|
cleanup_blocking.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_task_tool_registration_points_to_browser_module() -> None:
|
||||||
|
tool = run_commands.mcp._tool_manager._tools["skyvern_run_task"] # type: ignore[attr-defined]
|
||||||
|
assert tool.fn.__module__ == "skyvern.cli.mcp_tools.browser"
|
||||||
|
|||||||
Reference in New Issue
Block a user