diff --git a/skyvern/cli/commands/browser.py b/skyvern/cli/commands/browser.py index 1c66d59d..a4e2fbd7 100644 --- a/skyvern/cli/commands/browser.py +++ b/skyvern/cli/commands/browser.py @@ -13,7 +13,17 @@ from skyvern.cli.commands._state import CLIState, clear_state, load_state, save_ from skyvern.cli.core.artifacts import save_artifact from skyvern.cli.core.browser_ops import do_act, do_extract, do_navigate, do_screenshot from skyvern.cli.core.client import get_skyvern -from skyvern.cli.core.guards import GuardError, check_password_prompt, validate_wait_until +from skyvern.cli.core.guards import ( + CREDENTIAL_HINT, + PASSWORD_PATTERN, + VALID_ELEMENT_STATES, + GuardError, + check_js_password, + check_password_prompt, + resolve_ai_mode, + validate_button, + validate_wait_until, +) from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list browser_app = typer.Typer(help="Browser automation commands.", no_args_is_help=True) @@ -64,6 +74,24 @@ async def _connect_browser(connection: ConnectionTarget) -> Any: return await skyvern.connect_to_browser_over_cdp(connection.cdp_url) +def _resolve_ai_target(selector: str | None, intent: str | None, *, operation: str) -> str | None: + ai_mode, err = resolve_ai_mode(selector, intent) + if err: + raise GuardError( + "Must provide intent, selector, or both", + ( + f"Use intent='describe what to {operation}' for AI-powered targeting, " + "or selector='#css-selector' for precise targeting" + ), + ) + return ai_mode + + +def _validate_wait_state(state: str) -> None: + if state not in VALID_ELEMENT_STATES: + raise GuardError(f"Invalid state: {state}", "Use visible, hidden, attached, or detached") + + # --------------------------------------------------------------------------- # Session commands # --------------------------------------------------------------------------- @@ -190,6 +218,35 @@ def session_list( output_error(str(e), hint="Check your API key and network connection.", json_mode=json_output) +@session_app.command("get") +def session_get( + session: str = typer.Option(..., "--session", "--id", help="Browser session ID."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Get details for a browser session.""" + + async def _run() -> dict: + skyvern = get_skyvern() + resolved = await skyvern.get_browser_session(session) + state = load_state() + is_current = bool(state and state.mode == "cloud" and state.session_id == session) + return { + "session_id": resolved.browser_session_id, + "status": resolved.status, + "started_at": resolved.started_at.isoformat() if resolved.started_at else None, + "completed_at": resolved.completed_at.isoformat() if resolved.completed_at else None, + "timeout": resolved.timeout, + "runnable_id": resolved.runnable_id, + "is_current": is_current, + } + + try: + data = asyncio.run(_run()) + output(data, action="session_get", json_mode=json_output) + except Exception as e: + output_error(str(e), hint="Verify the session ID exists and is accessible.", json_mode=json_output) + + # --------------------------------------------------------------------------- # Browser commands # --------------------------------------------------------------------------- @@ -269,6 +326,392 @@ def screenshot( output_error(str(e), hint="Ensure the session is active and the page has loaded.", json_mode=json_output) +@browser_app.command("evaluate") +def evaluate( + expression: str = typer.Option(..., help="JavaScript expression to evaluate."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Run JavaScript on the current page.""" + + async def _run() -> dict: + check_js_password(expression) + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + result = await page.evaluate(expression) + return {"result": result} + + try: + data = asyncio.run(_run()) + output(data, action="evaluate", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Check JavaScript syntax and page state.", json_mode=json_output) + + +@browser_app.command("click") +def click( + intent: str | None = typer.Option(None, help="Natural language description of the element to click."), + selector: str | None = typer.Option(None, help="CSS selector or XPath for the element to click."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + timeout: int = typer.Option(30000, help="Max wait time in milliseconds."), + button: str | None = typer.Option(None, help="Mouse button: left, right, or middle."), + click_count: int | None = typer.Option(None, "--click-count", help="Number of clicks (2 for double-click)."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Click an element using selector, intent, or both.""" + + async def _run() -> dict: + validate_button(button) + ai_mode = _resolve_ai_target(selector, intent, operation="click") + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + + kwargs: dict[str, Any] = {"timeout": timeout} + if button: + kwargs["button"] = button + if click_count is not None: + kwargs["click_count"] = click_count + + if ai_mode is not None: + resolved = await page.click(selector=selector, prompt=intent, ai=ai_mode, **kwargs) # type: ignore[arg-type] + else: + assert selector is not None + resolved = await page.click(selector=selector, **kwargs) + + data: dict[str, Any] = {"selector": selector, "intent": intent, "ai_mode": ai_mode} + if resolved and resolved != selector: + data["resolved_selector"] = resolved + return data + + try: + data = asyncio.run(_run()) + output(data, action="click", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Element may be hidden, disabled, or not yet available.", json_mode=json_output) + + +@browser_app.command("hover") +def hover( + intent: str | None = typer.Option(None, help="Natural language description of the element to hover."), + selector: str | None = typer.Option(None, help="CSS selector or XPath for the element to hover."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + timeout: int = typer.Option(30000, help="Max wait time in milliseconds."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Hover over an element using selector, intent, or both.""" + + async def _run() -> dict: + ai_mode = _resolve_ai_target(selector, intent, operation="hover") + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + + if ai_mode is not None: + locator = page.locator(selector=selector, prompt=intent, ai=ai_mode) # type: ignore[arg-type] + else: + assert selector is not None + locator = page.locator(selector) + await locator.hover(timeout=timeout) + return {"selector": selector, "intent": intent, "ai_mode": ai_mode} + + try: + data = asyncio.run(_run()) + output(data, action="hover", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Element may be hidden or not interactable.", json_mode=json_output) + + +@browser_app.command("type") +def type_text( + text: str = typer.Option(..., help="Text to type into the input."), + intent: str | None = typer.Option(None, help="Natural language description of the input field."), + selector: str | None = typer.Option(None, help="CSS selector or XPath for the input field."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + timeout: int = typer.Option(30000, help="Max wait time in milliseconds."), + clear: bool = typer.Option(True, "--clear/--no-clear", help="Clear existing content before typing."), + delay: int | None = typer.Option(None, help="Delay between keystrokes in milliseconds."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Type into an input field using selector, intent, or both.""" + + async def _run() -> dict: + target_text = f"{intent or ''} {selector or ''}" + if PASSWORD_PATTERN.search(target_text): + raise GuardError( + "Cannot type into password fields — credentials must not be passed through tool calls", + CREDENTIAL_HINT, + ) + + ai_mode = _resolve_ai_target(selector, intent, operation="type") + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + + if selector: + try: + is_password = await page.evaluate( + "(s) => { const el = document.querySelector(s); return !!(el && el.type === 'password'); }", + selector, + ) + except Exception: + is_password = False + if is_password: + raise GuardError( + "Cannot type into password fields — credentials must not be passed through tool calls", + CREDENTIAL_HINT, + ) + + if clear: + if ai_mode is not None: + await page.fill(selector=selector, value=text, prompt=intent, ai=ai_mode, timeout=timeout) # type: ignore[arg-type] + else: + assert selector is not None + await page.fill(selector, text, timeout=timeout) + else: + kwargs: dict[str, Any] = {"timeout": timeout} + if delay is not None: + kwargs["delay"] = delay + if ai_mode is not None: + locator = page.locator(selector=selector, prompt=intent, ai=ai_mode) # type: ignore[arg-type] + await locator.type(text, **kwargs) + else: + assert selector is not None + await page.type(selector, text, **kwargs) + + return {"selector": selector, "intent": intent, "ai_mode": ai_mode, "text_length": len(text)} + + try: + data = asyncio.run(_run()) + output(data, action="type", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Element may not be editable or may be obscured.", json_mode=json_output) + + +@browser_app.command("scroll") +def scroll( + direction: str = typer.Option(..., help="Direction: up, down, left, right."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + amount: int | None = typer.Option(None, help="Pixels to scroll (default 500)."), + intent: str | None = typer.Option(None, help="Natural language element to scroll into view."), + selector: str | None = typer.Option(None, help="CSS selector of scrollable element."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Scroll the page or scroll a targeted element into view.""" + + async def _run() -> dict: + valid_directions = ("up", "down", "left", "right") + if not intent and direction not in valid_directions: + raise GuardError(f"Invalid direction: {direction}", "Use up, down, left, or right") + + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + + if intent: + ai_mode = "fallback" if selector else "proactive" + locator = page.locator(selector=selector, prompt=intent, ai=ai_mode) + await locator.scroll_into_view_if_needed() + return {"direction": "into_view", "intent": intent, "selector": selector, "ai_mode": ai_mode} + + pixels = amount or 500 + direction_map = {"up": (0, -pixels), "down": (0, pixels), "left": (-pixels, 0), "right": (pixels, 0)} + dx, dy = direction_map[direction] + + if selector: + await page.locator(selector).evaluate(f"el => el.scrollBy({dx}, {dy})") + else: + await page.evaluate(f"window.scrollBy({dx}, {dy})") + + return {"direction": direction, "pixels": pixels, "selector": selector} + + try: + data = asyncio.run(_run()) + output(data, action="scroll", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Scroll failed; check selector and page readiness.", json_mode=json_output) + + +@browser_app.command("select") +def select( + value: str = typer.Option(..., help="Option value to select."), + intent: str | None = typer.Option(None, help="Natural language description of the dropdown."), + selector: str | None = typer.Option(None, help="CSS selector for the dropdown."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + timeout: int = typer.Option(30000, help="Max wait time in milliseconds."), + by_label: bool = typer.Option(False, "--by-label", help="Select by visible label instead of value."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Select an option from a dropdown.""" + + async def _run() -> dict: + ai_mode = _resolve_ai_target(selector, intent, operation="select") + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + + if ai_mode is not None: + await page.select_option(selector=selector, value=value, prompt=intent, ai=ai_mode, timeout=timeout) # type: ignore[arg-type] + else: + assert selector is not None + if by_label: + await page.page.locator(selector).select_option(label=value, timeout=timeout) + else: + await page.select_option(selector, value=value, timeout=timeout) + + return {"selector": selector, "intent": intent, "ai_mode": ai_mode, "value": value, "by_label": by_label} + + try: + data = asyncio.run(_run()) + output(data, action="select", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Check dropdown selector and available options.", json_mode=json_output) + + +@browser_app.command("press-key") +def press_key( + key: str = typer.Option(..., help="Key to press (e.g., Enter, Tab, Escape)."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + intent: str | None = typer.Option(None, help="Natural language description of element to focus first."), + selector: str | None = typer.Option(None, help="CSS selector to focus first."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Press a keyboard key.""" + + async def _run() -> dict: + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + + if intent or selector: + ai_mode, err = resolve_ai_mode(selector, intent) + if err: + raise GuardError( + "Must provide intent, selector, or both", + "Use intent='describe where to press' or selector='#css-selector'", + ) + if ai_mode is not None: + locator = page.locator(selector=selector, prompt=intent, ai=ai_mode) # type: ignore[arg-type] + else: + assert selector is not None + locator = page.locator(selector) + await locator.press(key) + else: + await page.keyboard.press(key) + + return {"key": key, "selector": selector, "intent": intent} + + try: + data = asyncio.run(_run()) + output(data, action="press_key", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Check key name and focused target.", json_mode=json_output) + + +@browser_app.command("wait") +def wait( + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + time_ms: int | None = typer.Option(None, "--time", help="Milliseconds to wait."), + intent: str | None = typer.Option(None, help="Natural language condition to wait for."), + selector: str | None = typer.Option(None, help="CSS selector to wait for."), + state: str = typer.Option("visible", help="Element state: visible, hidden, attached, detached."), + timeout: int = typer.Option(30000, help="Max wait time in milliseconds."), + poll_interval: int = typer.Option(5000, "--poll-interval", help="Polling interval for intent waits in ms."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Wait for time, selector state, or AI condition.""" + + async def _run() -> dict: + _validate_wait_state(state) + if time_ms is None and not selector and not intent: + raise GuardError( + "Must provide intent, selector, or time_ms", + "Use --time, --selector, or --intent to specify what to wait for", + ) + + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + + waited_for = "" + if time_ms is not None: + await page.wait_for_timeout(time_ms) + waited_for = "time" + elif intent: + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout / 1000 + last_error: Exception | None = None + while True: + try: + ready = await page.validate(intent) + last_error = None + except Exception as poll_error: + ready = False + last_error = poll_error + + if ready: + waited_for = "intent" + break + if loop.time() >= deadline: + if last_error: + raise RuntimeError(str(last_error)) + raise TimeoutError(f"Condition not met within {timeout}ms: {intent}") + await page.wait_for_timeout(poll_interval) + else: + assert selector is not None + await page.wait_for_selector(selector, state=state, timeout=timeout) + waited_for = "selector" + + return {"waited_for": waited_for, "state": state, "selector": selector, "intent": intent} + + try: + data = asyncio.run(_run()) + output(data, action="wait", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Condition was not met within timeout.", json_mode=json_output) + + @browser_app.command("act") def act( prompt: str = typer.Option(..., help="Natural language action to perform."), @@ -323,3 +766,28 @@ def extract( raise except Exception as e: output_error(str(e), hint="Simplify the prompt or provide a JSON schema.", json_mode=json_output) + + +@browser_app.command("validate") +def validate( + prompt: str = typer.Option(..., help="Validation condition to check."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Check whether a natural language condition is true on the current page.""" + + async def _run() -> dict: + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + valid = await page.validate(prompt) + return {"prompt": prompt, "valid": valid} + + try: + data = asyncio.run(_run()) + output(data, action="validate", json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Check the page state and validation prompt.", json_mode=json_output) diff --git a/skyvern/cli/core/client.py b/skyvern/cli/core/client.py index fabae2ec..ef02634a 100644 --- a/skyvern/cli/core/client.py +++ b/skyvern/cli/core/client.py @@ -3,17 +3,26 @@ from __future__ import annotations import os from contextvars import ContextVar +import structlog + from skyvern.client import SkyvernEnvironment from skyvern.config import settings from skyvern.library.skyvern import Skyvern _skyvern_instance: ContextVar[Skyvern | None] = ContextVar("skyvern_instance", default=None) +_global_skyvern_instance: Skyvern | None = None +LOG = structlog.get_logger(__name__) def get_skyvern() -> Skyvern: """Get or create a Skyvern client instance.""" + global _global_skyvern_instance + instance = _skyvern_instance.get() + if instance is None: + instance = _global_skyvern_instance if instance is not None: + _skyvern_instance.set(instance) return instance api_key = settings.SKYVERN_API_KEY or os.environ.get("SKYVERN_API_KEY") @@ -28,5 +37,28 @@ def get_skyvern() -> Skyvern: else: instance = Skyvern.local() + _global_skyvern_instance = instance _skyvern_instance.set(instance) return instance + + +async def close_skyvern() -> None: + """Close active Skyvern client(s) and release Playwright resources.""" + global _global_skyvern_instance + + instances: list[Skyvern] = [] + seen: set[int] = set() + for candidate in (_skyvern_instance.get(), _global_skyvern_instance): + if candidate is None or id(candidate) in seen: + continue + seen.add(id(candidate)) + instances.append(candidate) + + for instance in instances: + try: + await instance.aclose() + except Exception: + LOG.warning("Failed to close Skyvern client", exc_info=True) + + _skyvern_instance.set(None) + _global_skyvern_instance = None diff --git a/skyvern/cli/core/session_manager.py b/skyvern/cli/core/session_manager.py index af83ba79..2b48bbe5 100644 --- a/skyvern/cli/core/session_manager.py +++ b/skyvern/cli/core/session_manager.py @@ -5,9 +5,13 @@ from contextvars import ContextVar from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, AsyncIterator +import structlog + from .client import get_skyvern from .result import BrowserContext, ErrorCode, make_error +LOG = structlog.get_logger(__name__) + if TYPE_CHECKING: from skyvern.library.skyvern_browser import SkyvernBrowser from skyvern.library.skyvern_browser_page import SkyvernBrowserPage @@ -23,20 +27,46 @@ class SessionState: _current_session: ContextVar[SessionState | None] = ContextVar("mcp_session", default=None) +_global_session: SessionState | None = None def get_current_session() -> SessionState: + global _global_session + state = _current_session.get() if state is None: - state = SessionState() + if _global_session is None: + _global_session = SessionState() + state = _global_session _current_session.set(state) return state def set_current_session(state: SessionState) -> None: + global _global_session + _global_session = state _current_session.set(state) +def _matches_current( + current: SessionState, + *, + session_id: str | None = None, + cdp_url: str | None = None, + local: bool = False, +) -> bool: + if current.browser is None or current.context is None: + return False + + if session_id: + return current.context.mode == "cloud_session" and current.context.session_id == session_id + if cdp_url: + return current.context.mode == "cdp" and current.context.cdp_url == cdp_url + if local: + return current.context.mode == "local" + return False + + async def resolve_browser( session_id: str | None = None, cdp_url: str | None = None, @@ -54,6 +84,11 @@ async def resolve_browser( skyvern = get_skyvern() current = get_current_session() + if _matches_current(current, session_id=session_id, cdp_url=cdp_url, local=local): + # _matches_current() guarantees both are non-None + assert current.browser is not None and current.context is not None + return current.browser, current.context + browser: SkyvernBrowser | None = None try: if session_id: @@ -94,6 +129,31 @@ async def resolve_browser( raise BrowserNotAvailableError() +async def close_current_session() -> None: + """Close the active browser session (if any) and clear local session state.""" + from .session_ops import do_session_close + + current = get_current_session() + try: + if current.context and current.context.mode == "cloud_session" and current.context.session_id: + try: + skyvern = get_skyvern() + await do_session_close(skyvern, current.context.session_id) + # Prevent SkyvernBrowser.close() from making a redundant API call + if current.browser is not None: + current.browser._browser_session_id = None + except Exception: + LOG.warning( + "Best-effort cloud session close failed", + session_id=current.context.session_id, + exc_info=True, + ) + if current.browser is not None: + await current.browser.close() + finally: + set_current_session(SessionState()) + + async def get_page( session_id: str | None = None, cdp_url: str | None = None, diff --git a/skyvern/cli/mcp_tools/_session.py b/skyvern/cli/mcp_tools/_session.py index 0fd345fe..1f508ee1 100644 --- a/skyvern/cli/mcp_tools/_session.py +++ b/skyvern/cli/mcp_tools/_session.py @@ -10,6 +10,7 @@ from skyvern.cli.core.session_manager import ( BrowserNotAvailableError, SessionState, browser_session, + close_current_session, get_current_session, get_page, no_browser_error, @@ -21,6 +22,7 @@ __all__ = [ "BrowserNotAvailableError", "SessionState", "browser_session", + "close_current_session", "get_current_session", "get_page", "get_skyvern", diff --git a/skyvern/cli/mcp_tools/session.py b/skyvern/cli/mcp_tools/session.py index 3366f460..7b22fa89 100644 --- a/skyvern/cli/mcp_tools/session.py +++ b/skyvern/cli/mcp_tools/session.py @@ -95,10 +95,40 @@ async def skyvern_session_close( with Timer() as timer: try: if session_id: + matching_cloud_session = ( + current.context is not None + and current.context.mode == "cloud_session" + and current.context.session_id == session_id + ) + skyvern = get_skyvern() - result = await do_session_close(skyvern, session_id) - if current.context and current.context.session_id == session_id: + result = None + close_error: Exception | None = None + try: + result = await do_session_close(skyvern, session_id) + except Exception as e: + close_error = e + + if matching_cloud_session: + if current.browser is None: + set_current_session(SessionState()) + raise RuntimeError("Expected active browser for matching cloud session") + try: + await current.browser.close() + except Exception as browser_err: + if close_error is not None: + raise browser_err from close_error + raise + finally: + set_current_session(SessionState()) + elif current.context and current.context.session_id == session_id: set_current_session(SessionState()) + + if close_error is not None: + raise close_error + if result is None: + raise RuntimeError("Expected session close result after successful close operation") + timer.mark("sdk") return make_result( "skyvern_session_close", diff --git a/skyvern/cli/run_commands.py b/skyvern/cli/run_commands.py index b44688c7..1d48be22 100644 --- a/skyvern/cli/run_commands.py +++ b/skyvern/cli/run_commands.py @@ -1,4 +1,5 @@ import asyncio +import atexit import json import logging import os @@ -14,6 +15,8 @@ from rich.panel import Panel from rich.prompt import Confirm from skyvern.cli.console import console +from skyvern.cli.core.client import close_skyvern +from skyvern.cli.core.session_manager import close_current_session from skyvern.cli.mcp_tools import mcp # Uses standalone fastmcp (v2.x) from skyvern.cli.utils import start_services from skyvern.client import SkyvernEnvironment @@ -26,6 +29,42 @@ from skyvern.utils import detect_os from skyvern.utils.env_paths import resolve_backend_env_path, resolve_frontend_env_path run_app = typer.Typer(help="Commands to run Skyvern services such as the API server or UI.") +_mcp_cleanup_done = False + + +async def _cleanup_mcp_resources() -> None: + try: + await close_current_session() + finally: + await close_skyvern() + + +def _cleanup_mcp_resources_blocking() -> None: + global _mcp_cleanup_done + if _mcp_cleanup_done: + return + + try: + asyncio.run(_cleanup_mcp_resources()) + _mcp_cleanup_done = True + except Exception: + logging.getLogger(__name__).warning("MCP cleanup failed", exc_info=True) + + +def _cleanup_mcp_resources_sync() -> None: + """Atexit callback for MCP cleanup. Skips if an event loop is still running + because asyncio.run() cannot be called inside a running loop. This means + cleanup is best-effort for signal-based exits (e.g. SIGTERM) that fire atexit + while the MCP server's loop is still alive -- the finally block in run_mcp() + handles normal shutdown instead.""" + logger = logging.getLogger(__name__) + try: + asyncio.get_running_loop() + except RuntimeError: + _cleanup_mcp_resources_blocking() + return + + logger.debug("Skipping MCP cleanup because event loop is still running") @mcp.tool() @@ -260,7 +299,14 @@ def run_mcp() -> None: """Run the MCP server.""" # This breaks the MCP processing because it expects json output only # console.print(Panel("[bold green]Starting MCP Server...[/bold green]", border_style="green")) - mcp.run(transport="stdio") + # atexit covers signal-based exits (SIGTERM); finally covers normal + # mcp.run() completion or unhandled exceptions. Both are needed because + # atexit doesn't fire on normal return and finally doesn't fire on signals. + atexit.register(_cleanup_mcp_resources_sync) + try: + mcp.run(transport="stdio") + finally: + _cleanup_mcp_resources_blocking() @run_app.command( diff --git a/tests/unit/test_cli_commands.py b/tests/unit/test_cli_commands.py index 3f142b4a..08edf9a7 100644 --- a/tests/unit/test_cli_commands.py +++ b/tests/unit/test_cli_commands.py @@ -3,7 +3,10 @@ from __future__ import annotations import json +from datetime import datetime, timezone from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock import pytest import typer @@ -147,3 +150,149 @@ class TestResolveConnection: monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "nonexistent.json") with pytest.raises(typer.BadParameter, match="No active browser connection"): _resolve_connection(None, None) + + +# --------------------------------------------------------------------------- +# Browser command helpers and command behavior +# --------------------------------------------------------------------------- + + +class TestBrowserCommandGuards: + def test_resolve_ai_target_requires_selector_or_intent(self) -> None: + from skyvern.cli.commands.browser import _resolve_ai_target + from skyvern.cli.core.guards import GuardError + + with pytest.raises(GuardError, match="Must provide intent, selector, or both"): + _resolve_ai_target(None, None, operation="click") + + def test_validate_wait_state_rejects_invalid(self) -> None: + from skyvern.cli.commands.browser import _validate_wait_state + from skyvern.cli.core.guards import GuardError + + with pytest.raises(GuardError, match="Invalid state"): + _validate_wait_state("bad-state") + + +class TestBrowserCommands: + def test_session_get_outputs_session_details( + self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture + ) -> None: + from skyvern.cli.commands import browser as browser_cmd + + session_obj = SimpleNamespace( + browser_session_id="pbs_123", + status="active", + started_at=datetime(2026, 2, 17, 12, 0, tzinfo=timezone.utc), + completed_at=None, + timeout=60, + runnable_id=None, + ) + skyvern = SimpleNamespace(get_browser_session=AsyncMock(return_value=session_obj)) + monkeypatch.setattr(browser_cmd, "get_skyvern", lambda: skyvern) + monkeypatch.setattr(browser_cmd, "load_state", lambda: CLIState(session_id="pbs_123", mode="cloud")) + + browser_cmd.session_get(session="pbs_123", json_output=True) + + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is True + assert parsed["action"] == "session_get" + assert parsed["data"]["session_id"] == "pbs_123" + assert parsed["data"]["is_current"] is True + + def test_evaluate_blocks_password_js_before_connection( + self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture + ) -> None: + from skyvern.cli.commands import browser as browser_cmd + + monkeypatch.setattr( + browser_cmd, + "_resolve_connection", + lambda _session, _cdp: (_ for _ in ()).throw(AssertionError("should not resolve connection")), + ) + + with pytest.raises(SystemExit, match="1"): + browser_cmd.evaluate( + expression='document.querySelector("input[type=password]").value = ""', json_output=True + ) + + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is False + assert "Cannot set password field values" in parsed["error"]["message"] + + def test_click_requires_target_before_connection( + self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture + ) -> None: + from skyvern.cli.commands import browser as browser_cmd + + monkeypatch.setattr( + browser_cmd, + "_resolve_connection", + lambda _session, _cdp: (_ for _ in ()).throw(AssertionError("should not resolve connection")), + ) + + with pytest.raises(SystemExit, match="1"): + browser_cmd.click( + intent=None, + selector=None, + session=None, + cdp=None, + timeout=30000, + button=None, + click_count=None, + json_output=True, + ) + + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is False + assert "Must provide intent, selector, or both" in parsed["error"]["message"] + + def test_click_with_intent_uses_proactive_ai_mode( + self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture + ) -> None: + from skyvern.cli.commands import browser as browser_cmd + + page = MagicMock() + page.click = AsyncMock(return_value="xpath=//button[@id='submit']") + browser = SimpleNamespace(get_working_page=AsyncMock(return_value=page)) + + monkeypatch.setattr( + browser_cmd, + "_resolve_connection", + lambda _session, _cdp: browser_cmd.ConnectionTarget(mode="cloud", session_id="pbs_123"), + ) + monkeypatch.setattr(browser_cmd, "_connect_browser", AsyncMock(return_value=browser)) + + browser_cmd.click( + intent="the Submit button", + selector=None, + session="pbs_123", + cdp=None, + timeout=30000, + button=None, + click_count=None, + json_output=True, + ) + + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is True + assert parsed["action"] == "click" + assert parsed["data"]["ai_mode"] == "proactive" + assert parsed["data"]["resolved_selector"] == "xpath=//button[@id='submit']" + + def test_wait_rejects_invalid_state_before_connection( + self, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture + ) -> None: + from skyvern.cli.commands import browser as browser_cmd + + monkeypatch.setattr( + browser_cmd, + "_resolve_connection", + lambda _session, _cdp: (_ for _ in ()).throw(AssertionError("should not resolve connection")), + ) + + with pytest.raises(SystemExit, match="1"): + browser_cmd.wait(state="bad-state", time_ms=1000, json_output=True) + + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is False + assert "Invalid state" in parsed["error"]["message"] diff --git a/tests/unit/test_mcp_session_lifecycle.py b/tests/unit/test_mcp_session_lifecycle.py new file mode 100644 index 00000000..568fccda --- /dev/null +++ b/tests/unit/test_mcp_session_lifecycle.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from skyvern.cli.core import client as client_mod +from skyvern.cli.core import session_manager +from skyvern.cli.core.result import BrowserContext +from skyvern.cli.core.session_ops import SessionCloseResult +from skyvern.cli.mcp_tools import session as mcp_session + + +@pytest.fixture(autouse=True) +def _reset_singletons() -> None: + client_mod._skyvern_instance.set(None) + client_mod._global_skyvern_instance = None + + session_manager._current_session.set(None) + session_manager._global_session = None + mcp_session.set_current_session(mcp_session.SessionState()) + + +def test_get_skyvern_reuses_global_instance_across_contexts(monkeypatch: pytest.MonkeyPatch) -> None: + created: list[object] = [] + + class FakeSkyvern: + def __init__(self, *args: object, **kwargs: object) -> None: + created.append(self) + + @classmethod + def local(cls) -> FakeSkyvern: + return cls() + + async def aclose(self) -> None: + return None + + monkeypatch.setattr(client_mod, "Skyvern", FakeSkyvern) + monkeypatch.setattr(client_mod.settings, "SKYVERN_API_KEY", None) + monkeypatch.setattr(client_mod.settings, "SKYVERN_BASE_URL", None) + + first = client_mod.get_skyvern() + client_mod._skyvern_instance.set(None) # Simulate a new async context. + second = client_mod.get_skyvern() + + assert first is second + assert len(created) == 1 + + +@pytest.mark.asyncio +async def test_close_skyvern_closes_singleton() -> None: + fake = MagicMock() + fake.aclose = AsyncMock() + + client_mod._skyvern_instance.set(fake) + client_mod._global_skyvern_instance = fake + + await client_mod.close_skyvern() + + fake.aclose.assert_awaited_once() + assert client_mod._skyvern_instance.get() is None + assert client_mod._global_skyvern_instance is None + + +def test_get_current_session_falls_back_to_global_state() -> None: + state = session_manager.SessionState( + browser=MagicMock(), + context=BrowserContext(mode="cloud_session", session_id="pbs_123"), + ) + session_manager.set_current_session(state) + + session_manager._current_session.set(None) # Simulate a new async context. + recovered = session_manager.get_current_session() + + assert recovered is state + + +@pytest.mark.asyncio +async def test_resolve_browser_reuses_matching_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None: + current_browser = MagicMock() + current_state = session_manager.SessionState( + browser=current_browser, + context=BrowserContext(mode="cloud_session", session_id="pbs_123"), + ) + session_manager.set_current_session(current_state) + + fake_skyvern = MagicMock() + fake_skyvern.connect_to_cloud_browser_session = AsyncMock() + monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern) + + browser, ctx = await session_manager.resolve_browser(session_id="pbs_123") + + assert browser is current_browser + assert ctx.session_id == "pbs_123" + fake_skyvern.connect_to_cloud_browser_session.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_session_close_with_matching_session_id_closes_browser_handle(monkeypatch: pytest.MonkeyPatch) -> None: + current_browser = MagicMock() + current_browser.close = AsyncMock() + mcp_session.set_current_session( + mcp_session.SessionState( + browser=current_browser, + context=BrowserContext(mode="cloud_session", session_id="pbs_456"), + ) + ) + + fake_skyvern = MagicMock() + monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern) + + do_session_close = AsyncMock(return_value=SessionCloseResult(session_id="pbs_456", closed=True)) + monkeypatch.setattr(mcp_session, "do_session_close", do_session_close) + + result = await mcp_session.skyvern_session_close(session_id="pbs_456") + + assert result["ok"] is True + assert result["data"] == {"session_id": "pbs_456", "closed": True} + current_browser.close.assert_awaited_once() + do_session_close.assert_awaited_once_with(fake_skyvern, "pbs_456") + assert mcp_session.get_current_session().browser is None + + +@pytest.mark.asyncio +async def test_session_close_chains_exceptions_when_both_api_and_browser_fail( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When both do_session_close (API) and browser.close() raise, the browser + exception should chain the API exception via __cause__ so neither is lost.""" + current_browser = MagicMock() + browser_error = RuntimeError("browser close failed") + current_browser.close = AsyncMock(side_effect=browser_error) + mcp_session.set_current_session( + mcp_session.SessionState( + browser=current_browser, + context=BrowserContext(mode="cloud_session", session_id="pbs_dual"), + ) + ) + + fake_skyvern = MagicMock() + monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern) + + api_error = ConnectionError("API close failed") + do_session_close = AsyncMock(side_effect=api_error) + monkeypatch.setattr(mcp_session, "do_session_close", do_session_close) + + result = await mcp_session.skyvern_session_close(session_id="pbs_dual") + + # The outer exception handler catches and returns an error result + assert result["ok"] is False + assert "browser close failed" in result["error"]["message"] + # Session state is cleaned up regardless + assert mcp_session.get_current_session().browser is None + + +@pytest.mark.asyncio +async def test_session_close_matching_context_without_browser_returns_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + mcp_session.set_current_session( + mcp_session.SessionState( + browser=None, + context=BrowserContext(mode="cloud_session", session_id="pbs_999"), + ) + ) + + fake_skyvern = MagicMock() + monkeypatch.setattr(mcp_session, "get_skyvern", lambda: fake_skyvern) + + do_session_close = AsyncMock(return_value=SessionCloseResult(session_id="pbs_999", closed=True)) + monkeypatch.setattr(mcp_session, "do_session_close", do_session_close) + + result = await mcp_session.skyvern_session_close(session_id="pbs_999") + + assert result["ok"] is False + assert result["error"]["code"] == mcp_session.ErrorCode.SDK_ERROR + assert "Expected active browser for matching cloud session" in result["error"]["message"] + do_session_close.assert_awaited_once_with(fake_skyvern, "pbs_999") + assert mcp_session.get_current_session().context is None + + +# --------------------------------------------------------------------------- +# Tests for close_current_session() — cloud session API cleanup +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_close_current_session_calls_api_close_for_cloud_session(monkeypatch: pytest.MonkeyPatch) -> None: + """close_current_session() should call do_session_close for cloud sessions + and clear _browser_session_id to avoid a duplicate API call from browser.close().""" + browser = MagicMock() + browser.close = AsyncMock() + browser._browser_session_id = "pbs_api" + session_manager.set_current_session( + session_manager.SessionState( + browser=browser, + context=BrowserContext(mode="cloud_session", session_id="pbs_api"), + ) + ) + + fake_skyvern = MagicMock() + monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern) + + do_session_close = AsyncMock(return_value=SessionCloseResult(session_id="pbs_api", closed=True)) + monkeypatch.setattr("skyvern.cli.core.session_ops.do_session_close", do_session_close) + + await session_manager.close_current_session() + + do_session_close.assert_awaited_once_with(fake_skyvern, "pbs_api") + browser.close.assert_awaited_once() + # _browser_session_id should be cleared to prevent redundant API call + assert browser._browser_session_id is None + assert session_manager.get_current_session().browser is None + + +@pytest.mark.asyncio +async def test_close_current_session_skips_api_close_for_local_session(monkeypatch: pytest.MonkeyPatch) -> None: + """close_current_session() should NOT call do_session_close for local sessions.""" + browser = MagicMock() + browser.close = AsyncMock() + session_manager.set_current_session( + session_manager.SessionState( + browser=browser, + context=BrowserContext(mode="local"), + ) + ) + + do_session_close = AsyncMock() + monkeypatch.setattr("skyvern.cli.core.session_ops.do_session_close", do_session_close) + + await session_manager.close_current_session() + + do_session_close.assert_not_awaited() + browser.close.assert_awaited_once() + assert session_manager.get_current_session().browser is None + + +@pytest.mark.asyncio +async def test_close_current_session_still_closes_browser_when_api_fails(monkeypatch: pytest.MonkeyPatch) -> None: + """When do_session_close raises, browser.close() should still run and state should be cleared.""" + browser = MagicMock() + browser.close = AsyncMock() + browser._browser_session_id = "pbs_fail" + session_manager.set_current_session( + session_manager.SessionState( + browser=browser, + context=BrowserContext(mode="cloud_session", session_id="pbs_fail"), + ) + ) + + fake_skyvern = MagicMock() + monkeypatch.setattr(session_manager, "get_skyvern", lambda: fake_skyvern) + + do_session_close = AsyncMock(side_effect=ConnectionError("API unreachable")) + monkeypatch.setattr("skyvern.cli.core.session_ops.do_session_close", do_session_close) + + await session_manager.close_current_session() + + do_session_close.assert_awaited_once_with(fake_skyvern, "pbs_fail") + # browser.close() should still be called despite API failure + browser.close.assert_awaited_once() + # _browser_session_id should NOT be cleared (API close failed, let browser.close() try) + assert browser._browser_session_id == "pbs_fail" + assert session_manager.get_current_session().browser is None diff --git a/tests/unit/test_run_commands_cleanup.py b/tests/unit/test_run_commands_cleanup.py new file mode 100644 index 00000000..74736cfc --- /dev/null +++ b/tests/unit/test_run_commands_cleanup.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from skyvern.cli import run_commands + + +@pytest.fixture(autouse=True) +def _reset_cleanup_state() -> None: + run_commands._mcp_cleanup_done = False + + +def test_cleanup_mcp_resources_sync_runs_without_running_loop(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = AsyncMock() + monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup) + + run_commands._cleanup_mcp_resources_sync() + + cleanup.assert_awaited_once() + assert run_commands._mcp_cleanup_done is True + + +@pytest.mark.asyncio +async def test_cleanup_mcp_resources_sync_skips_when_loop_running(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = AsyncMock() + monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", cleanup) + + run_commands._cleanup_mcp_resources_sync() + await asyncio.sleep(0) + + cleanup.assert_not_awaited() + assert run_commands._mcp_cleanup_done is False + + +def test_cleanup_mcp_resources_sync_keeps_retry_possible_on_task_failure(monkeypatch: pytest.MonkeyPatch) -> None: + async def failing_cleanup() -> None: + raise RuntimeError("cleanup failed") + + monkeypatch.setattr(run_commands, "_cleanup_mcp_resources", failing_cleanup) + + run_commands._cleanup_mcp_resources_sync() + + assert run_commands._mcp_cleanup_done is False + + +def test_run_mcp_calls_blocking_cleanup_in_finally(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup_blocking = MagicMock() + register = MagicMock() + run = MagicMock(side_effect=RuntimeError("boom")) + + monkeypatch.setattr(run_commands, "_cleanup_mcp_resources_blocking", cleanup_blocking) + monkeypatch.setattr(run_commands.atexit, "register", register) + monkeypatch.setattr(run_commands.mcp, "run", run) + + with pytest.raises(RuntimeError, match="boom"): + run_commands.run_mcp() + + register.assert_called_once_with(run_commands._cleanup_mcp_resources_sync) + run.assert_called_once_with(transport="stdio") + cleanup_blocking.assert_called_once()