Add PR A CLI browser command parity with MCP (#4789)
This commit is contained in:
@@ -13,7 +13,17 @@ from skyvern.cli.commands._state import CLIState, clear_state, load_state, save_
|
||||
from skyvern.cli.core.artifacts import save_artifact
|
||||
from skyvern.cli.core.browser_ops import do_act, do_extract, do_navigate, do_screenshot
|
||||
from skyvern.cli.core.client import get_skyvern
|
||||
from skyvern.cli.core.guards import GuardError, check_password_prompt, validate_wait_until
|
||||
from skyvern.cli.core.guards import (
|
||||
CREDENTIAL_HINT,
|
||||
PASSWORD_PATTERN,
|
||||
VALID_ELEMENT_STATES,
|
||||
GuardError,
|
||||
check_js_password,
|
||||
check_password_prompt,
|
||||
resolve_ai_mode,
|
||||
validate_button,
|
||||
validate_wait_until,
|
||||
)
|
||||
from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list
|
||||
|
||||
browser_app = typer.Typer(help="Browser automation commands.", no_args_is_help=True)
|
||||
@@ -64,6 +74,24 @@ async def _connect_browser(connection: ConnectionTarget) -> Any:
|
||||
return await skyvern.connect_to_browser_over_cdp(connection.cdp_url)
|
||||
|
||||
|
||||
def _resolve_ai_target(selector: str | None, intent: str | None, *, operation: str) -> str | None:
|
||||
ai_mode, err = resolve_ai_mode(selector, intent)
|
||||
if err:
|
||||
raise GuardError(
|
||||
"Must provide intent, selector, or both",
|
||||
(
|
||||
f"Use intent='describe what to {operation}' for AI-powered targeting, "
|
||||
"or selector='#css-selector' for precise targeting"
|
||||
),
|
||||
)
|
||||
return ai_mode
|
||||
|
||||
|
||||
def _validate_wait_state(state: str) -> None:
|
||||
if state not in VALID_ELEMENT_STATES:
|
||||
raise GuardError(f"Invalid state: {state}", "Use visible, hidden, attached, or detached")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session commands
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -190,6 +218,35 @@ def session_list(
|
||||
output_error(str(e), hint="Check your API key and network connection.", json_mode=json_output)
|
||||
|
||||
|
||||
@session_app.command("get")
|
||||
def session_get(
|
||||
session: str = typer.Option(..., "--session", "--id", help="Browser session ID."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Get details for a browser session."""
|
||||
|
||||
async def _run() -> dict:
|
||||
skyvern = get_skyvern()
|
||||
resolved = await skyvern.get_browser_session(session)
|
||||
state = load_state()
|
||||
is_current = bool(state and state.mode == "cloud" and state.session_id == session)
|
||||
return {
|
||||
"session_id": resolved.browser_session_id,
|
||||
"status": resolved.status,
|
||||
"started_at": resolved.started_at.isoformat() if resolved.started_at else None,
|
||||
"completed_at": resolved.completed_at.isoformat() if resolved.completed_at else None,
|
||||
"timeout": resolved.timeout,
|
||||
"runnable_id": resolved.runnable_id,
|
||||
"is_current": is_current,
|
||||
}
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="session_get", json_mode=json_output)
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Verify the session ID exists and is accessible.", json_mode=json_output)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Browser commands
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -269,6 +326,392 @@ def screenshot(
|
||||
output_error(str(e), hint="Ensure the session is active and the page has loaded.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("evaluate")
|
||||
def evaluate(
|
||||
expression: str = typer.Option(..., help="JavaScript expression to evaluate."),
|
||||
session: str | None = typer.Option(None, help="Browser session ID."),
|
||||
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Run JavaScript on the current page."""
|
||||
|
||||
async def _run() -> dict:
|
||||
check_js_password(expression)
|
||||
connection = _resolve_connection(session, cdp)
|
||||
browser = await _connect_browser(connection)
|
||||
page = await browser.get_working_page()
|
||||
result = await page.evaluate(expression)
|
||||
return {"result": result}
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="evaluate", json_mode=json_output)
|
||||
except GuardError as e:
|
||||
output_error(str(e), hint=e.hint, json_mode=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Check JavaScript syntax and page state.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("click")
|
||||
def click(
|
||||
intent: str | None = typer.Option(None, help="Natural language description of the element to click."),
|
||||
selector: str | None = typer.Option(None, help="CSS selector or XPath for the element to click."),
|
||||
session: str | None = typer.Option(None, help="Browser session ID."),
|
||||
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
|
||||
timeout: int = typer.Option(30000, help="Max wait time in milliseconds."),
|
||||
button: str | None = typer.Option(None, help="Mouse button: left, right, or middle."),
|
||||
click_count: int | None = typer.Option(None, "--click-count", help="Number of clicks (2 for double-click)."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Click an element using selector, intent, or both."""
|
||||
|
||||
async def _run() -> dict:
|
||||
validate_button(button)
|
||||
ai_mode = _resolve_ai_target(selector, intent, operation="click")
|
||||
connection = _resolve_connection(session, cdp)
|
||||
browser = await _connect_browser(connection)
|
||||
page = await browser.get_working_page()
|
||||
|
||||
kwargs: dict[str, Any] = {"timeout": timeout}
|
||||
if button:
|
||||
kwargs["button"] = button
|
||||
if click_count is not None:
|
||||
kwargs["click_count"] = click_count
|
||||
|
||||
if ai_mode is not None:
|
||||
resolved = await page.click(selector=selector, prompt=intent, ai=ai_mode, **kwargs) # type: ignore[arg-type]
|
||||
else:
|
||||
assert selector is not None
|
||||
resolved = await page.click(selector=selector, **kwargs)
|
||||
|
||||
data: dict[str, Any] = {"selector": selector, "intent": intent, "ai_mode": ai_mode}
|
||||
if resolved and resolved != selector:
|
||||
data["resolved_selector"] = resolved
|
||||
return data
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="click", json_mode=json_output)
|
||||
except GuardError as e:
|
||||
output_error(str(e), hint=e.hint, json_mode=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Element may be hidden, disabled, or not yet available.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("hover")
|
||||
def hover(
|
||||
intent: str | None = typer.Option(None, help="Natural language description of the element to hover."),
|
||||
selector: str | None = typer.Option(None, help="CSS selector or XPath for the element to hover."),
|
||||
session: str | None = typer.Option(None, help="Browser session ID."),
|
||||
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
|
||||
timeout: int = typer.Option(30000, help="Max wait time in milliseconds."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Hover over an element using selector, intent, or both."""
|
||||
|
||||
async def _run() -> dict:
|
||||
ai_mode = _resolve_ai_target(selector, intent, operation="hover")
|
||||
connection = _resolve_connection(session, cdp)
|
||||
browser = await _connect_browser(connection)
|
||||
page = await browser.get_working_page()
|
||||
|
||||
if ai_mode is not None:
|
||||
locator = page.locator(selector=selector, prompt=intent, ai=ai_mode) # type: ignore[arg-type]
|
||||
else:
|
||||
assert selector is not None
|
||||
locator = page.locator(selector)
|
||||
await locator.hover(timeout=timeout)
|
||||
return {"selector": selector, "intent": intent, "ai_mode": ai_mode}
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="hover", json_mode=json_output)
|
||||
except GuardError as e:
|
||||
output_error(str(e), hint=e.hint, json_mode=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Element may be hidden or not interactable.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("type")
|
||||
def type_text(
|
||||
text: str = typer.Option(..., help="Text to type into the input."),
|
||||
intent: str | None = typer.Option(None, help="Natural language description of the input field."),
|
||||
selector: str | None = typer.Option(None, help="CSS selector or XPath for the input field."),
|
||||
session: str | None = typer.Option(None, help="Browser session ID."),
|
||||
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
|
||||
timeout: int = typer.Option(30000, help="Max wait time in milliseconds."),
|
||||
clear: bool = typer.Option(True, "--clear/--no-clear", help="Clear existing content before typing."),
|
||||
delay: int | None = typer.Option(None, help="Delay between keystrokes in milliseconds."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Type into an input field using selector, intent, or both."""
|
||||
|
||||
async def _run() -> dict:
|
||||
target_text = f"{intent or ''} {selector or ''}"
|
||||
if PASSWORD_PATTERN.search(target_text):
|
||||
raise GuardError(
|
||||
"Cannot type into password fields — credentials must not be passed through tool calls",
|
||||
CREDENTIAL_HINT,
|
||||
)
|
||||
|
||||
ai_mode = _resolve_ai_target(selector, intent, operation="type")
|
||||
connection = _resolve_connection(session, cdp)
|
||||
browser = await _connect_browser(connection)
|
||||
page = await browser.get_working_page()
|
||||
|
||||
if selector:
|
||||
try:
|
||||
is_password = await page.evaluate(
|
||||
"(s) => { const el = document.querySelector(s); return !!(el && el.type === 'password'); }",
|
||||
selector,
|
||||
)
|
||||
except Exception:
|
||||
is_password = False
|
||||
if is_password:
|
||||
raise GuardError(
|
||||
"Cannot type into password fields — credentials must not be passed through tool calls",
|
||||
CREDENTIAL_HINT,
|
||||
)
|
||||
|
||||
if clear:
|
||||
if ai_mode is not None:
|
||||
await page.fill(selector=selector, value=text, prompt=intent, ai=ai_mode, timeout=timeout) # type: ignore[arg-type]
|
||||
else:
|
||||
assert selector is not None
|
||||
await page.fill(selector, text, timeout=timeout)
|
||||
else:
|
||||
kwargs: dict[str, Any] = {"timeout": timeout}
|
||||
if delay is not None:
|
||||
kwargs["delay"] = delay
|
||||
if ai_mode is not None:
|
||||
locator = page.locator(selector=selector, prompt=intent, ai=ai_mode) # type: ignore[arg-type]
|
||||
await locator.type(text, **kwargs)
|
||||
else:
|
||||
assert selector is not None
|
||||
await page.type(selector, text, **kwargs)
|
||||
|
||||
return {"selector": selector, "intent": intent, "ai_mode": ai_mode, "text_length": len(text)}
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="type", json_mode=json_output)
|
||||
except GuardError as e:
|
||||
output_error(str(e), hint=e.hint, json_mode=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Element may not be editable or may be obscured.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("scroll")
|
||||
def scroll(
|
||||
direction: str = typer.Option(..., help="Direction: up, down, left, right."),
|
||||
session: str | None = typer.Option(None, help="Browser session ID."),
|
||||
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
|
||||
amount: int | None = typer.Option(None, help="Pixels to scroll (default 500)."),
|
||||
intent: str | None = typer.Option(None, help="Natural language element to scroll into view."),
|
||||
selector: str | None = typer.Option(None, help="CSS selector of scrollable element."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Scroll the page or scroll a targeted element into view."""
|
||||
|
||||
async def _run() -> dict:
|
||||
valid_directions = ("up", "down", "left", "right")
|
||||
if not intent and direction not in valid_directions:
|
||||
raise GuardError(f"Invalid direction: {direction}", "Use up, down, left, or right")
|
||||
|
||||
connection = _resolve_connection(session, cdp)
|
||||
browser = await _connect_browser(connection)
|
||||
page = await browser.get_working_page()
|
||||
|
||||
if intent:
|
||||
ai_mode = "fallback" if selector else "proactive"
|
||||
locator = page.locator(selector=selector, prompt=intent, ai=ai_mode)
|
||||
await locator.scroll_into_view_if_needed()
|
||||
return {"direction": "into_view", "intent": intent, "selector": selector, "ai_mode": ai_mode}
|
||||
|
||||
pixels = amount or 500
|
||||
direction_map = {"up": (0, -pixels), "down": (0, pixels), "left": (-pixels, 0), "right": (pixels, 0)}
|
||||
dx, dy = direction_map[direction]
|
||||
|
||||
if selector:
|
||||
await page.locator(selector).evaluate(f"el => el.scrollBy({dx}, {dy})")
|
||||
else:
|
||||
await page.evaluate(f"window.scrollBy({dx}, {dy})")
|
||||
|
||||
return {"direction": direction, "pixels": pixels, "selector": selector}
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="scroll", json_mode=json_output)
|
||||
except GuardError as e:
|
||||
output_error(str(e), hint=e.hint, json_mode=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Scroll failed; check selector and page readiness.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("select")
|
||||
def select(
|
||||
value: str = typer.Option(..., help="Option value to select."),
|
||||
intent: str | None = typer.Option(None, help="Natural language description of the dropdown."),
|
||||
selector: str | None = typer.Option(None, help="CSS selector for the dropdown."),
|
||||
session: str | None = typer.Option(None, help="Browser session ID."),
|
||||
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
|
||||
timeout: int = typer.Option(30000, help="Max wait time in milliseconds."),
|
||||
by_label: bool = typer.Option(False, "--by-label", help="Select by visible label instead of value."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Select an option from a dropdown."""
|
||||
|
||||
async def _run() -> dict:
|
||||
ai_mode = _resolve_ai_target(selector, intent, operation="select")
|
||||
connection = _resolve_connection(session, cdp)
|
||||
browser = await _connect_browser(connection)
|
||||
page = await browser.get_working_page()
|
||||
|
||||
if ai_mode is not None:
|
||||
await page.select_option(selector=selector, value=value, prompt=intent, ai=ai_mode, timeout=timeout) # type: ignore[arg-type]
|
||||
else:
|
||||
assert selector is not None
|
||||
if by_label:
|
||||
await page.page.locator(selector).select_option(label=value, timeout=timeout)
|
||||
else:
|
||||
await page.select_option(selector, value=value, timeout=timeout)
|
||||
|
||||
return {"selector": selector, "intent": intent, "ai_mode": ai_mode, "value": value, "by_label": by_label}
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="select", json_mode=json_output)
|
||||
except GuardError as e:
|
||||
output_error(str(e), hint=e.hint, json_mode=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Check dropdown selector and available options.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("press-key")
|
||||
def press_key(
|
||||
key: str = typer.Option(..., help="Key to press (e.g., Enter, Tab, Escape)."),
|
||||
session: str | None = typer.Option(None, help="Browser session ID."),
|
||||
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
|
||||
intent: str | None = typer.Option(None, help="Natural language description of element to focus first."),
|
||||
selector: str | None = typer.Option(None, help="CSS selector to focus first."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Press a keyboard key."""
|
||||
|
||||
async def _run() -> dict:
|
||||
connection = _resolve_connection(session, cdp)
|
||||
browser = await _connect_browser(connection)
|
||||
page = await browser.get_working_page()
|
||||
|
||||
if intent or selector:
|
||||
ai_mode, err = resolve_ai_mode(selector, intent)
|
||||
if err:
|
||||
raise GuardError(
|
||||
"Must provide intent, selector, or both",
|
||||
"Use intent='describe where to press' or selector='#css-selector'",
|
||||
)
|
||||
if ai_mode is not None:
|
||||
locator = page.locator(selector=selector, prompt=intent, ai=ai_mode) # type: ignore[arg-type]
|
||||
else:
|
||||
assert selector is not None
|
||||
locator = page.locator(selector)
|
||||
await locator.press(key)
|
||||
else:
|
||||
await page.keyboard.press(key)
|
||||
|
||||
return {"key": key, "selector": selector, "intent": intent}
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="press_key", json_mode=json_output)
|
||||
except GuardError as e:
|
||||
output_error(str(e), hint=e.hint, json_mode=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Check key name and focused target.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("wait")
|
||||
def wait(
|
||||
session: str | None = typer.Option(None, help="Browser session ID."),
|
||||
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
|
||||
time_ms: int | None = typer.Option(None, "--time", help="Milliseconds to wait."),
|
||||
intent: str | None = typer.Option(None, help="Natural language condition to wait for."),
|
||||
selector: str | None = typer.Option(None, help="CSS selector to wait for."),
|
||||
state: str = typer.Option("visible", help="Element state: visible, hidden, attached, detached."),
|
||||
timeout: int = typer.Option(30000, help="Max wait time in milliseconds."),
|
||||
poll_interval: int = typer.Option(5000, "--poll-interval", help="Polling interval for intent waits in ms."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Wait for time, selector state, or AI condition."""
|
||||
|
||||
async def _run() -> dict:
|
||||
_validate_wait_state(state)
|
||||
if time_ms is None and not selector and not intent:
|
||||
raise GuardError(
|
||||
"Must provide intent, selector, or time_ms",
|
||||
"Use --time, --selector, or --intent to specify what to wait for",
|
||||
)
|
||||
|
||||
connection = _resolve_connection(session, cdp)
|
||||
browser = await _connect_browser(connection)
|
||||
page = await browser.get_working_page()
|
||||
|
||||
waited_for = ""
|
||||
if time_ms is not None:
|
||||
await page.wait_for_timeout(time_ms)
|
||||
waited_for = "time"
|
||||
elif intent:
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout / 1000
|
||||
last_error: Exception | None = None
|
||||
while True:
|
||||
try:
|
||||
ready = await page.validate(intent)
|
||||
last_error = None
|
||||
except Exception as poll_error:
|
||||
ready = False
|
||||
last_error = poll_error
|
||||
|
||||
if ready:
|
||||
waited_for = "intent"
|
||||
break
|
||||
if loop.time() >= deadline:
|
||||
if last_error:
|
||||
raise RuntimeError(str(last_error))
|
||||
raise TimeoutError(f"Condition not met within {timeout}ms: {intent}")
|
||||
await page.wait_for_timeout(poll_interval)
|
||||
else:
|
||||
assert selector is not None
|
||||
await page.wait_for_selector(selector, state=state, timeout=timeout)
|
||||
waited_for = "selector"
|
||||
|
||||
return {"waited_for": waited_for, "state": state, "selector": selector, "intent": intent}
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="wait", json_mode=json_output)
|
||||
except GuardError as e:
|
||||
output_error(str(e), hint=e.hint, json_mode=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Condition was not met within timeout.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("act")
|
||||
def act(
|
||||
prompt: str = typer.Option(..., help="Natural language action to perform."),
|
||||
@@ -323,3 +766,28 @@ def extract(
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Simplify the prompt or provide a JSON schema.", json_mode=json_output)
|
||||
|
||||
|
||||
@browser_app.command("validate")
|
||||
def validate(
|
||||
prompt: str = typer.Option(..., help="Validation condition to check."),
|
||||
session: str | None = typer.Option(None, help="Browser session ID."),
|
||||
cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."),
|
||||
json_output: bool = typer.Option(False, "--json", help="Output as JSON."),
|
||||
) -> None:
|
||||
"""Check whether a natural language condition is true on the current page."""
|
||||
|
||||
async def _run() -> dict:
|
||||
connection = _resolve_connection(session, cdp)
|
||||
browser = await _connect_browser(connection)
|
||||
page = await browser.get_working_page()
|
||||
valid = await page.validate(prompt)
|
||||
return {"prompt": prompt, "valid": valid}
|
||||
|
||||
try:
|
||||
data = asyncio.run(_run())
|
||||
output(data, action="validate", json_mode=json_output)
|
||||
except typer.BadParameter:
|
||||
raise
|
||||
except Exception as e:
|
||||
output_error(str(e), hint="Check the page state and validation prompt.", json_mode=json_output)
|
||||
|
||||
@@ -3,17 +3,26 @@ from __future__ import annotations
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.client import SkyvernEnvironment
|
||||
from skyvern.config import settings
|
||||
from skyvern.library.skyvern import Skyvern
|
||||
|
||||
_skyvern_instance: ContextVar[Skyvern | None] = ContextVar("skyvern_instance", default=None)
|
||||
_global_skyvern_instance: Skyvern | None = None
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
def get_skyvern() -> Skyvern:
|
||||
"""Get or create a Skyvern client instance."""
|
||||
global _global_skyvern_instance
|
||||
|
||||
instance = _skyvern_instance.get()
|
||||
if instance is None:
|
||||
instance = _global_skyvern_instance
|
||||
if instance is not None:
|
||||
_skyvern_instance.set(instance)
|
||||
return instance
|
||||
|
||||
api_key = settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY")
|
||||
@@ -28,5 +37,28 @@ def get_skyvern() -> Skyvern:
|
||||
else:
|
||||
instance = Skyvern.local()
|
||||
|
||||
_global_skyvern_instance = instance
|
||||
_skyvern_instance.set(instance)
|
||||
return instance
|
||||
|
||||
|
||||
async def close_skyvern() -> None:
|
||||
"""Close active Skyvern client(s) and release Playwright resources."""
|
||||
global _global_skyvern_instance
|
||||
|
||||
instances: list[Skyvern] = []
|
||||
seen: set[int] = set()
|
||||
for candidate in (_skyvern_instance.get(), _global_skyvern_instance):
|
||||
if candidate is None or id(candidate) in seen:
|
||||
continue
|
||||
seen.add(id(candidate))
|
||||
instances.append(candidate)
|
||||
|
||||
for instance in instances:
|
||||
try:
|
||||
await instance.aclose()
|
||||
except Exception:
|
||||
LOG.warning("Failed to close Skyvern client", exc_info=True)
|
||||
|
||||
_skyvern_instance.set(None)
|
||||
_global_skyvern_instance = None
|
||||
|
||||
@@ -5,9 +5,13 @@ from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator
|
||||
|
||||
import structlog
|
||||
|
||||
from .client import get_skyvern
|
||||
from .result import BrowserContext, ErrorCode, make_error
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.library.skyvern_browser import SkyvernBrowser
|
||||
from skyvern.library.skyvern_browser_page import SkyvernBrowserPage
|
||||
@@ -23,20 +27,46 @@ class SessionState:
|
||||
|
||||
|
||||
_current_session: ContextVar[SessionState | None] = ContextVar("mcp_session", default=None)
|
||||
_global_session: SessionState | None = None
|
||||
|
||||
|
||||
def get_current_session() -> SessionState:
|
||||
global _global_session
|
||||
|
||||
state = _current_session.get()
|
||||
if state is None:
|
||||
state = SessionState()
|
||||
if _global_session is None:
|
||||
_global_session = SessionState()
|
||||
state = _global_session
|
||||
_current_session.set(state)
|
||||
return state
|
||||
|
||||
|
||||
def set_current_session(state: SessionState) -> None:
|
||||
global _global_session
|
||||
_global_session = state
|
||||
_current_session.set(state)
|
||||
|
||||
|
||||
def _matches_current(
|
||||
current: SessionState,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
cdp_url: str | None = None,
|
||||
local: bool = False,
|
||||
) -> bool:
|
||||
if current.browser is None or current.context is None:
|
||||
return False
|
||||
|
||||
if session_id:
|
||||
return current.context.mode == "cloud_session" and current.context.session_id == session_id
|
||||
if cdp_url:
|
||||
return current.context.mode == "cdp" and current.context.cdp_url == cdp_url
|
||||
if local:
|
||||
return current.context.mode == "local"
|
||||
return False
|
||||
|
||||
|
||||
async def resolve_browser(
|
||||
session_id: str | None = None,
|
||||
cdp_url: str | None = None,
|
||||
@@ -54,6 +84,11 @@ async def resolve_browser(
|
||||
skyvern = get_skyvern()
|
||||
current = get_current_session()
|
||||
|
||||
if _matches_current(current, session_id=session_id, cdp_url=cdp_url, local=local):
|
||||
# _matches_current() guarantees both are non-None
|
||||
assert current.browser is not None and current.context is not None
|
||||
return current.browser, current.context
|
||||
|
||||
browser: SkyvernBrowser | None = None
|
||||
try:
|
||||
if session_id:
|
||||
@@ -94,6 +129,31 @@ async def resolve_browser(
|
||||
raise BrowserNotAvailableError()
|
||||
|
||||
|
||||
async def close_current_session() -> None:
|
||||
"""Close the active browser session (if any) and clear local session state."""
|
||||
from .session_ops import do_session_close
|
||||
|
||||
current = get_current_session()
|
||||
try:
|
||||
if current.context and current.context.mode == "cloud_session" and current.context.session_id:
|
||||
try:
|
||||
skyvern = get_skyvern()
|
||||
await do_session_close(skyvern, current.context.session_id)
|
||||
# Prevent SkyvernBrowser.close() from making a redundant API call
|
||||
if current.browser is not None:
|
||||
current.browser._browser_session_id = None
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Best-effort cloud session close failed",
|
||||
session_id=current.context.session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
if current.browser is not None:
|
||||
await current.browser.close()
|
||||
finally:
|
||||
set_current_session(SessionState())
|
||||
|
||||
|
||||
async def get_page(
|
||||
session_id: str | None = None,
|
||||
cdp_url: str | None = None,
|
||||
|
||||
@@ -10,6 +10,7 @@ from skyvern.cli.core.session_manager import (
|
||||
BrowserNotAvailableError,
|
||||
SessionState,
|
||||
browser_session,
|
||||
close_current_session,
|
||||
get_current_session,
|
||||
get_page,
|
||||
no_browser_error,
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"BrowserNotAvailableError",
|
||||
"SessionState",
|
||||
"browser_session",
|
||||
"close_current_session",
|
||||
"get_current_session",
|
||||
"get_page",
|
||||
"get_skyvern",
|
||||
|
||||
@@ -95,10 +95,40 @@ async def skyvern_session_close(
|
||||
with Timer() as timer:
|
||||
try:
|
||||
if session_id:
|
||||
matching_cloud_session = (
|
||||
current.context is not None
|
||||
and current.context.mode == "cloud_session"
|
||||
and current.context.session_id == session_id
|
||||
)
|
||||
|
||||
skyvern = get_skyvern()
|
||||
result = await do_session_close(skyvern, session_id)
|
||||
if current.context and current.context.session_id == session_id:
|
||||
result = None
|
||||
close_error: Exception | None = None
|
||||
try:
|
||||
result = await do_session_close(skyvern, session_id)
|
||||
except Exception as e:
|
||||
close_error = e
|
||||
|
||||
if matching_cloud_session:
|
||||
if current.browser is None:
|
||||
set_current_session(SessionState())
|
||||
raise RuntimeError("Expected active browser for matching cloud session")
|
||||
try:
|
||||
await current.browser.close()
|
||||
except Exception as browser_err:
|
||||
if close_error is not None:
|
||||
raise browser_err from close_error
|
||||
raise
|
||||
finally:
|
||||
set_current_session(SessionState())
|
||||
elif current.context and current.context.session_id == session_id:
|
||||
set_current_session(SessionState())
|
||||
|
||||
if close_error is not None:
|
||||
raise close_error
|
||||
if result is None:
|
||||
raise RuntimeError("Expected session close result after successful close operation")
|
||||
|
||||
timer.mark("sdk")
|
||||
return make_result(
|
||||
"skyvern_session_close",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -14,6 +15,8 @@ from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from skyvern.cli.console import console
|
||||
from skyvern.cli.core.client import close_skyvern
|
||||
from skyvern.cli.core.session_manager import close_current_session
|
||||
from skyvern.cli.mcp_tools import mcp # Uses standalone fastmcp (v2.x)
|
||||
from skyvern.cli.utils import start_services
|
||||
from skyvern.client import SkyvernEnvironment
|
||||
@@ -26,6 +29,42 @@ from skyvern.utils import detect_os
|
||||
from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path
|
||||
|
||||
run_app = typer.Typer(help="Commands to run Skyvern services such as the API server or UI.")
|
||||
_mcp_cleanup_done = False
|
||||
|
||||
|
||||
async def _cleanup_mcp_resources() -> None:
|
||||
try:
|
||||
await close_current_session()
|
||||
finally:
|
||||
await close_skyvern()
|
||||
|
||||
|
||||
def _cleanup_mcp_resources_blocking() -> None:
|
||||
global _mcp_cleanup_done
|
||||
if _mcp_cleanup_done:
|
||||
return
|
||||
|
||||
try:
|
||||
asyncio.run(_cleanup_mcp_resources())
|
||||
_mcp_cleanup_done = True
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning("MCP cleanup failed", exc_info=True)
|
||||
|
||||
|
||||
def _cleanup_mcp_resources_sync() -> None:
|
||||
"""Atexit callback for MCP cleanup. Skips if an event loop is still running
|
||||
because asyncio.run() cannot be called inside a running loop. This means
|
||||
cleanup is best-effort for signal-based exits (e.g. SIGTERM) that fire atexit
|
||||
while the MCP server's loop is still alive -- the finally block in run_mcp()
|
||||
handles normal shutdown instead."""
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
_cleanup_mcp_resources_blocking()
|
||||
return
|
||||
|
||||
logger.debug("Skipping MCP cleanup because event loop is still running")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -260,7 +299,14 @@ def run_mcp() -> None:
|
||||
"""Run the MCP server."""
|
||||
# This breaks the MCP processing because it expects json output only
|
||||
# console.print(Panel("[bold green]Starting MCP Server...[/bold green]", border_style="green"))
|
||||
mcp.run(transport="stdio")
|
||||
# atexit covers signal-based exits (SIGTERM); finally covers normal
|
||||
# mcp.run() completion or unhandled exceptions. Both are needed because
|
||||
# atexit doesn't fire on normal return and finally doesn't fire on signals.
|
||||
atexit.register(_cleanup_mcp_resources_sync)
|
||||
try:
|
||||
mcp.run(transport="stdio")
|
||||
finally:
|
||||
_cleanup_mcp_resources_blocking()
|
||||
|
||||
|
||||
@run_app.command(
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import typer
|
||||
@@ -147,3 +150,149 @@ class TestResolveConnection:
|
||||
monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "nonexistent.json")
|
||||
with pytest.raises(typer.BadParameter, match="No active browser connection"):
|
||||
_resolve_connection(None, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Browser command helpers and command behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBrowserCommandGuards:
|
||||
def test_resolve_ai_target_requires_selector_or_intent(self) -> None:
|
||||
from skyvern.cli.commands.browser import _resolve_ai_target
|
||||
from skyvern.cli.core.guards import GuardError
|
||||
|
||||
with pytest.raises(GuardError, match="Must provide intent, selector, or both"):
|
||||
_resolve_ai_target(None, None, operation="click")
|
||||
|
||||
def test_validate_wait_state_rejects_invalid(self) -> None:
|
||||
from skyvern.cli.commands.browser import _validate_wait_state
|
||||
from skyvern.cli.core.guards import GuardError
|
||||
|
||||
with pytest.raises(GuardError, match="Invalid state"):
|
||||
_validate_wait_state("bad-state")
|
||||
|
||||
|
||||
class TestBrowserCommands:
|
||||
def test_session_get_outputs_session_details(
|
||||
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||
) -> None:
|
||||
from skyvern.cli.commands import browser as browser_cmd
|
||||
|
||||
session_obj = SimpleNamespace(
|
||||
browser_session_id="pbs_123",
|
||||
status="active",
|
||||
started_at=datetime(2026, 2, 17, 12, 0, tzinfo=timezone.utc),
|
||||
completed_at=None,
|
||||
timeout=60,
|
||||
runnable_id=None,
|
||||
)
|
||||
skyvern = SimpleNamespace(get_browser_session=AsyncMock(return_value=session_obj))
|
||||
monkeypatch.setattr(browser_cmd, "get_skyvern", lambda: skyvern)
|
||||
monkeypatch.setattr(browser_cmd, "load_state", lambda: CLIState(session_id="pbs_123", mode="cloud"))
|
||||
|
||||
browser_cmd.session_get(session="pbs_123", json_output=True)
|
||||
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is True
|
||||
assert parsed["action"] == "session_get"
|
||||
assert parsed["data"]["session_id"] == "pbs_123"
|
||||
assert parsed["data"]["is_current"] is True
|
||||
|
||||
def test_evaluate_blocks_password_js_before_connection(
|
||||
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||
) -> None:
|
||||
from skyvern.cli.commands import browser as browser_cmd
|
||||
|
||||
monkeypatch.setattr(
|
||||
browser_cmd,
|
||||
"_resolve_connection",
|
||||
lambda _session, _cdp: (_ for _ in ()).throw(AssertionError("should not resolve connection")),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
browser_cmd.evaluate(
|
||||
expression='document.querySelector("input[type=password]").value = ""', json_output=True
|
||||
)
|
||||
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is False
|
||||
assert "Cannot set password field values" in parsed["error"]["message"]
|
||||
|
||||
def test_click_requires_target_before_connection(
|
||||
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||
) -> None:
|
||||
from skyvern.cli.commands import browser as browser_cmd
|
||||
|
||||
monkeypatch.setattr(
|
||||
browser_cmd,
|
||||
"_resolve_connection",
|
||||
lambda _session, _cdp: (_ for _ in ()).throw(AssertionError("should not resolve connection")),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
browser_cmd.click(
|
||||
intent=None,
|
||||
selector=None,
|
||||
session=None,
|
||||
cdp=None,
|
||||
timeout=30000,
|
||||
button=None,
|
||||
click_count=None,
|
||||
json_output=True,
|
||||
)
|
||||
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is False
|
||||
assert "Must provide intent, selector, or both" in parsed["error"]["message"]
|
||||
|
||||
def test_click_with_intent_uses_proactive_ai_mode(
|
||||
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||
) -> None:
|
||||
from skyvern.cli.commands import browser as browser_cmd
|
||||
|
||||
page = MagicMock()
|
||||
page.click = AsyncMock(return_value="xpath=//button[@id='submit']")
|
||||
browser = SimpleNamespace(get_working_page=AsyncMock(return_value=page))
|
||||
|
||||
monkeypatch.setattr(
|
||||
browser_cmd,
|
||||
"_resolve_connection",
|
||||
lambda _session, _cdp: browser_cmd.ConnectionTarget(mode="cloud", session_id="pbs_123"),
|
||||
)
|
||||
monkeypatch.setattr(browser_cmd, "_connect_browser", AsyncMock(return_value=browser))
|
||||
|
||||
browser_cmd.click(
|
||||
intent="the Submit button",
|
||||
selector=None,
|
||||
session="pbs_123",
|
||||
cdp=None,
|
||||
timeout=30000,
|
||||
button=None,
|
||||
click_count=None,
|
||||
json_output=True,
|
||||
)
|
||||
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is True
|
||||
assert parsed["action"] == "click"
|
||||
assert parsed["data"]["ai_mode"] == "proactive"
|
||||
assert parsed["data"]["resolved_selector"] == "xpath=//button[@id='submit']"
|
||||
|
||||
def test_wait_rejects_invalid_state_before_connection(
|
||||
self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
|
||||
) -> None:
|
||||
from skyvern.cli.commands import browser as browser_cmd
|
||||
|
||||
monkeypatch.setattr(
|
||||
browser_cmd,
|
||||
"_resolve_connection",
|
||||
lambda _session, _cdp: (_ for _ in ()).throw(AssertionError("should not resolve connection")),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
browser_cmd.wait(state="bad-state", time_ms=1000, json_output=True)
|
||||
|
||||
parsed = json.loads(capsys.readouterr().out)
|
||||
assert parsed["ok"] is False
|
||||
assert "Invalid state" in parsed["error"]["message"]
|
||||
|
||||
264
tests/unit/test_mcp_session_lifecycle.py
Normal file
264
tests/unit/test_mcp_session_lifecycle.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.cli.core import client as client_mod
|
||||
from skyvern.cli.core import session_manager
|
||||
from skyvern.cli.core.result import BrowserContext
|
||||
from skyvern.cli.core.session_ops import SessionCloseResult
|
||||
from skyvern.cli.mcp_tools import session as mcp_session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singletons() -> None:
|
||||
client_mod._skyvern_instance.set(None)
|
||||
client_mod._global_skyvern_instance = None
|
||||
|
||||
session_manager._current_session.set(None)
|
||||
session_manager._global_session = None
|
||||
mcp_session.set_current_session(mcp_session.SessionState())
|
||||
|
||||
|
||||
def test_get_skyvern_reuses_global_instance_across_contexts(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created: list[object] = []
|
||||
|
||||
class FakeSkyvern:
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
created.append(self)
|
||||
|
||||
@classmethod
|
||||
def local(cls) -> FakeSkyvern:
|
||||
return cls()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None)
|
||||
monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None)
|
||||
|
||||
first = client_mod.get_skyvern()
|
||||
client_mod._skyvern_instance.set(None) # Simulate a new async context.
|
||||
second = client_mod.get_skyvern()
|
||||
|
||||
assert first is second
|
||||
assert len(created) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_skyvern_closes_singleton() -> None:
|
||||
fake = MagicMock()
|
||||
fake.aclose = AsyncMock()
|
||||
|
||||
client_mod._skyvern_instance.set(fake)
|
||||
client_mod._global_skyvern_instance = fake
|
||||
|
||||
await client_mod.close_skyvern()
|
||||
|
||||
fake.aclose.assert_awaited_once()
|
||||
assert client_mod._skyvern_instance.get() is None
|
||||
assert client_mod._global_skyvern_instance is None
|
||||
|
||||
|
||||
def test_get_current_session_falls_back_to_global_state() -> None:
|
||||
state = session_manager.SessionState(
|
||||
browser=MagicMock(),
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
)
|
||||
session_manager.set_current_session(state)
|
||||
|
||||
session_manager._current_session.set(None) # Simulate a new async context.
|
||||
recovered = session_manager.get_current_session()
|
||||
|
||||
assert recovered is state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_browser = MagicMock()
|
||||
current_state = session_manager.SessionState(
|
||||
browser=current_browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_123"),
|
||||
)
|
||||
session_manager.set_current_session(current_state)
|
||||
|
||||
fake_skyvern = MagicMock()
|
||||
fake_skyvern.connect_to_cloud_browser_session = AsyncMock()
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
browser, ctx = await session_manager.resolve_browser(session_id="pbs_123")
|
||||
|
||||
assert browser is current_browser
|
||||
assert ctx.session_id == "pbs_123"
|
||||
fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_browser = MagicMock()
|
||||
current_browser.close = AsyncMock()
|
||||
mcp_session.set_current_session(
|
||||
mcp_session.SessionState(
|
||||
browser=current_browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_456"),
|
||||
)
|
||||
)
|
||||
|
||||
fake_skyvern = MagicMock()
|
||||
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
do_session_close = AsyncMock(return_value=SessionCloseResult(session_id="pbs_456", closed=True))
|
||||
monkeypatch.setattr(mcp_session, "do_session_close", do_session_close)
|
||||
|
||||
result = await mcp_session.skyvern_session_close(session_id="pbs_456")
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["data"] == {"session_id": "pbs_456", "closed": True}
|
||||
current_browser.close.assert_awaited_once()
|
||||
do_session_close.assert_awaited_once_with(fake_skyvern, "pbs_456")
|
||||
assert mcp_session.get_current_session().browser is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_close_chains_exceptions_when_both_api_and_browser_fail(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When both do_session_close (API) and browser.close() raise, the browser
|
||||
exception should chain the API exception via __cause__ so neither is lost."""
|
||||
current_browser = MagicMock()
|
||||
browser_error = RuntimeError("browser close failed")
|
||||
current_browser.close = AsyncMock(side_effect=browser_error)
|
||||
mcp_session.set_current_session(
|
||||
mcp_session.SessionState(
|
||||
browser=current_browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_dual"),
|
||||
)
|
||||
)
|
||||
|
||||
fake_skyvern = MagicMock()
|
||||
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
api_error = ConnectionError("API close failed")
|
||||
do_session_close = AsyncMock(side_effect=api_error)
|
||||
monkeypatch.setattr(mcp_session, "do_session_close", do_session_close)
|
||||
|
||||
result = await mcp_session.skyvern_session_close(session_id="pbs_dual")
|
||||
|
||||
# The outer exception handler catches and returns an error result
|
||||
assert result["ok"] is False
|
||||
assert "browser close failed" in result["error"]["message"]
|
||||
# Session state is cleaned up regardless
|
||||
assert mcp_session.get_current_session().browser is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_close_matching_context_without_browser_returns_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
mcp_session.set_current_session(
|
||||
mcp_session.SessionState(
|
||||
browser=None,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_999"),
|
||||
)
|
||||
)
|
||||
|
||||
fake_skyvern = MagicMock()
|
||||
monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
do_session_close = AsyncMock(return_value=SessionCloseResult(session_id="pbs_999", closed=True))
|
||||
monkeypatch.setattr(mcp_session, "do_session_close", do_session_close)
|
||||
|
||||
result = await mcp_session.skyvern_session_close(session_id="pbs_999")
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["error"]["code"] == mcp_session.ErrorCode.SDK_ERROR
|
||||
assert "Expected active browser for matching cloud session" in result["error"]["message"]
|
||||
do_session_close.assert_awaited_once_with(fake_skyvern, "pbs_999")
|
||||
assert mcp_session.get_current_session().context is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for close_current_session() — cloud session API cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_current_session_calls_api_close_for_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""close_current_session() should call do_session_close for cloud sessions
|
||||
and clear _browser_session_id to avoid a duplicate API call from browser.close()."""
|
||||
browser = MagicMock()
|
||||
browser.close = AsyncMock()
|
||||
browser._browser_session_id = "pbs_api"
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_api"),
|
||||
)
|
||||
)
|
||||
|
||||
fake_skyvern = MagicMock()
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
do_session_close = AsyncMock(return_value=SessionCloseResult(session_id="pbs_api", closed=True))
|
||||
monkeypatch.setattr("skyvern.cli.core.session_ops.do_session_close", do_session_close)
|
||||
|
||||
await session_manager.close_current_session()
|
||||
|
||||
do_session_close.assert_awaited_once_with(fake_skyvern, "pbs_api")
|
||||
browser.close.assert_awaited_once()
|
||||
# _browser_session_id should be cleared to prevent redundant API call
|
||||
assert browser._browser_session_id is None
|
||||
assert session_manager.get_current_session().browser is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_current_session_skips_api_close_for_local_session(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""close_current_session() should NOT call do_session_close for local sessions."""
|
||||
browser = MagicMock()
|
||||
browser.close = AsyncMock()
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=browser,
|
||||
context=BrowserContext(mode="local"),
|
||||
)
|
||||
)
|
||||
|
||||
do_session_close = AsyncMock()
|
||||
monkeypatch.setattr("skyvern.cli.core.session_ops.do_session_close", do_session_close)
|
||||
|
||||
await session_manager.close_current_session()
|
||||
|
||||
do_session_close.assert_not_awaited()
|
||||
browser.close.assert_awaited_once()
|
||||
assert session_manager.get_current_session().browser is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_current_session_still_closes_browser_when_api_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""When do_session_close raises, browser.close() should still run and state should be cleared."""
|
||||
browser = MagicMock()
|
||||
browser.close = AsyncMock()
|
||||
browser._browser_session_id = "pbs_fail"
|
||||
session_manager.set_current_session(
|
||||
session_manager.SessionState(
|
||||
browser=browser,
|
||||
context=BrowserContext(mode="cloud_session", session_id="pbs_fail"),
|
||||
)
|
||||
)
|
||||
|
||||
fake_skyvern = MagicMock()
|
||||
monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern)
|
||||
|
||||
do_session_close = AsyncMock(side_effect=ConnectionError("API unreachable"))
|
||||
monkeypatch.setattr("skyvern.cli.core.session_ops.do_session_close", do_session_close)
|
||||
|
||||
await session_manager.close_current_session()
|
||||
|
||||
do_session_close.assert_awaited_once_with(fake_skyvern, "pbs_fail")
|
||||
# browser.close() should still be called despite API failure
|
||||
browser.close.assert_awaited_once()
|
||||
# _browser_session_id should NOT be cleared (API close failed, let browser.close() try)
|
||||
assert browser._browser_session_id == "pbs_fail"
|
||||
assert session_manager.get_current_session().browser is None
|
||||
63
tests/unit/test_run_commands_cleanup.py
Normal file
63
tests/unit/test_run_commands_cleanup.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.cli import run_commands
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_cleanup_state() -> None:
|
||||
run_commands._mcp_cleanup_done = False
|
||||
|
||||
|
||||
def test_cleanup_mcp_resources_sync_runs_without_running_loop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = AsyncMock()
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup)
|
||||
|
||||
run_commands._cleanup_mcp_resources_sync()
|
||||
|
||||
cleanup.assert_awaited_once()
|
||||
assert run_commands._mcp_cleanup_done is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_mcp_resources_sync_skips_when_loop_running(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = AsyncMock()
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup)
|
||||
|
||||
run_commands._cleanup_mcp_resources_sync()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
cleanup.assert_not_awaited()
|
||||
assert run_commands._mcp_cleanup_done is False
|
||||
|
||||
|
||||
def test_cleanup_mcp_resources_sync_keeps_retry_possible_on_task_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def failing_cleanup() -> None:
|
||||
raise RuntimeError("cleanup failed")
|
||||
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", failing_cleanup)
|
||||
|
||||
run_commands._cleanup_mcp_resources_sync()
|
||||
|
||||
assert run_commands._mcp_cleanup_done is False
|
||||
|
||||
|
||||
def test_run_mcp_calls_blocking_cleanup_in_finally(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup_blocking = MagicMock()
|
||||
register = MagicMock()
|
||||
run = MagicMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking)
|
||||
monkeypatch.setattr(run_commands.atexit, "register", register)
|
||||
monkeypatch.setattr(run_commands.mcp, "run", run)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
run_commands.run_mcp()
|
||||
|
||||
register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync)
|
||||
run.assert_called_once_with(transport="stdio")
|
||||
cleanup_blocking.assert_called_once()
|
||||
Reference in New Issue
Block a user