diff --git a/skyvern/cli/commands.py b/skyvern/cli/commands/__init__.py similarity index 60% rename from skyvern/cli/commands.py rename to skyvern/cli/commands/__init__.py index cb9f1ca3..97ede4df 100644 --- a/skyvern/cli/commands.py +++ b/skyvern/cli/commands/__init__.py @@ -1,23 +1,51 @@ +import logging + import typer from dotenv import load_dotenv +from skyvern.forge.sdk.forge_log import setup_logger as _setup_logger from skyvern.utils.env_paths import resolve_backend_env_path -from .credentials import credentials_app -from .docs import docs_app -from .init_command import init_browser, init_env -from .quickstart import quickstart_app -from .run_commands import run_app -from .status import status_app -from .stop_commands import stop_app -from .tasks import tasks_app -from .workflow import workflow_app +from ..credentials import credentials_app +from ..docs import docs_app +from ..init_command import init_browser, init_env +from ..quickstart import quickstart_app +from ..run_commands import run_app +from ..status import status_app +from ..stop_commands import stop_app +from ..tasks import tasks_app +from ..workflow import workflow_app +from .browser import browser_app + +_cli_logging_configured = False + + +def configure_cli_logging() -> None: + """Configure CLI log levels once at runtime (not at import time).""" + global _cli_logging_configured + if _cli_logging_configured: + return + _cli_logging_configured = True + + # Suppress noisy SDK/third-party logs for CLI execution only. + for logger_name in ("skyvern", "httpx", "litellm", "playwright", "httpcore"): + logging.getLogger(logger_name).setLevel(logging.WARNING) + _setup_logger() + cli_app = typer.Typer( help=("""[bold]Skyvern CLI[/bold]\nManage and run your local Skyvern environment."""), no_args_is_help=True, rich_markup_mode="rich", ) + + +@cli_app.callback() +def cli_callback() -> None: + """Configure CLI logging before command execution.""" + configure_cli_logging() + + cli_app.add_typer( run_app, name="run", @@ -40,6 +68,9 @@ cli_app.add_typer( quickstart_app, name="quickstart", help="One-command setup and start for Skyvern (combines init and run)." ) +# Browser automation commands +cli_app.add_typer(browser_app, name="browser", help="Browser automation commands.") + @init_app.callback() def init_callback( diff --git a/skyvern/cli/commands/__main__.py b/skyvern/cli/commands/__main__.py new file mode 100644 index 00000000..3ab570f1 --- /dev/null +++ b/skyvern/cli/commands/__main__.py @@ -0,0 +1,9 @@ +from dotenv import load_dotenv + +from skyvern.utils.env_paths import resolve_backend_env_path + +from . import cli_app + +if __name__ == "__main__": # pragma: no cover - manual CLI invocation + load_dotenv(resolve_backend_env_path()) + cli_app() diff --git a/skyvern/cli/commands/_output.py b/skyvern/cli/commands/_output.py new file mode 100644 index 00000000..e99ccec9 --- /dev/null +++ b/skyvern/cli/commands/_output.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import json +import sys +from typing import Any + +from rich.console import Console +from rich.table import Table + +console = Console() + + +def output( + data: Any, + *, + action: str = "", + json_mode: bool = False, +) -> None: + if json_mode: + envelope: dict[str, Any] = {"ok": True, "action": action, "data": data, "error": None} + json.dump(envelope, sys.stdout, indent=2, default=str) + sys.stdout.write("\n") + return + if isinstance(data, list) and data and isinstance(data[0], dict): + table = Table() + for key in data[0]: + table.add_column(key.replace("_", " ").title()) + for row in data: + table.add_row(*[str(v) for v in row.values()]) + console.print(table) + elif isinstance(data, dict): + for key, value in data.items(): + console.print(f"[bold]{key}:[/bold] {value}") + else: + console.print(str(data)) + + +def output_error(message: str, *, hint: str = "", json_mode: bool = False, exit_code: int = 1) -> None: + if json_mode: + envelope: dict[str, Any] = { + "ok": False, + "action": "", + "data": None, + "error": {"message": message, "hint": hint}, + } + json.dump(envelope, sys.stdout, indent=2, default=str) + sys.stdout.write("\n") + raise SystemExit(exit_code) + console.print(f"[red]Error: {message}[/red]") + if hint: + console.print(f"[yellow]Hint: {hint}[/yellow]") + raise SystemExit(exit_code) diff --git a/skyvern/cli/commands/_state.py b/skyvern/cli/commands/_state.py new file mode 100644 index 00000000..96482c85 --- /dev/null +++ b/skyvern/cli/commands/_state.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path + +STATE_DIR = Path.home() / ".skyvern" +STATE_FILE = STATE_DIR / "state.json" + +_TTL_SECONDS = 86400 # 24 hours + + +@dataclass +class CLIState: + session_id: str | None = None + cdp_url: str | None = None + mode: str | None = None # "cloud", "local", or "cdp" + created_at: str | None = None + + +def save_state(state: CLIState) -> None: + STATE_DIR.mkdir(parents=True, exist_ok=True) + STATE_DIR.chmod(0o700) + data = asdict(state) + data["created_at"] = datetime.now(timezone.utc).isoformat() + STATE_FILE.write_text(json.dumps(data)) + STATE_FILE.chmod(0o600) + + +def load_state() -> CLIState | None: + if not STATE_FILE.exists(): + return None + try: + data = json.loads(STATE_FILE.read_text()) + created_at = data.get("created_at") + if created_at: + age = (datetime.now(timezone.utc) - datetime.fromisoformat(created_at)).total_seconds() + if age > _TTL_SECONDS: + return None + return CLIState(**{k: v for k, v in data.items() if k in CLIState.__dataclass_fields__}) + except Exception: + return None + + +def clear_state() -> None: + if STATE_FILE.exists(): + STATE_FILE.unlink() diff --git a/skyvern/cli/commands/browser.py b/skyvern/cli/commands/browser.py new file mode 100644 index 00000000..1c66d59d --- /dev/null +++ b/skyvern/cli/commands/browser.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +import asyncio +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Literal + +import typer + +from skyvern.cli.commands._output import output, output_error +from skyvern.cli.commands._state import CLIState, clear_state, load_state, save_state +from skyvern.cli.core.artifacts import save_artifact +from skyvern.cli.core.browser_ops import do_act, do_extract, do_navigate, do_screenshot +from skyvern.cli.core.client import get_skyvern +from skyvern.cli.core.guards import GuardError, check_password_prompt, validate_wait_until +from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list + +browser_app = typer.Typer(help="Browser automation commands.", no_args_is_help=True) +session_app = typer.Typer(help="Manage browser sessions.", no_args_is_help=True) +browser_app.add_typer(session_app, name="session") + + +@dataclass(frozen=True) +class ConnectionTarget: + mode: Literal["cloud", "cdp"] + session_id: str | None = None + cdp_url: str | None = None + + +def _resolve_connection(session: str | None, cdp: str | None) -> ConnectionTarget: + if session and cdp: + raise typer.BadParameter("Pass only one of --session or --cdp.") + + if session: + return ConnectionTarget(mode="cloud", session_id=session) + if cdp: + return ConnectionTarget(mode="cdp", cdp_url=cdp) + + state = load_state() + if state: + if state.mode == "cdp" and state.cdp_url: + return ConnectionTarget(mode="cdp", cdp_url=state.cdp_url) + if state.session_id: + return ConnectionTarget(mode="cloud", session_id=state.session_id) + if state.cdp_url: + return ConnectionTarget(mode="cdp", cdp_url=state.cdp_url) + + raise typer.BadParameter( + "No active browser connection. Create one with: skyvern browser session create\n" + "Or connect with: skyvern browser session connect --cdp ws://...\n" + "Or specify: --session pbs_... / --cdp ws://..." + ) + + +async def _connect_browser(connection: ConnectionTarget) -> Any: + skyvern = get_skyvern() + if connection.mode == "cloud": + if not connection.session_id: + raise typer.BadParameter("Cloud mode requires --session or an active cloud session in state.") + return await skyvern.connect_to_cloud_browser_session(connection.session_id) + if not connection.cdp_url: + raise typer.BadParameter("CDP mode requires --cdp or an active CDP URL in state.") + return await skyvern.connect_to_browser_over_cdp(connection.cdp_url) + + +# --------------------------------------------------------------------------- +# Session commands +# --------------------------------------------------------------------------- + + +@session_app.command("create") +def session_create( + timeout: int = typer.Option(60, help="Session timeout in minutes."), + proxy: str | None = typer.Option(None, help="Proxy location (e.g. RESIDENTIAL)."), + local: bool = typer.Option(False, "--local", help="Launch a local browser instead of cloud."), + headless: bool = typer.Option(False, "--headless", help="Run local browser headless."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Create a new browser session.""" + if local: + output_error( + "Local browser sessions are not yet supported in CLI mode.", + hint="Use MCP (skyvern run mcp) for local browser sessions, or omit --local for cloud sessions.", + json_mode=json_output, + ) + + async def _run() -> dict: + skyvern = get_skyvern() + _browser, result = await do_session_create( + skyvern, + timeout=timeout, + proxy_location=proxy, + ) + save_state(CLIState(session_id=result.session_id, cdp_url=None, mode="cloud")) + return { + "session_id": result.session_id, + "mode": "cloud", + "timeout_minutes": result.timeout_minutes, + } + + try: + data = asyncio.run(_run()) + output(data, action="session_create", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Check your API key and network connection.", json_mode=json_output) + + +@session_app.command("close") +def session_close( + session: str | None = typer.Option(None, help="Browser session ID to close."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL to detach from."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Close a browser session.""" + + async def _run() -> dict: + connection = _resolve_connection(session, cdp) + if connection.mode == "cdp": + clear_state() + return {"cdp_url": connection.cdp_url, "closed": False, "detached": True} + + if not connection.session_id: + raise typer.BadParameter("Cloud mode requires a browser session ID.") + + skyvern = get_skyvern() + result = await do_session_close(skyvern, connection.session_id) + clear_state() + return {"session_id": result.session_id, "closed": result.closed} + + try: + data = asyncio.run(_run()) + output(data, action="session_close", json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Verify the session ID or CDP URL is correct.", json_mode=json_output) + + +@session_app.command("connect") +def session_connect( + session: str | None = typer.Option(None, help="Cloud browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Connect to an existing browser session (cloud or CDP) and persist it as active state.""" + if not session and not cdp: + raise typer.BadParameter("Specify one of --session or --cdp.") + + async def _run() -> dict: + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + await browser.get_working_page() + + if connection.mode == "cdp": + save_state(CLIState(session_id=None, cdp_url=connection.cdp_url, mode="cdp")) + return {"connected": True, "mode": "cdp", "cdp_url": connection.cdp_url} + + save_state(CLIState(session_id=connection.session_id, cdp_url=None, mode="cloud")) + return {"connected": True, "mode": "cloud", "session_id": connection.session_id} + + try: + data = asyncio.run(_run()) + output(data, action="session_connect", json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Verify the session ID or CDP URL is reachable.", json_mode=json_output) + + +@session_app.command("list") +def session_list( + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """List all browser sessions.""" + + async def _run() -> list[dict]: + skyvern = get_skyvern() + sessions = await do_session_list(skyvern) + return [asdict(s) for s in sessions] + + try: + data = asyncio.run(_run()) + output(data, action="session_list", json_mode=json_output) + except Exception as e: + output_error(str(e), hint="Check your API key and network connection.", json_mode=json_output) + + +# --------------------------------------------------------------------------- +# Browser commands +# --------------------------------------------------------------------------- + + +@browser_app.command("navigate") +def navigate( + url: str = typer.Option(..., help="URL to navigate to."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + timeout: int = typer.Option(30000, help="Navigation timeout in milliseconds."), + wait_until: str | None = typer.Option(None, help="Wait condition: load, domcontentloaded, networkidle, commit."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Navigate to a URL in the browser session.""" + + async def _run() -> dict: + validate_wait_until(wait_until) + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + result = await do_navigate(page, url, timeout=timeout, wait_until=wait_until) + return {"url": result.url, "title": result.title} + + try: + data = asyncio.run(_run()) + output(data, action="navigate", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Check the URL is valid and the session is active.", json_mode=json_output) + + +@browser_app.command("screenshot") +def screenshot( + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + full_page: bool = typer.Option(False, "--full-page", help="Capture the full scrollable page."), + selector: str | None = typer.Option(None, help="CSS selector to screenshot."), + output_path: str | None = typer.Option(None, "--output", help="Custom output file path."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Take a screenshot of the current page.""" + + async def _run() -> dict: + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + result = await do_screenshot(page, full_page=full_page, selector=selector) + + if output_path: + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(result.data) + return {"path": str(path), "bytes": len(result.data), "full_page": result.full_page} + + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + artifact = save_artifact( + content=result.data, + kind="screenshot", + filename=f"screenshot_{timestamp}.png", + mime="image/png", + session_id=connection.session_id, + ) + return {"path": artifact.path, "bytes": artifact.bytes, "full_page": result.full_page} + + try: + data = asyncio.run(_run()) + output(data, action="screenshot", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Ensure the session is active and the page has loaded.", json_mode=json_output) + + +@browser_app.command("act") +def act( + prompt: str = typer.Option(..., help="Natural language action to perform."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Perform a natural language action on the current page.""" + + async def _run() -> dict: + check_password_prompt(prompt) + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + result = await do_act(page, prompt) + return {"prompt": result.prompt, "completed": result.completed} + + try: + data = asyncio.run(_run()) + output(data, action="act", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Simplify the prompt or break into steps.", json_mode=json_output) + + +@browser_app.command("extract") +def extract( + prompt: str = typer.Option(..., help="What data to extract from the page."), + session: str | None = typer.Option(None, help="Browser session ID."), + cdp: str | None = typer.Option(None, "--cdp", help="CDP WebSocket URL."), + schema: str | None = typer.Option(None, help="JSON schema for structured extraction."), + json_output: bool = typer.Option(False, "--json", help="Output as JSON."), +) -> None: + """Extract data from the current page using natural language.""" + + async def _run() -> dict: + connection = _resolve_connection(session, cdp) + browser = await _connect_browser(connection) + page = await browser.get_working_page() + result = await do_extract(page, prompt, schema=schema) + return {"prompt": prompt, "extracted": result.extracted} + + try: + data = asyncio.run(_run()) + output(data, action="extract", json_mode=json_output) + except GuardError as e: + output_error(str(e), hint=e.hint, json_mode=json_output) + except typer.BadParameter: + raise + except Exception as e: + output_error(str(e), hint="Simplify the prompt or provide a JSON schema.", json_mode=json_output) diff --git a/skyvern/cli/core/browser_ops.py b/skyvern/cli/core/browser_ops.py new file mode 100644 index 00000000..e034bad5 --- /dev/null +++ b/skyvern/cli/core/browser_ops.py @@ -0,0 +1,87 @@ +"""Shared browser operations for MCP tools and CLI commands. + +Each function: validate inputs -> call SDK -> return typed result. +Session resolution and output formatting are caller responsibilities. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any + +from .guards import GuardError + + +@dataclass +class NavigateResult: + url: str + title: str + + +@dataclass +class ScreenshotResult: + data: bytes + full_page: bool = False + + +@dataclass +class ActResult: + prompt: str + completed: bool = True + + +@dataclass +class ExtractResult: + extracted: Any = None + + +def parse_extract_schema(schema: str | dict[str, Any] | None) -> dict[str, Any] | None: + """Parse and validate an extraction schema payload.""" + if schema is None: + return None + if isinstance(schema, dict): + return schema + + try: + return json.loads(schema) + except (json.JSONDecodeError, TypeError) as e: + raise GuardError(f"Invalid JSON schema: {e}", "Provide schema as a valid JSON string") + + +async def do_navigate( + page: Any, + url: str, + timeout: int = 30000, + wait_until: str | None = None, +) -> NavigateResult: + await page.goto(url, timeout=timeout, wait_until=wait_until) + return NavigateResult(url=page.url, title=await page.title()) + + +async def do_screenshot( + page: Any, + full_page: bool = False, + selector: str | None = None, +) -> ScreenshotResult: + if selector: + element = page.locator(selector) + data = await element.screenshot() + else: + data = await page.screenshot(full_page=full_page) + return ScreenshotResult(data=data, full_page=full_page) + + +async def do_act(page: Any, prompt: str) -> ActResult: + await page.act(prompt) + return ActResult(prompt=prompt, completed=True) + + +async def do_extract( + page: Any, + prompt: str, + schema: str | dict[str, Any] | None = None, +) -> ExtractResult: + parsed_schema = parse_extract_schema(schema) + extracted = await page.extract(prompt=prompt, schema=parsed_schema) + return ExtractResult(extracted=extracted) diff --git a/skyvern/cli/core/guards.py b/skyvern/cli/core/guards.py new file mode 100644 index 00000000..773ab8ac --- /dev/null +++ b/skyvern/cli/core/guards.py @@ -0,0 +1,81 @@ +"""Shared input validation guards for both MCP and CLI surfaces.""" + +from __future__ import annotations + +import re + +PASSWORD_PATTERN = re.compile( + r"\bpass(?:word|phrase|code)s?\b|\bsecret\b|\bcredential\b|\bpin\s*(?:code)?\b|\bpwd\b|\bpasswd\b", + re.IGNORECASE, +) + +JS_PASSWORD_PATTERN = re.compile( + r"""(?:type\s*=\s*['"]?password|\.type\s*===?\s*['"]password|input\[type=password\]).*?\.value\s*=""", + re.IGNORECASE, +) + +CREDENTIAL_HINT = ( + "Use skyvern_login with a stored credential to authenticate. " + "Create credentials via CLI: skyvern credentials add. " + "Never pass passwords through tool calls." +) + +VALID_WAIT_UNTIL = ("load", "domcontentloaded", "networkidle", "commit") +VALID_BUTTONS = ("left", "right", "middle") +VALID_ELEMENT_STATES = ("visible", "hidden", "attached", "detached") + + +class GuardError(Exception): + """Raised when an input guard blocks an operation.""" + + def __init__(self, message: str, hint: str = "") -> None: + super().__init__(message) + self.hint = hint + + +def check_password_prompt(text: str) -> None: + """Block prompts containing password/credential terms.""" + if PASSWORD_PATTERN.search(text): + raise GuardError( + "Cannot perform password/credential actions — credentials must not be passed through tool calls", + CREDENTIAL_HINT, + ) + + +def check_js_password(expression: str) -> None: + """Block JS expressions that set password field values.""" + if JS_PASSWORD_PATTERN.search(expression): + raise GuardError( + "Cannot set password field values via JavaScript — credentials must not be passed through tool calls", + CREDENTIAL_HINT, + ) + + +def validate_wait_until(value: str | None) -> None: + if value is not None and value not in VALID_WAIT_UNTIL: + raise GuardError( + f"Invalid wait_until: {value}", + "Use load, domcontentloaded, networkidle, or commit", + ) + + +def validate_button(value: str | None) -> None: + if value is not None and value not in VALID_BUTTONS: + raise GuardError(f"Invalid button: {value}", "Use left, right, or middle") + + +def resolve_ai_mode( + selector: str | None, + intent: str | None, +) -> tuple[str | None, str | None]: + """Determine AI mode from selector/intent combination. + + Returns (ai_mode, error_code) -- if error_code is set, the call should fail. + """ + if intent and not selector: + return "proactive", None + if intent and selector: + return "fallback", None + if selector and not intent: + return None, None + return None, "INVALID_INPUT" diff --git a/skyvern/cli/core/session_ops.py b/skyvern/cli/core/session_ops.py new file mode 100644 index 00000000..4e3014eb --- /dev/null +++ b/skyvern/cli/core/session_ops.py @@ -0,0 +1,74 @@ +"""Shared session operations for MCP tools and CLI commands.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from skyvern.schemas.runs import ProxyLocation + + +@dataclass +class SessionCreateResult: + session_id: str | None + local: bool = False + headless: bool = False + timeout_minutes: int | None = None + + +@dataclass +class SessionCloseResult: + session_id: str | None + closed: bool = True + + +@dataclass +class SessionInfo: + session_id: str + status: str | None + started_at: str | None + timeout: int | None + runnable_id: str | None = None + available: bool = False + + +async def do_session_create( + skyvern: Any, + timeout: int = 60, + proxy_location: str | None = None, + local: bool = False, + headless: bool = False, +) -> tuple[Any, SessionCreateResult]: + """Create browser session. Returns (browser, result).""" + if local: + browser = await skyvern.launch_local_browser(headless=headless) + return browser, SessionCreateResult(session_id=None, local=True, headless=headless) + + proxy = ProxyLocation(proxy_location) if proxy_location else None + browser = await skyvern.launch_cloud_browser(timeout=timeout, proxy_location=proxy) + return browser, SessionCreateResult( + session_id=browser.browser_session_id, + timeout_minutes=timeout, + ) + + +async def do_session_close(skyvern: Any, session_id: str) -> SessionCloseResult: + """Close a browser session by ID.""" + await skyvern.close_browser_session(session_id) + return SessionCloseResult(session_id=session_id) + + +async def do_session_list(skyvern: Any) -> list[SessionInfo]: + """List all browser sessions.""" + sessions = await skyvern.get_browser_sessions() + return [ + SessionInfo( + session_id=s.browser_session_id, + status=s.status, + started_at=s.started_at.isoformat() if s.started_at else None, + timeout=s.timeout, + runnable_id=s.runnable_id, + available=s.runnable_id is None and s.browser_address is not None, + ) + for s in sessions + ] diff --git a/skyvern/cli/mcp_tools/browser.py b/skyvern/cli/mcp_tools/browser.py index 61ee5c33..9c41b999 100644 --- a/skyvern/cli/mcp_tools/browser.py +++ b/skyvern/cli/mcp_tools/browser.py @@ -4,13 +4,24 @@ import asyncio import base64 import json import logging -import re from datetime import datetime, timezone from typing import Annotated, Any from playwright.async_api import TimeoutError as PlaywrightTimeoutError from pydantic import Field +from skyvern.cli.core.browser_ops import do_act, do_extract, do_navigate, do_screenshot, parse_extract_schema +from skyvern.cli.core.guards import ( + CREDENTIAL_HINT, + JS_PASSWORD_PATTERN, + PASSWORD_PATTERN, + GuardError, + check_password_prompt, +) +from skyvern.cli.core.guards import resolve_ai_mode as _resolve_ai_mode +from skyvern.cli.core.guards import ( + validate_wait_until, +) from skyvern.schemas.run_blocks import CredentialType from ._common import ( @@ -24,39 +35,6 @@ from ._session import BrowserNotAvailableError, get_page, no_browser_error LOG = logging.getLogger(__name__) -_PASSWORD_PATTERN = re.compile( - r"\bpass(?:word|phrase|code)s?\b|\bsecret\b|\bcredential\b|\bpin\s*(?:code)?\b|\bpwd\b|\bpasswd\b", - re.IGNORECASE, -) - -_CREDENTIAL_ERROR_HINT = ( - "Use skyvern_login with a stored credential to authenticate. " - "Create credentials via CLI: skyvern credentials add. " - "Never pass passwords through tool calls." -) - -_JS_PASSWORD_PATTERN = re.compile( - r"""(?:type\s*=\s*['"]?password|\.type\s*===?\s*['"]password|input\[type=password\]).*?\.value\s*=""", - re.IGNORECASE, -) - - -def _resolve_ai_mode( - selector: str | None, - intent: str | None, -) -> tuple[str | None, str | None]: - """Determine AI mode from selector/intent combination. - - Returns (ai_mode, error_code) — if error_code is set, the call should fail. - """ - if intent and not selector: - return "proactive", None - if intent and selector: - return "fallback", None - if selector and not intent: - return None, None - return None, "INVALID_INPUT" - async def skyvern_navigate( url: Annotated[str, "The URL to navigate to"], @@ -80,15 +58,13 @@ async def skyvern_navigate( Returns the final URL (after redirects) and page title. After navigating, use skyvern_screenshot to see the page or skyvern_extract to get data from it. """ - if wait_until is not None and wait_until not in ("load", "domcontentloaded", "networkidle", "commit"): + try: + validate_wait_until(wait_until) + except GuardError as e: return make_result( "skyvern_navigate", ok=False, - error=make_error( - ErrorCode.INVALID_INPUT, - f"Invalid wait_until: {wait_until}", - "Use load, domcontentloaded, networkidle, or commit", - ), + error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint), ) try: @@ -98,10 +74,16 @@ async def skyvern_navigate( with Timer() as timer: try: - await page.goto(url, timeout=timeout, wait_until=wait_until) + result = await do_navigate(page, url, timeout=timeout, wait_until=wait_until) timer.mark("sdk") - final_url = page.url - title = await page.title() + except GuardError as e: + return make_result( + "skyvern_navigate", + ok=False, + browser_context=ctx, + timing_ms=timer.timing_ms, + error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint), + ) except Exception as e: return make_result( "skyvern_navigate", @@ -114,7 +96,7 @@ async def skyvern_navigate( return make_result( "skyvern_navigate", browser_context=ctx, - data={"url": final_url, "title": title, "sdk_equivalent": f'await page.goto("{url}")'}, + data={"url": result.url, "title": result.title, "sdk_equivalent": f'await page.goto("{url}")'}, timing_ms=timer.timing_ms, ) @@ -355,14 +337,14 @@ async def skyvern_type( """ # Block password entry — redirect to skyvern_login target_text = f"{intent or ''} {selector or ''}" - if _PASSWORD_PATTERN.search(target_text): + if PASSWORD_PATTERN.search(target_text): return make_result( "skyvern_type", ok=False, error=make_error( ErrorCode.INVALID_INPUT, "Cannot type into password fields — credentials must not be passed through tool calls", - _CREDENTIAL_ERROR_HINT, + CREDENTIAL_HINT, ), ) @@ -402,7 +384,7 @@ async def skyvern_type( error=make_error( ErrorCode.INVALID_INPUT, "Cannot type into password fields — credentials must not be passed through tool calls", - _CREDENTIAL_ERROR_HINT, + CREDENTIAL_HINT, ), ) @@ -491,11 +473,7 @@ async def skyvern_screenshot( with Timer() as timer: try: - if selector: - element = page.locator(selector) - screenshot_bytes = await element.screenshot() - else: - screenshot_bytes = await page.screenshot(full_page=full_page) + result = await do_screenshot(page, full_page=full_page, selector=selector) timer.mark("sdk") except Exception as e: return make_result( @@ -507,7 +485,7 @@ async def skyvern_screenshot( ) if inline: - data_b64 = base64.b64encode(screenshot_bytes).decode("utf-8") + data_b64 = base64.b64encode(result.data).decode("utf-8") return make_result( "skyvern_screenshot", browser_context=ctx, @@ -515,7 +493,7 @@ async def skyvern_screenshot( "inline": True, "data": data_b64, "mime": "image/png", - "bytes": len(screenshot_bytes), + "bytes": len(result.data), "sdk_equivalent": "await page.screenshot()", }, timing_ms=timer.timing_ms, @@ -525,7 +503,7 @@ async def skyvern_screenshot( ts = datetime.now(timezone.utc).strftime("%H%M%S_%f") filename = f"screenshot_{ts}.png" artifact = save_artifact( - screenshot_bytes, + result.data, kind="screenshot", filename=filename, mime="image/png", @@ -896,14 +874,14 @@ async def skyvern_evaluate( Security: This executes arbitrary JS in the page context. Only use with trusted expressions. """ # Block JS that sets password field values - if _JS_PASSWORD_PATTERN.search(expression): + if JS_PASSWORD_PATTERN.search(expression): return make_result( "skyvern_evaluate", ok=False, error=make_error( ErrorCode.INVALID_INPUT, "Cannot set password field values via JavaScript — credentials must not be passed through tool calls", - _CREDENTIAL_ERROR_HINT, + CREDENTIAL_HINT, ), ) @@ -947,20 +925,17 @@ async def skyvern_extract( For visual inspection instead of structured data, use skyvern_screenshot. Optionally provide a JSON `schema` to enforce the output structure (pass as a JSON string). """ - parsed_schema: dict[str, Any] | None = None if schema is not None: try: - parsed_schema = json.loads(schema) - except (json.JSONDecodeError, TypeError) as e: + parsed_schema = parse_extract_schema(schema) + except GuardError as e: return make_result( "skyvern_extract", ok=False, - error=make_error( - ErrorCode.INVALID_INPUT, - f"Invalid JSON schema: {e}", - "Provide schema as a valid JSON string", - ), + error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint), ) + else: + parsed_schema = None try: page, ctx = await get_page(session_id=session_id, cdp_url=cdp_url) @@ -969,8 +944,16 @@ async def skyvern_extract( with Timer() as timer: try: - extracted = await page.extract(prompt=prompt, schema=parsed_schema) + result = await do_extract(page, prompt, schema=parsed_schema) timer.mark("sdk") + except GuardError as e: + return make_result( + "skyvern_extract", + ok=False, + browser_context=ctx, + timing_ms=timer.timing_ms, + error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint), + ) except Exception as e: return make_result( "skyvern_extract", @@ -983,7 +966,10 @@ async def skyvern_extract( return make_result( "skyvern_extract", browser_context=ctx, - data={"extracted": extracted, "sdk_equivalent": f'await page.extract(prompt="{prompt}")'}, + data={ + "extracted": result.extracted, + "sdk_equivalent": f'await page.extract(prompt="{prompt}")', + }, timing_ms=timer.timing_ms, ) @@ -1037,16 +1023,13 @@ async def skyvern_act( For multi-step automations (4+ pages), use skyvern_workflow_create with one block per step. For quick one-off multi-page tasks, use skyvern_run_task. """ - # Block login/password actions — redirect to skyvern_login - if _PASSWORD_PATTERN.search(prompt): + try: + check_password_prompt(prompt) + except GuardError as e: return make_result( "skyvern_act", ok=False, - error=make_error( - ErrorCode.INVALID_INPUT, - "Cannot perform password/credential actions — credentials must not be passed through tool calls", - _CREDENTIAL_ERROR_HINT, - ), + error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint), ) try: @@ -1056,8 +1039,16 @@ async def skyvern_act( with Timer() as timer: try: - await page.act(prompt) + result = await do_act(page, prompt) timer.mark("sdk") + except GuardError as e: + return make_result( + "skyvern_act", + ok=False, + browser_context=ctx, + timing_ms=timer.timing_ms, + error=make_error(ErrorCode.INVALID_INPUT, str(e), e.hint), + ) except Exception as e: return make_result( "skyvern_act", @@ -1070,7 +1061,11 @@ async def skyvern_act( return make_result( "skyvern_act", browser_context=ctx, - data={"prompt": prompt, "completed": True, "sdk_equivalent": f'await page.act("{prompt}")'}, + data={ + "prompt": result.prompt, + "completed": result.completed, + "sdk_equivalent": f'await page.act("{prompt}")', + }, timing_ms=timer.timing_ms, ) @@ -1099,14 +1094,14 @@ async def skyvern_run_task( For simple single-step actions on the current page, use skyvern_act instead. """ # Block password/credential actions — redirect to skyvern_login - if _PASSWORD_PATTERN.search(prompt): + if PASSWORD_PATTERN.search(prompt): return make_result( "skyvern_run_task", ok=False, error=make_error( ErrorCode.INVALID_INPUT, "Cannot perform password/credential actions — credentials must not be passed through tool calls", - _CREDENTIAL_ERROR_HINT, + CREDENTIAL_HINT, ), ) diff --git a/skyvern/cli/mcp_tools/session.py b/skyvern/cli/mcp_tools/session.py index 67e5b8e8..3366f460 100644 --- a/skyvern/cli/mcp_tools/session.py +++ b/skyvern/cli/mcp_tools/session.py @@ -4,7 +4,7 @@ from typing import Annotated, Any from pydantic import Field -from skyvern.schemas.runs import ProxyLocation +from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list from ._common import BrowserContext, ErrorCode, Timer, make_error, make_result from ._session import ( @@ -30,25 +30,21 @@ async def skyvern_session_create( with Timer() as timer: try: skyvern = get_skyvern() - - if local: - browser = await skyvern.launch_local_browser(headless=headless) - ctx = BrowserContext(mode="local") - set_current_session(SessionState(browser=browser, context=ctx)) - timer.mark("sdk") - return make_result( - "skyvern_session_create", - browser_context=ctx, - data={"local": True, "headless": headless}, - timing_ms=timer.timing_ms, - ) - - proxy = ProxyLocation(proxy_location) if proxy_location else None - browser = await skyvern.launch_cloud_browser(timeout=timeout, proxy_location=proxy) - ctx = BrowserContext(mode="cloud_session", session_id=browser.browser_session_id) - set_current_session(SessionState(browser=browser, context=ctx)) + browser, result = await do_session_create( + skyvern, + timeout=timeout or 60, + proxy_location=proxy_location, + local=local, + headless=headless, + ) timer.mark("sdk") + if result.local: + ctx = BrowserContext(mode="local") + else: + ctx = BrowserContext(mode="cloud_session", session_id=result.session_id) + set_current_session(SessionState(browser=browser, context=ctx)) + except ValueError as e: return make_result( "skyvern_session_create", @@ -68,12 +64,20 @@ async def skyvern_session_create( error=make_error(ErrorCode.SDK_ERROR, str(e), "Failed to create browser session"), ) + if result.local: + return make_result( + "skyvern_session_create", + browser_context=ctx, + data={"local": True, "headless": result.headless}, + timing_ms=timer.timing_ms, + ) + return make_result( "skyvern_session_create", browser_context=ctx, data={ - "session_id": browser.browser_session_id, - "timeout_minutes": timeout, + "session_id": result.session_id, + "timeout_minutes": result.timeout_minutes, }, timing_ms=timer.timing_ms, ) @@ -92,13 +96,13 @@ async def skyvern_session_close( try: if session_id: skyvern = get_skyvern() - await skyvern.close_browser_session(session_id) + result = await do_session_close(skyvern, session_id) if current.context and current.context.session_id == session_id: set_current_session(SessionState()) timer.mark("sdk") return make_result( "skyvern_session_close", - data={"session_id": session_id, "closed": True}, + data={"session_id": result.session_id, "closed": result.closed}, timing_ms=timer.timing_ms, ) @@ -138,17 +142,17 @@ async def skyvern_session_list() -> dict[str, Any]: with Timer() as timer: try: skyvern = get_skyvern() - sessions = await skyvern.get_browser_sessions() + sessions = await do_session_list(skyvern) timer.mark("sdk") session_data = [ { - "session_id": s.browser_session_id, + "session_id": s.session_id, "status": s.status, - "started_at": s.started_at.isoformat() if s.started_at else None, + "started_at": s.started_at, "timeout": s.timeout, "runnable_id": s.runnable_id, - "available": s.runnable_id is None and s.browser_address is not None, + "available": s.available, } for s in sessions ] diff --git a/tests/unit/test_cli_commands.py b/tests/unit/test_cli_commands.py new file mode 100644 index 00000000..3f142b4a --- /dev/null +++ b/tests/unit/test_cli_commands.py @@ -0,0 +1,149 @@ +"""Tests for CLI commands infrastructure: _state.py and _output.py.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +import typer + +from skyvern.cli.commands._state import CLIState, clear_state, load_state, save_state + +# --------------------------------------------------------------------------- +# _state.py +# --------------------------------------------------------------------------- + + +def _patch_state_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + monkeypatch.setattr("skyvern.cli.commands._state.STATE_DIR", tmp_path) + monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "state.json") + + +class TestCLIState: + def test_save_load_roundtrip(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + _patch_state_dir(monkeypatch, tmp_path) + save_state(CLIState(session_id="pbs_123", cdp_url=None, mode="cloud")) + loaded = load_state() + assert loaded is not None + assert loaded.session_id == "pbs_123" + assert loaded.cdp_url is None + assert loaded.mode == "cloud" + + def test_save_load_roundtrip_cdp(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + _patch_state_dir(monkeypatch, tmp_path) + save_state(CLIState(session_id=None, cdp_url="ws://localhost:9222/devtools/browser/abc", mode="cdp")) + loaded = load_state() + assert loaded is not None + assert loaded.session_id is None + assert loaded.cdp_url == "ws://localhost:9222/devtools/browser/abc" + assert loaded.mode == "cdp" + + def test_load_returns_none_when_missing(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "nonexistent.json") + assert load_state() is None + + def test_24h_ttl_expires(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + _patch_state_dir(monkeypatch, tmp_path) + state_file = tmp_path / "state.json" + state_file.write_text( + json.dumps( + { + "session_id": "pbs_old", + "mode": "cloud", + "created_at": "2020-01-01T00:00:00+00:00", + } + ) + ) + assert load_state() is None + + def test_clear_state(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + _patch_state_dir(monkeypatch, tmp_path) + save_state(CLIState(session_id="pbs_123")) + clear_state() + assert not (tmp_path / "state.json").exists() + + def test_load_ignores_corrupt_file(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + state_file = tmp_path / "state.json" + monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", state_file) + state_file.write_text("not-json") + assert load_state() is None + + +# --------------------------------------------------------------------------- +# _output.py +# --------------------------------------------------------------------------- + + +class TestOutput: + def test_json_envelope(self, capsys: pytest.CaptureFixture) -> None: + from skyvern.cli.commands._output import output + + output({"key": "value"}, action="test", json_mode=True) + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is True + assert parsed["action"] == "test" + assert parsed["data"]["key"] == "value" + + def test_json_error(self, capsys: pytest.CaptureFixture) -> None: + from skyvern.cli.commands._output import output_error + + with pytest.raises(SystemExit, match="1"): + output_error("bad thing", hint="fix it", json_mode=True) + parsed = json.loads(capsys.readouterr().out) + assert parsed["ok"] is False + assert parsed["error"]["message"] == "bad thing" + + +# --------------------------------------------------------------------------- +# Connection resolution +# --------------------------------------------------------------------------- + + +class TestResolveConnection: + def test_explicit_session_wins(self) -> None: + from skyvern.cli.commands.browser import _resolve_connection + + result = _resolve_connection("pbs_explicit", None) + assert result.mode == "cloud" + assert result.session_id == "pbs_explicit" + assert result.cdp_url is None + + def test_explicit_cdp_wins(self) -> None: + from skyvern.cli.commands.browser import _resolve_connection + + result = _resolve_connection(None, "ws://localhost:9222/devtools/browser/abc") + assert result.mode == "cdp" + assert result.session_id is None + assert result.cdp_url == "ws://localhost:9222/devtools/browser/abc" + + def test_rejects_both_connection_flags(self) -> None: + from skyvern.cli.commands.browser import _resolve_connection + + with pytest.raises(typer.BadParameter, match="Pass only one of --session or --cdp"): + _resolve_connection("pbs_explicit", "ws://localhost:9222/devtools/browser/abc") + + def test_state_fallback(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + from skyvern.cli.commands.browser import _resolve_connection + + _patch_state_dir(monkeypatch, tmp_path) + save_state(CLIState(session_id="pbs_from_state", mode="cloud")) + result = _resolve_connection(None, None) + assert result.mode == "cloud" + assert result.session_id == "pbs_from_state" + + def test_state_fallback_cdp(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + from skyvern.cli.commands.browser import _resolve_connection + + _patch_state_dir(monkeypatch, tmp_path) + save_state(CLIState(session_id=None, cdp_url="ws://localhost:9222/devtools/browser/abc", mode="cdp")) + result = _resolve_connection(None, None) + assert result.mode == "cdp" + assert result.cdp_url == "ws://localhost:9222/devtools/browser/abc" + + def test_no_session_raises(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + from skyvern.cli.commands.browser import _resolve_connection + + monkeypatch.setattr("skyvern.cli.commands._state.STATE_FILE", tmp_path / "nonexistent.json") + with pytest.raises(typer.BadParameter, match="No active browser connection"): + _resolve_connection(None, None) diff --git a/tests/unit/test_cli_commands_logging.py b/tests/unit/test_cli_commands_logging.py new file mode 100644 index 00000000..b714838d --- /dev/null +++ b/tests/unit/test_cli_commands_logging.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import logging + +import skyvern.cli.commands as cli_commands + + +def test_configure_cli_logging_is_idempotent(monkeypatch) -> None: + setup_calls: list[str] = [] + monkeypatch.setattr(cli_commands, "_setup_logger", lambda: setup_calls.append("called")) + monkeypatch.setattr(cli_commands, "_cli_logging_configured", False) + + logger_names = ("skyvern", "httpx", "litellm", "playwright", "httpcore") + previous_levels = {name: logging.getLogger(name).level for name in logger_names} + try: + cli_commands.configure_cli_logging() + assert setup_calls == ["called"] + for name in logger_names: + assert logging.getLogger(name).level == logging.WARNING + + cli_commands.configure_cli_logging() + assert setup_calls == ["called"] + finally: + for name, level in previous_levels.items(): + logging.getLogger(name).setLevel(level) + + +def test_cli_callback_configures_logging(monkeypatch) -> None: + called = False + + def _fake_configure() -> None: + nonlocal called + called = True + + monkeypatch.setattr(cli_commands, "configure_cli_logging", _fake_configure) + cli_commands.cli_callback() + assert called diff --git a/tests/unit/test_cli_shared_core.py b/tests/unit/test_cli_shared_core.py new file mode 100644 index 00000000..5c265c55 --- /dev/null +++ b/tests/unit/test_cli_shared_core.py @@ -0,0 +1,234 @@ +"""Tests for skyvern.cli.core shared modules (guards, browser_ops, session_ops).""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from skyvern.cli.core.browser_ops import do_act, do_extract, do_navigate, do_screenshot, parse_extract_schema +from skyvern.cli.core.guards import ( + GuardError, + check_js_password, + check_password_prompt, + resolve_ai_mode, + validate_button, + validate_wait_until, +) +from skyvern.cli.core.session_ops import do_session_close, do_session_create, do_session_list + +# --------------------------------------------------------------------------- +# guards.py +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "text", + [ + "enter your password", + "use credential to login", + "type the secret", + "enter passphrase", + "enter passcode", + "enter your pin code", + "type pwd here", + "enter passwd", + ], +) +def test_password_guard_blocks_sensitive_text(text: str) -> None: + with pytest.raises(GuardError) as exc_info: + check_password_prompt(text) + assert exc_info.value.hint # hint should always be populated + + +@pytest.mark.parametrize("text", ["click the submit button", "fill in the email field", ""]) +def test_password_guard_allows_normal_text(text: str) -> None: + check_password_prompt(text) # should not raise + + +def test_js_password_guard() -> None: + with pytest.raises(GuardError): + check_js_password('input[type=password].value = "secret"') + with pytest.raises(GuardError): + check_js_password('.type === "password"; el.value = "x"') + check_js_password("document.title") # allowed + + +@pytest.mark.parametrize("value", ["load", "domcontentloaded", "networkidle", "commit", None]) +def test_wait_until_accepts_valid(value: str | None) -> None: + validate_wait_until(value) + + +def test_wait_until_rejects_invalid() -> None: + with pytest.raises(GuardError, match="Invalid wait_until"): + validate_wait_until("badvalue") + + +@pytest.mark.parametrize("value", ["left", "right", "middle", None]) +def test_button_accepts_valid(value: str | None) -> None: + validate_button(value) + + +def test_button_rejects_invalid() -> None: + with pytest.raises(GuardError, match="Invalid button"): + validate_button("double") + + +@pytest.mark.parametrize( + "selector,intent,expected", + [ + (None, "click it", ("proactive", None)), + ("#btn", "click it", ("fallback", None)), + ("#btn", None, (None, None)), + (None, None, (None, "INVALID_INPUT")), + ], +) +def test_resolve_ai_mode(selector: str | None, intent: str | None, expected: tuple) -> None: + assert resolve_ai_mode(selector, intent) == expected + + +# --------------------------------------------------------------------------- +# browser_ops.py +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_do_navigate_success() -> None: + page = MagicMock() + page.goto = AsyncMock() + page.url = "https://example.com/final" + page.title = AsyncMock(return_value="Example") + + result = await do_navigate(page, "https://example.com") + assert result.url == "https://example.com/final" + assert result.title == "Example" + + +@pytest.mark.asyncio +async def test_do_navigate_passes_wait_until_through() -> None: + page = MagicMock() + page.goto = AsyncMock() + page.url = "https://example.com/final" + page.title = AsyncMock(return_value="Example") + + result = await do_navigate(page, "https://example.com", wait_until="badvalue") + assert result.url == "https://example.com/final" + page.goto.assert_awaited_once_with("https://example.com", timeout=30000, wait_until="badvalue") + + +@pytest.mark.asyncio +async def test_do_screenshot_full_page() -> None: + page = MagicMock() + page.screenshot = AsyncMock(return_value=b"png-data") + + result = await do_screenshot(page, full_page=True) + assert result.data == b"png-data" + assert result.full_page is True + + +@pytest.mark.asyncio +async def test_do_screenshot_with_selector() -> None: + page = MagicMock() + element = MagicMock() + element.screenshot = AsyncMock(return_value=b"element-data") + page.locator.return_value = element + + result = await do_screenshot(page, selector="#header") + assert result.data == b"element-data" + + +@pytest.mark.asyncio +async def test_do_act_success() -> None: + page = MagicMock() + page.act = AsyncMock() + result = await do_act(page, "enter the password") + assert result.prompt == "enter the password" + assert result.completed is True + + +@pytest.mark.asyncio +async def test_do_extract_rejects_bad_schema() -> None: + with pytest.raises(GuardError, match="Invalid JSON schema"): + await do_extract(MagicMock(), "get data", schema="not-json") + + +@pytest.mark.asyncio +async def test_do_extract_success() -> None: + page = MagicMock() + page.extract = AsyncMock(return_value={"title": "Example"}) + + result = await do_extract(page, "get the title") + assert result.extracted == {"title": "Example"} + + +def test_parse_extract_schema_accepts_preparsed_dict() -> None: + schema = {"type": "object", "properties": {"title": {"type": "string"}}} + parsed = parse_extract_schema(schema) + assert parsed is schema + + +@pytest.mark.asyncio +async def test_do_extract_accepts_preparsed_dict() -> None: + page = MagicMock() + page.extract = AsyncMock(return_value={"title": "Example"}) + schema = {"type": "object", "properties": {"title": {"type": "string"}}} + + result = await do_extract(page, "get the title", schema=schema) + assert result.extracted == {"title": "Example"} + page.extract.assert_awaited_once_with(prompt="get the title", schema=schema) + + +# --------------------------------------------------------------------------- +# session_ops.py +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_do_session_create_local() -> None: + skyvern = MagicMock() + skyvern.launch_local_browser = AsyncMock(return_value=MagicMock()) + + browser, result = await do_session_create(skyvern, local=True, headless=True) + assert result.local is True + assert result.session_id is None + + +@pytest.mark.asyncio +async def test_do_session_create_cloud() -> None: + skyvern = MagicMock() + browser_mock = MagicMock() + browser_mock.browser_session_id = "pbs_123" + skyvern.launch_cloud_browser = AsyncMock(return_value=browser_mock) + + browser, result = await do_session_create(skyvern, timeout=30) + assert result.session_id == "pbs_123" + assert result.timeout_minutes == 30 + + +@pytest.mark.asyncio +async def test_do_session_close() -> None: + skyvern = MagicMock() + skyvern.close_browser_session = AsyncMock() + + result = await do_session_close(skyvern, "pbs_123") + assert result.session_id == "pbs_123" + assert result.closed is True + + +@pytest.mark.asyncio +async def test_do_session_list() -> None: + session = MagicMock() + session.browser_session_id = "pbs_1" + session.status = "active" + session.started_at = None + session.timeout = 60 + session.runnable_id = None + session.browser_address = "ws://localhost:1234" + + skyvern = MagicMock() + skyvern.get_browser_sessions = AsyncMock(return_value=[session]) + + result = await do_session_list(skyvern) + assert len(result) == 1 + assert result[0].session_id == "pbs_1" + assert result[0].available is True diff --git a/tests/unit/test_mcp_browser_tools.py b/tests/unit/test_mcp_browser_tools.py new file mode 100644 index 00000000..7485db86 --- /dev/null +++ b/tests/unit/test_mcp_browser_tools.py @@ -0,0 +1,65 @@ +"""Tests for MCP browser tool preflight validation behavior.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from skyvern.cli.core.result import BrowserContext +from skyvern.cli.mcp_tools import browser as mcp_browser + + +@pytest.mark.asyncio +async def test_skyvern_extract_invalid_schema_preflight_before_session(monkeypatch: pytest.MonkeyPatch) -> None: + get_page = AsyncMock(side_effect=AssertionError("get_page should not be called for invalid schema")) + monkeypatch.setattr(mcp_browser, "get_page", get_page) + + result = await mcp_browser.skyvern_extract(prompt="extract data", schema="{invalid") + + assert result["ok"] is False + assert result["error"]["code"] == mcp_browser.ErrorCode.INVALID_INPUT + assert "Invalid JSON schema" in result["error"]["message"] + get_page.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_skyvern_extract_preparsed_schema_passed_to_core(monkeypatch: pytest.MonkeyPatch) -> None: + page = object() + context = BrowserContext(mode="cloud_session", session_id="pbs_test") + monkeypatch.setattr(mcp_browser, "get_page", AsyncMock(return_value=(page, context))) + + do_extract = AsyncMock(return_value=SimpleNamespace(extracted={"ok": True})) + monkeypatch.setattr(mcp_browser, "do_extract", do_extract) + + result = await mcp_browser.skyvern_extract(prompt="extract data", schema='{"type":"object"}') + + assert result["ok"] is True + await_args = do_extract.await_args + assert await_args is not None + assert isinstance(await_args.kwargs["schema"], dict) + + +@pytest.mark.asyncio +async def test_skyvern_navigate_invalid_wait_until_preflight_before_session(monkeypatch: pytest.MonkeyPatch) -> None: + get_page = AsyncMock(side_effect=AssertionError("get_page should not be called for invalid wait_until")) + monkeypatch.setattr(mcp_browser, "get_page", get_page) + + result = await mcp_browser.skyvern_navigate(url="https://example.com", wait_until="not-a-real-wait-until") + + assert result["ok"] is False + assert result["error"]["code"] == mcp_browser.ErrorCode.INVALID_INPUT + get_page.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_skyvern_act_password_prompt_preflight_before_session(monkeypatch: pytest.MonkeyPatch) -> None: + get_page = AsyncMock(side_effect=AssertionError("get_page should not be called for password prompt")) + monkeypatch.setattr(mcp_browser, "get_page", get_page) + + result = await mcp_browser.skyvern_act(prompt="enter the password and submit") + + assert result["ok"] is False + assert result["error"]["code"] == mcp_browser.ErrorCode.INVALID_INPUT + get_page.assert_not_awaited()