From d96de3b7a2040b54378f4efc76103131d0e794c5 Mon Sep 17 00:00:00 2001 From: Jonathan Dobson Date: Fri, 21 Nov 2025 15:12:26 -0500 Subject: [PATCH] Browser streaming refactor (#4064) --- skyvern/forge/sdk/routes/streaming/agent.py | 191 ------ skyvern/forge/sdk/routes/streaming/auth.py | 24 + .../sdk/routes/streaming/channels/README.md | 336 ++++++++++ .../sdk/routes/streaming/channels/__init__.py | 0 .../sdk/routes/streaming/channels/cdp.py | 185 ++++++ .../routes/streaming/channels/execution.py | 121 ++++ .../routes/streaming/channels/exfiltration.py | 1 + .../sdk/routes/streaming/channels/message.py | 456 ++++++++++++++ .../sdk/routes/streaming/channels/vnc.py | 592 ++++++++++++++++++ skyvern/forge/sdk/routes/streaming/clients.py | 328 ---------- .../forge/sdk/routes/streaming/messages.py | 369 +++-------- .../forge/sdk/routes/streaming/registries.py | 78 +++ .../forge/sdk/routes/streaming/screenshot.py | 11 + skyvern/forge/sdk/routes/streaming/verify.py | 53 +- skyvern/forge/sdk/routes/streaming/vnc.py | 498 +-------------- skyvern/py.typed | 0 16 files changed, 1948 insertions(+), 1295 deletions(-) delete mode 100644 skyvern/forge/sdk/routes/streaming/agent.py create mode 100644 skyvern/forge/sdk/routes/streaming/channels/README.md create mode 100644 skyvern/forge/sdk/routes/streaming/channels/__init__.py create mode 100644 skyvern/forge/sdk/routes/streaming/channels/cdp.py create mode 100644 skyvern/forge/sdk/routes/streaming/channels/execution.py create mode 100644 skyvern/forge/sdk/routes/streaming/channels/exfiltration.py create mode 100644 skyvern/forge/sdk/routes/streaming/channels/message.py create mode 100644 skyvern/forge/sdk/routes/streaming/channels/vnc.py delete mode 100644 skyvern/forge/sdk/routes/streaming/clients.py create mode 100644 skyvern/forge/sdk/routes/streaming/registries.py create mode 100644 skyvern/py.typed diff --git a/skyvern/forge/sdk/routes/streaming/agent.py b/skyvern/forge/sdk/routes/streaming/agent.py deleted file mode 100644 index 7966de7a..00000000 --- a/skyvern/forge/sdk/routes/streaming/agent.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -A lightweight "agent" for interacting with the streaming browser over CDP. -""" - -import typing -from contextlib import asynccontextmanager - -import structlog -from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright - -import skyvern.forge.sdk.routes.streaming.clients as sc -from skyvern.config import settings - -LOG = structlog.get_logger() - - -class StreamingAgent: - """ - A minimal agent that can connect to a browser via CDP and execute JavaScript. - - Specifically for operations during streaming sessions (like copy/pasting selected text, etc.). - """ - - def __init__(self, streaming: sc.Streaming) -> None: - self.streaming = streaming - self.browser: Browser | None = None - self.browser_context: BrowserContext | None = None - self.page: Page | None = None - self.pw: Playwright | None = None - - async def connect(self, cdp_url: str | None = None) -> None: - url = cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL - - LOG.info("StreamingAgent connecting to CDP", cdp_url=url) - - pw = self.pw or await async_playwright().start() - - self.pw = pw - - headers = { - "x-api-key": self.streaming.x_api_key, - } - - self.browser = await pw.chromium.connect_over_cdp(url, headers=headers) - - org_id = self.streaming.organization_id - browser_session_id = ( - self.streaming.browser_session.persistent_browser_session_id if self.streaming.browser_session else None - ) - - if browser_session_id: - cdp_session = await self.browser.new_browser_cdp_session() - await cdp_session.send( - "Browser.setDownloadBehavior", - { - "behavior": "allow", - "downloadPath": f"/app/downloads/{org_id}/{browser_session_id}", - "eventsEnabled": True, - }, - ) - - contexts = self.browser.contexts - if contexts: - LOG.info("StreamingAgent using existing browser context") - self.browser_context = contexts[0] - else: - LOG.warning("No existing browser context found, creating new one") - self.browser_context = await self.browser.new_context() - - pages = self.browser_context.pages - if pages: - self.page = pages[0] - LOG.info("StreamingAgent connected to page", url=self.page.url) - else: - LOG.warning("No pages found in browser context") - - LOG.info("StreamingAgent connected successfully") - - async def evaluate_js( - self, expression: str, arg: str | int | float | bool | list | dict | None = None - ) -> str | int | float | bool | list | dict | None: - if not self.page: - raise RuntimeError("StreamingAgent is not connected to a page. Call connect() first.") - - LOG.info("StreamingAgent evaluating JS", expression=expression[:100]) - - try: - result = await self.page.evaluate(expression, arg) - LOG.info("StreamingAgent JS evaluation successful") - return result - except Exception as ex: - LOG.exception("StreamingAgent JS evaluation failed", expression=expression, ex=str(ex)) - raise - - async def get_selected_text(self) -> str: - LOG.info("StreamingAgent getting selected text") - - js_expression = """ - () => { - const selection = window.getSelection(); - return selection ? selection.toString() : ''; - } - """ - - selected_text = await self.evaluate_js(js_expression) - - if isinstance(selected_text, str) or selected_text is None: - LOG.info("StreamingAgent got selected text", length=len(selected_text) if selected_text else 0) - return selected_text or "" - - raise RuntimeError(f"StreamingAgent selected text is not a string, but a(n) '{type(selected_text)}'") - - async def paste_text(self, text: str) -> None: - LOG.info("StreamingAgent pasting text") - - js_expression = """ - (text) => { - const activeElement = document.activeElement; - if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) { - const start = activeElement.selectionStart || 0; - const end = activeElement.selectionEnd || 0; - const value = activeElement.value || ''; - activeElement.value = value.slice(0, start) + text + value.slice(end); - const newCursorPos = start + text.length; - activeElement.setSelectionRange(newCursorPos, newCursorPos); - } - } - """ - - await self.evaluate_js(js_expression, text) - - LOG.info("StreamingAgent pasted text successfully") - - async def close(self) -> None: - LOG.info("StreamingAgent closing connection") - - if self.browser: - await self.browser.close() - self.browser = None - self.browser_context = None - self.page = None - - if self.pw: - await self.pw.stop() - self.pw = None - - LOG.info("StreamingAgent closed") - - -@asynccontextmanager -async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterator[StreamingAgent]: - """ - The first pass at this has us doing the following for every operation: - - creating a new agent - - connecting - - [doing smth] - - closing the agent - - This may add latency, but locally it is pretty fast. This keeps things stateless for now. - - If it turns out it's too slow, we can refactor to keep a persistent agent per streaming client. - """ - - if not streaming: - msg = "connected_agent: no streaming client provided." - LOG.error(msg) - - raise Exception(msg) - - if not streaming.browser_session or not streaming.browser_session.browser_address: - msg = "connected_agent: no browser session or browser address found for streaming client." - - LOG.error( - msg, - client_id=streaming.client_id, - organization_id=streaming.organization_id, - ) - - raise Exception(msg) - - agent = StreamingAgent(streaming=streaming) - - try: - await agent.connect(streaming.browser_session.browser_address) - - # NOTE(jdo:streaming-local-dev): use BROWSER_REMOTE_DEBUGGING_URL from settings - # await agent.connect() - - yield agent - finally: - await agent.close() diff --git a/skyvern/forge/sdk/routes/streaming/auth.py b/skyvern/forge/sdk/routes/streaming/auth.py index 00250d24..75de230a 100644 --- a/skyvern/forge/sdk/routes/streaming/auth.py +++ b/skyvern/forge/sdk/routes/streaming/auth.py @@ -6,11 +6,35 @@ import structlog from fastapi import WebSocket from websockets.exceptions import ConnectionClosedOK +from skyvern.forge import app +from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.services.org_auth_service import get_current_org LOG = structlog.get_logger() +class Constants: + MISSING_API_KEY = "" + + +async def get_x_api_key(organization_id: str) -> str: + token = await app.DATABASE.get_valid_org_auth_token( + organization_id, + OrganizationAuthTokenType.api.value, + ) + + if not token: + LOG.warning( + "No valid API key found for organization when streaming.", + organization_id=organization_id, + ) + x_api_key = Constants.MISSING_API_KEY + else: + x_api_key = token.token + + return x_api_key + + async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None: """ Accepts the websocket connection. diff --git a/skyvern/forge/sdk/routes/streaming/channels/README.md b/skyvern/forge/sdk/routes/streaming/channels/README.md new file mode 100644 index 00000000..1fd14ae9 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/channels/README.md @@ -0,0 +1,336 @@ +# Channels + +A "channel", as used within the streaming mechanism of our remote browsers, +is a WebSocket fit to some particular purpose. + +There is/are: + - a "VNC" channel that transmits NoVNC's RFB protocol data + - a "Message" channel that transmits JSON between the frontend app and + the api server + - "CDP" channels that send messages to a remote browser using CDP protocol + data + - an "Execution" channel (one-off executions) + - a soon-to-be "Exfiltration" channel (user event streaming) + +In all cases, these are just WebSockets. They have been bucketed into "named channels" +to aid understanding. + +These channels are described at the top of their respective files. + +## Architecture + +WARN: below is an AI-generated architecture document for all of the code beneath +the `skyvern/forge/sdk/routes/streaming` directory. It looks correct. + +### High-Level Component Diagram + +``` +┌─────────────────┐ +│ Frontend App │ +│ (Skyvern) │ +└────────┬────────┘ + │ + │ Two WebSocket Connections (paired via client_id) + │ + ┌────┴────┬──────────────────────────────────────────┐ + │ │ │ + │ VNC Channel Message Channel + http (RFB Protocol) (JSON Messages) + │ │ │ + │ │ │ +┌───▼─────────▼──────────────────────────────────────────▼────┐ +│ API Server │ +│ ┌────────────────────────────────────────────────────────┐ │ +│ │ Registries (In-Memory State) │ │ +│ │ - vnc_channels: dict[client_id -> VncChannel] │ │ +│ │ - message_channels: dict[client_id -> MessageChannel] │ │ +│ └────────────────────────────────────────────────────────┘ │ +│ │ +│ VNC Channel Logic: Message Channel Logic: │ +│ - RFB pass-through - Copy/paste coordination │ +│ - Keyboard/mouse filtering - Control handoff (agent/user)│ +│ - Interactor control - Clipboard management │ +│ - Copy/paste detection - Channel coordination │ +│ │ +│ CDP Channels (created on-demand): │ +│ - ExecutionChannel: JS evaluation (paste, get selected) │ +│ - ExfiltrationChannel: (future) user event streaming │ +│ │ +└────┬─────────────────────────────────────────────────────┬──┘ + │ │ + │ WebSocket (RFB) Playwright (CDP)│ + │ │ +┌────▼─────────────────────────────────────────────────────▼──┐ +│ Persistent Browser Session │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ noVNC Server Chrome/Chromium │ │ +│ │ (websockify) │ │ +│ └──────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### Channel Pairing & Sticky Sessions + +**Critical Design Constraint**: The VNC and Message channels for a given frontend +instance MUST connect to the same API server instance because they coordinate +via in-memory registries keyed by `client_id`. + +``` +Frontend Instance (client_id="abc123") + │ + ├─→ VNC Channel ─────→ API Server Instance #2 + │ ↓ + │ vnc_channels["abc123"] = VncChannel + │ ↕ (coordinate via client_id) + └─→ Message Channel ──→ API Server Instance #2 + ↓ + message_channels["abc123"] = MessageChannel +``` + +**Deployment Requirement**: Load balancer must use sticky sessions (e.g., cookie-based +or IP-based affinity) to ensure both WebSocket connections from the same client_id +reach the same backend instance. + +### Channel Lifecycle & Verification + +``` +┌──────────────────────────────────────────────────────────────┐ +│ Channel Creation (per browser_session/task/workflow_run) │ +└────────────────────┬─────────────────────────────────────────┘ + │ + ▼ + ┌────────────────────────────────────────┐ + │ Initial Verification │ + │ - verify_browser_session() │ + │ - verify_task() │ + │ - verify_workflow_run() │ + │ Returns: entity + browser session │ + └────────┬───────────────────────────────┘ + │ + ▼ + ┌────────────────────────────────────────┐ + │ Channel + Loops Created │ + │ VncChannel/MessageChannel initialized │ + │ + Added to registry │ + └────────┬───────────────────────────────┘ + │ + ▼ + ┌────────────────────────────────────────┐ + │ Concurrent Loop Execution │ + │ (via collect() - fail-fast) │ + │ │ + │ Loop 1: Verification Loop │ + │ - Polls every 5s │ + │ - Updates channel state │ + │ - Exits if entity invalid │ + │ │ + │ Loop 2: Data Streaming Loop │ + │ - VNC: bidirectional RFB │ + │ - Message: JSON messages │ + │ - Exits on disconnect │ + └────────┬───────────────────────────────┘ + │ + ▼ + ┌────────────────────────────────────────┐ + │ Channel Cleanup │ + │ - Close WebSocket │ + │ - Remove from registry │ + │ - Clear channel state │ + └────────────────────────────────────────┘ +``` + +### VNC Channel Data Flow + +``` +User Keyboard/Mouse Input + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Frontend (noVNC client) │ +│ Encodes input as RFB protocol bytes │ +└────────┬──────────────────────────────────────────────────┘ + │ WebSocket (bytes) + ▼ +┌───────────────────────────────────────────────────────────┐ +│ API Server: VncChannel.loop_stream_vnc() │ +│ frontend_to_browser() coroutine │ +│ │ +│ 1. Receive RFB bytes from frontend │ +│ 2. Detect message type (keyboard=4, mouse=5) │ +│ 3. Update key_state tracking │ +│ 4. Check for special key combinations: │ +│ - Ctrl+C / Cmd+C → copy_text() via CDP │ +│ - Ctrl+V / Cmd+V → ask_for_clipboard() via Message │ +│ - Ctrl+O → BLOCK (forbidden) │ +│ 5. Check interactor mode: │ +│ - If interactor=="agent" → BLOCK user input │ +│ - If interactor=="user" → PASS THROUGH │ +│ 6. Block right-mouse-button (security) │ +│ 7. Forward to noVNC server │ +└────────┬──────────────────────────────────────────────────┘ + │ WebSocket (bytes) + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Persistent Browser: noVNC Server (websockify) │ +│ Translates RFB → VNC protocol │ +└────────┬──────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Browser Display Update │ +└───────────────────────────────────────────────────────────┘ + +Screen Updates (reverse direction): +Browser → noVNC → VncChannel.browser_to_frontend() → Frontend +``` + +### Message Channel + CDP Execution Flow + +``` +User Pastes (Ctrl+V detected in VNC channel) + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ VncChannel: ask_for_clipboard() │ +│ Finds MessageChannel via registry[client_id] │ +└────────┬──────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ MessageChannel.ask_for_clipboard() │ +│ Sends: {"kind": "ask-for-clipboard"} │ +└────────┬──────────────────────────────────────────────────┘ + │ WebSocket (JSON) + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Frontend: User's clipboard content │ +│ Responds: {"kind": "ask-for-clipboard-response", │ +│ "text": "clipboard content"} │ +└────────┬──────────────────────────────────────────────────┘ + │ WebSocket (JSON) + ▼ +┌───────────────────────────────────────────────────────────┐ +│ MessageChannel: handle_data() │ +│ Receives clipboard text │ +│ Finds VncChannel via registry[client_id] │ +└────────┬──────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ ExecutionChannel (CDP) │ +│ 1. Connect to browser via Playwright │ +│ 2. Get browser context + page │ +│ 3. evaluate_js(paste_text_script, clipboard_text) │ +│ 4. Close CDP connection │ +└────────┬──────────────────────────────────────────────────┘ + │ CDP over WebSocket + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Browser: Text pasted into active element │ +└───────────────────────────────────────────────────────────┘ + +Similar flow for Copy (Ctrl+C): +VNC detects → ExecutionChannel.get_selected_text() → +MessageChannel sends {"kind": "copied-text", "text": "..."} +→ Frontend updates clipboard +``` + +### Control Flow: Agent ↔ User Interaction + +NOTE: we don't really have an "agent" at this time. But any control of the +browser that is not user-originated is kinda' agent-like, by some +definition of "agent". Here, we do not have an "AI agent". Future work may +alter this state of affairs - and some "agent" could operate the browser +automatically. + +``` +┌───────────────────────────────────────────────────────────┐ +│ Initial State: interactor = "agent" │ +│ - User keyboard/mouse input is BLOCKED │ +│ - Agent can control browser via CDP │ +└────────┬──────────────────────────────────────────────────┘ + │ + │ User clicks "Take Control" in frontend + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Frontend → MessageChannel │ +│ {"kind": "take-control"} │ +└────────┬──────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ MessageChannel.handle_data() │ +│ vnc_channel.interactor = "user" │ +└────────┬──────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ New State: interactor = "user" │ +│ - User keyboard/mouse input is PASSED THROUGH │ +│ - Agent should pause automation │ +└────────┬──────────────────────────────────────────────────┘ + │ + │ User clicks "Cede Control" in frontend + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Frontend → MessageChannel │ +│ {"kind": "cede-control"} │ +└────────┬──────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ MessageChannel.handle_data() │ +│ vnc_channel.interactor = "agent" │ +│ → Back to initial state │ +└───────────────────────────────────────────────────────────┘ +``` + +### Error Propagation & Cleanup + +The system uses `collect()` (fail-fast gather) for loop management: + +``` +Channel has 2 concurrent loops: + - Verification loop (polls DB every 5s) + - Streaming loop (handles WebSocket I/O) + +collect() behavior: + 1. Waits for ANY loop to fail or complete + 2. Cancels all other loops + 3. Propagates the first exception + +Cleanup (always executed via finally): + - channel.close() + - Sets browser_session/task/workflow_run = None + - Closes WebSocket + - Removes from registry +``` + +### Database Entity Relationships + +``` +Organization + │ + ├─→ BrowserSession ────────┐ + │ │ + ├─→ Task ──────────────────┤ + │ (has optional │ + │ browser_session) │ + │ │ + └─→ WorkflowRun ───────────┤ + (has optional │ + browser_session) │ + │ + ▼ + VncChannel + MessageChannel + (in-memory, paired by client_id) +``` + +### Key Design Patterns + +1. **Channel Pairing**: Two WebSocket connections coordinated via in-memory registry +2. **Fail-Fast Loops**: `collect()` ensures any loop failure closes the entire channel +3. **Interactor Mode**: Binary state controlling whether user input is allowed +4. **On-Demand CDP**: ExecutionChannel creates temporary connections for each operation +5. **Polling Verification**: Every 5s, channels verify their backing entity still exists +6. **Pass-Through Proxy**: API server intercepts but doesn't transform RFB data diff --git a/skyvern/forge/sdk/routes/streaming/channels/__init__.py b/skyvern/forge/sdk/routes/streaming/channels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skyvern/forge/sdk/routes/streaming/channels/cdp.py b/skyvern/forge/sdk/routes/streaming/channels/cdp.py new file mode 100644 index 00000000..e65fe3df --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/channels/cdp.py @@ -0,0 +1,185 @@ +""" +A channel for connecting to a persistent browser instance. + +What this channel looks like: + + [API Server] <--> [Browser (CDP)] + +Channel data: + + CDP protocol data, with Playwright thrown in. +""" + +from __future__ import annotations + +import asyncio +import typing as t + +import structlog +from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright + +from skyvern.config import settings + +if t.TYPE_CHECKING: + from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel + +LOG = structlog.get_logger() + + +class CdpChannel: + """ + CdpChannel. Relies on a VncChannel - without one, a CdpChannel has no + r'aison d'etre. + """ + + def __new__(cls, *_: t.Iterable[t.Any], **__: t.Mapping[str, t.Any]) -> t.Self: # noqa: N805 + if cls is CdpChannel: + raise TypeError("CdpChannel class cannot be instantiated directly.") + + return super().__new__(cls) + + def __init__(self, *, vnc_channel: VncChannel) -> None: + self.vnc_channel = vnc_channel + # -- + self.browser: Browser | None = None + self.browser_context: BrowserContext | None = None + self.page: Page | None = None + self.pw: Playwright | None = None + self.url: str | None = None + + @property + def class_name(self) -> str: + return self.__class__.__name__ + + @property + def identity(self) -> t.Dict[str, t.Any]: + base = self.vnc_channel.identity + + return base | {"url": self.url} + + async def connect(self, cdp_url: str | None = None) -> t.Self: + """ + Idempotent. + """ + + if self.browser and self.browser.is_connected(): + return self + + await self.close() + + if cdp_url: + url = cdp_url + elif self.vnc_channel.browser_session and self.vnc_channel.browser_session.browser_address: + url = self.vnc_channel.browser_session.browser_address + else: + url = settings.BROWSER_REMOTE_DEBUGGING_URL + + self.url = url + + LOG.info(f"{self.class_name} connecting to CDP", **self.identity) + + pw = self.pw or await async_playwright().start() + + self.pw = pw + + headers = ( + { + "x-api-key": self.vnc_channel.x_api_key, + } + if self.vnc_channel.x_api_key + else None + ) + + def on_close() -> None: + LOG.warning( + f"{self.class_name} closing because the persistent browser disconnected itself.", **self.identity + ) + close_task = asyncio.create_task(self.close()) + close_task.add_done_callback(lambda _: asyncio.create_task(self.connect())) # TODO: avoid blind reconnect + + self.browser = await pw.chromium.connect_over_cdp(url, headers=headers) + self.browser.on("disconnected", on_close) + + await self.apply_download_behavior(self.browser) + + contexts = self.browser.contexts + if contexts: + LOG.info(f"{self.class_name} using existing browser context", **self.identity) + self.browser_context = contexts[0] + else: + LOG.warning(f"{self.class_name} No existing browser context found, creating new one", **self.identity) + self.browser_context = await self.browser.new_context() + + pages = self.browser_context.pages + if pages: + self.page = pages[0] + LOG.info(f"{self.class_name} connected to page", **self.identity) + else: + LOG.warning(f"{self.class_name} No pages found in browser context", **self.identity) + + LOG.info(f"{self.class_name} connected successfully", **self.identity) + + return self + + async def apply_download_behavior(self, browser: Browser) -> t.Self: + org_id = self.vnc_channel.organization_id + + browser_session_id = ( + self.vnc_channel.browser_session.persistent_browser_session_id if self.vnc_channel.browser_session else None + ) + + download_path = f"/app/downloads/{org_id}/{browser_session_id}" if browser_session_id else "/app/downloads/" + + cdp_session = await browser.new_browser_cdp_session() + + await cdp_session.send( + "Browser.setDownloadBehavior", + { + "behavior": "allow", + "downloadPath": download_path, + "eventsEnabled": True, + }, + ) + + await cdp_session.detach() + + return self + + async def close(self) -> None: + LOG.info(f"{self.class_name} closing connection", **self.identity) + + if self.browser: + try: + await self.browser.close() + except Exception: + pass + self.browser = None + + if self.pw: + await self.pw.stop() + self.pw = None + + self.browser_context = None + self.page = None + + LOG.info(f"{self.class_name} closed", **self.identity) + + async def evaluate_js( + self, + expression: str, + arg: str | int | float | bool | list | dict | None = None, + ) -> str | int | float | bool | list | dict | None: + await self.connect() + + if not self.page: + raise RuntimeError(f"{self.class_name} evaluate_js: not connected to a page. Call connect() first.") + + LOG.info(f"{self.class_name} evaluating js", expression=expression[:100], **self.identity) + + try: + result = await self.page.evaluate(expression, arg) + LOG.info(f"{self.class_name} evaluated js successfully", **self.identity) + return result + except Exception: + LOG.exception(f"{self.class_name} failed to evaluate js", expression=expression, **self.identity) + raise diff --git a/skyvern/forge/sdk/routes/streaming/channels/execution.py b/skyvern/forge/sdk/routes/streaming/channels/execution.py new file mode 100644 index 00000000..90882b33 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/channels/execution.py @@ -0,0 +1,121 @@ +""" +A channel for executing JavaScript against a persistent browser instance. + +What this channel looks like: + + [API Server] <--> [Browser (CDP)] + +Channel data: + + Chrome DevTools Protocol (CDP) over WebSockets. We cheat and use Playwright. +""" + +from __future__ import annotations + +import typing as t +from contextlib import asynccontextmanager + +import structlog + +from skyvern.forge.sdk.routes.streaming.channels.cdp import CdpChannel + +if t.TYPE_CHECKING: + from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel + +LOG = structlog.get_logger() + + +class ExecutionChannel(CdpChannel): + """ + ExecutionChannel. + """ + + @property + def class_name(self) -> str: + return self.__class__.__name__ + + async def get_selected_text(self) -> str: + LOG.info(f"{self.class_name} getting selected text", **self.identity) + + js_expression = """ + () => { + const selection = window.getSelection(); + return selection ? selection.toString() : ''; + } + """ + + selected_text = await self.evaluate_js(js_expression, self.page) + + if isinstance(selected_text, str) or selected_text is None: + LOG.info( + f"{self.class_name} got selected text", + length=len(selected_text) if selected_text else 0, + **self.identity, + ) + return selected_text or "" + + raise RuntimeError(f"{self.class_name} selected text is not a string, but a(n) '{type(selected_text)}'") + + async def paste_text(self, text: str) -> None: + LOG.info(f"{self.class_name} pasting text", **self.identity) + + js_expression = """ + (text) => { + const activeElement = document.activeElement; + if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) { + const start = activeElement.selectionStart || 0; + const end = activeElement.selectionEnd || 0; + const value = activeElement.value || ''; + activeElement.value = value.slice(0, start) + text + value.slice(end); + const newCursorPos = start + text.length; + activeElement.setSelectionRange(newCursorPos, newCursorPos); + } + } + """ + + await self.evaluate_js(js_expression, text) + + LOG.info(f"{self.class_name} pasted text successfully", **self.identity) + + async def close(self) -> None: + LOG.info(f"{self.class_name} closing connection", **self.identity) + + if self.browser: + await self.browser.close() + self.browser = None + self.browser_context = None + self.page = None + + if self.pw: + await self.pw.stop() + self.pw = None + + LOG.info(f"{self.class_name} closed", **self.identity) + + +@asynccontextmanager +async def execution_channel(vnc_channel: VncChannel) -> t.AsyncIterator[ExecutionChannel]: + """ + The first pass at this has us doing the following for every operation: + - creating a new channel + - connecting + - [doing smth] + - closing the channel + + This may add latency, but locally it is pretty fast. This keeps things stateless for now. + + If it turns out it's too slow, we can refactor to keep a persistent channel per vnc client. + """ + + channel = ExecutionChannel(vnc_channel=vnc_channel) + + try: + await channel.connect() + + # NOTE(jdo:streaming-local-dev) + # from skyvern.config import settings + # await channel.connect(settings.BROWSER_REMOTE_DEBUGGING_URL) + + yield channel + finally: + await channel.close() diff --git a/skyvern/forge/sdk/routes/streaming/channels/exfiltration.py b/skyvern/forge/sdk/routes/streaming/channels/exfiltration.py new file mode 100644 index 00000000..d352c7e8 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/channels/exfiltration.py @@ -0,0 +1 @@ +# stub diff --git a/skyvern/forge/sdk/routes/streaming/channels/message.py b/skyvern/forge/sdk/routes/streaming/channels/message.py new file mode 100644 index 00000000..51defda0 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/channels/message.py @@ -0,0 +1,456 @@ +""" +A channel for streaming whole messages between our frontend and our API server. +This channel can access a persistent browser instance through the execution channel. + +What this channel looks like: + + [Skyvern App] <--> [API Server] + +Channel data: + + JSON over WebSockets. Semantics are fire and forget. Req-resp is built on + top of that using message types. +""" + +import asyncio +import dataclasses +import typing as t + +import structlog +from fastapi import WebSocket, WebSocketDisconnect +from starlette.websockets import WebSocketState +from websockets.exceptions import ConnectionClosedError + +from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel +from skyvern.forge.sdk.routes.streaming.registries import ( + add_message_channel, + del_message_channel, + get_vnc_channel, +) +from skyvern.forge.sdk.routes.streaming.verify import ( + loop_verify_browser_session, + loop_verify_workflow_run, + verify_browser_session, + verify_workflow_run, +) +from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession +from skyvern.forge.sdk.utils.aio import collect +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun + +LOG = structlog.get_logger() + +Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs" + + +MessageKinds = t.Literal[ + "ask-for-clipboard-response", + "cede-control", + "take-control", +] + + +@dataclasses.dataclass +class Message: + kind: MessageKinds + + +@dataclasses.dataclass +class MessageInTakeControl(Message): + kind: t.Literal["take-control"] = "take-control" + + +@dataclasses.dataclass +class MessageInCedeControl(Message): + kind: t.Literal["cede-control"] = "cede-control" + + +@dataclasses.dataclass +class MessageInAskForClipboardResponse(Message): + kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response" + text: str = "" + + +ChannelMessage = t.Union[ + MessageInAskForClipboardResponse, + MessageInCedeControl, + MessageInTakeControl, +] + + +def reify_channel_message(data: dict) -> ChannelMessage: + kind = data.get("kind", None) + + match kind: + case "ask-for-clipboard-response": + text = data.get("text") or "" + return MessageInAskForClipboardResponse(text=text) + case "cede-control": + return MessageInCedeControl() + case "take-control": + return MessageInTakeControl() + case _: + raise ValueError(f"Unknown message kind: '{kind}'") + + +@dataclasses.dataclass +class MessageChannel: + """ + A message channel for streaming JSON messages between our frontend and our API server. + """ + + client_id: str + organization_id: str + websocket: WebSocket + # -- + out_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) # warn: unbounded + browser_session: AddressablePersistentBrowserSession | None = None + workflow_run: WorkflowRun | None = None + + def __post_init__(self) -> None: + add_message_channel(self) + + @property + def class_name(self) -> str: + return self.__class__.__name__ + + @property + def identity(self) -> dict[str, str]: + base = {"organization_id": self.organization_id} + + if self.browser_session: + return base | {"browser_session_id": self.browser_session.persistent_browser_session_id} + + if self.workflow_run: + return base | {"workflow_run_id": self.workflow_run.id} + + return base + + async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel": + LOG.info(f"{self.class_name} closing message stream.", reason=reason, code=code, **self.identity) + + self.browser_session = None + self.workflow_run = None + + try: + await self.websocket.close(code=code, reason=reason) + except Exception: + pass + + del_message_channel(self.client_id) + + return self + + @property + def is_open(self) -> bool: + if self.websocket.client_state != WebSocketState.CONNECTED: + return False + + return True + + async def drain(self) -> list[dict]: + datums: list[dict] = [] + + tasks = [ + asyncio.create_task(self.receive_from_out_queue()), + asyncio.create_task(self.receive_from_user()), + ] + + results = await asyncio.gather(*tasks) + + for result in results: + datums.extend(result) + + if datums: + LOG.info(f"{self.class_name} Drained {len(datums)} messages from message channel.", **self.identity) + + return datums + + async def receive_from_user(self) -> list[dict]: + datums: list[dict] = [] + + while True: + try: + data = await asyncio.wait_for(self.websocket.receive_json(), timeout=0.001) + datums.append(data) + except asyncio.TimeoutError: + break + except RuntimeError: + if "not connected" in str(RuntimeError).lower(): + break + except Exception: + LOG.exception(f"{self.class_name} Failed to receive message from message channel", **self.identity) + break + + return datums + + async def receive_from_out_queue(self) -> list[dict]: + datums: list[dict] = [] + + while True: + try: + data = await asyncio.wait_for(self.out_queue.get(), timeout=0.001) + datums.append(data) + except asyncio.TimeoutError: + break + except asyncio.QueueEmpty: + break + + return datums + + def receive_from_out_queue_nowait(self) -> list[dict]: + datums: list[dict] = [] + + while True: + try: + data = self.out_queue.get_nowait() + datums.append(data) + except asyncio.QueueEmpty: + break + + return datums + + async def send(self, *, messages: list[dict]) -> t.Self: + for message in messages: + await self.out_queue.put(message) + + return self + + def send_nowait(self, *, messages: list[dict]) -> t.Self: + for message in messages: + self.out_queue.put_nowait(message) + + return self + + async def ask_for_clipboard(self) -> None: + LOG.info(f"{self.class_name} Sending ask-for-clipboard to message channel", **self.identity) + + try: + await self.websocket.send_json( + { + "kind": "ask-for-clipboard", + } + ) + except Exception: + LOG.exception(f"{self.class_name} Failed to send ask-for-clipboard to message channel", **self.identity) + + async def send_copied_text(self, copied_text: str) -> None: + LOG.info(f"{self.class_name} Sending copied text to message channel", **self.identity) + + try: + await self.websocket.send_json( + { + "kind": "copied-text", + "text": copied_text, + } + ) + except Exception: + LOG.exception(f"{self.class_name} Failed to send copied text to message channel", **self.identity) + + +async def loop_stream_messages(message_channel: MessageChannel) -> None: + """ + Stream messages and their results back and forth. + + Loops until the websocket is closed. + """ + + class_name = message_channel.class_name + + async def handle_data(data: dict) -> None: + nonlocal class_name + + try: + message = reify_channel_message(data) + except ValueError: + LOG.error(f"MessageChannel: cannot reify channel message from data: {data}", **message_channel.identity) + return + + message_kind = message.kind + + match message_kind: + case "ask-for-clipboard-response": + if not isinstance(message, MessageInAskForClipboardResponse): + LOG.error( + f"{class_name} invalid message type for ask-for-clipboard-response.", + message=message, + **message_channel.identity, + ) + return + + vnc_channel = get_vnc_channel(message_channel.client_id) + + if not vnc_channel: + LOG.error( + f"{class_name} no vnc channel found for message channel.", + message=message, + **message_channel.identity, + ) + return + + text = message.text + + async with execution_channel(vnc_channel) as execute: + await execute.paste_text(text) + + case "cede-control": + vnc_channel = get_vnc_channel(message_channel.client_id) + + if not vnc_channel: + LOG.error( + f"{class_name} no vnc channel client found for message channel.", + message=message, + **message_channel.identity, + ) + return + vnc_channel.interactor = "agent" + + case "take-control": + LOG.info(f"{class_name} processing take-control message.", **message_channel.identity) + vnc_channel = get_vnc_channel(message_channel.client_id) + + if not vnc_channel: + LOG.error( + f"{class_name} no vnc channel client found for message channel.", + message=message, + **message_channel.identity, + ) + return + vnc_channel.interactor = "user" + + case _: + LOG.error(f"{class_name} unknown message kind: '{message_kind}'", **message_channel.identity) + return + + async def frontend_to_backend() -> None: + nonlocal class_name + + LOG.info(f"{class_name} starting frontend-to-backend loop.", **message_channel.identity) + + while message_channel.is_open: + try: + datums = await message_channel.drain() + + for data in datums: + if not isinstance(data, dict): + LOG.error( + f"{class_name} cannot create message: expected dict, got {type(data)}", + **message_channel.identity, + ) + continue + + await handle_data(data) + + except WebSocketDisconnect: + LOG.info(f"{class_name} frontend disconnected.", **message_channel.identity) + raise + except ConnectionClosedError: + LOG.info(f"{class_name} frontend closed channel.", **message_channel.identity) + raise + except Exception: + LOG.exception(f"{class_name} An unexpected exception occurred.", **message_channel.identity) + raise + + loops = [ + asyncio.create_task(frontend_to_backend()), + ] + + try: + await collect(loops) + except Exception: + LOG.exception(f"{class_name} An exception occurred in loop message channel stream.", **message_channel.identity) + finally: + LOG.info(f"{class_name} Closing the message channel stream.", **message_channel.identity) + await message_channel.close(reason="loop-channel-closed") + + +async def get_message_channel_for_browser_session( + client_id: str, + browser_session_id: str, + organization_id: str, + websocket: WebSocket, +) -> tuple[MessageChannel, Loops] | None: + """ + Return a message channel for a browser session, with a list of loops to run concurrently. + """ + + LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id) + + browser_session = await verify_browser_session( + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + + if not browser_session: + LOG.info( + "Message channel: no initial browser session found.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + return None + + message_channel = MessageChannel( + client_id=client_id, + organization_id=organization_id, + browser_session=browser_session, + websocket=websocket, + ) + + LOG.info("Got message channel for browser session.", message_channel=message_channel) + + loops = [ + asyncio.create_task(loop_verify_browser_session(message_channel)), + asyncio.create_task(loop_stream_messages(message_channel)), + ] + + return message_channel, loops + + +async def get_message_channel_for_workflow_run( + client_id: str, + workflow_run_id: str, + organization_id: str, + websocket: WebSocket, +) -> tuple[MessageChannel, Loops] | None: + """ + Return a message channel for a workflow run, with a list of loops to run concurrently. + """ + + LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id) + + workflow_run, browser_session = await verify_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + + if not workflow_run: + LOG.info( + "Message channel: no initial workflow run found.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return None + + if not browser_session: + LOG.info( + "Message channel: no initial browser session found for workflow run.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return None + + message_channel = MessageChannel( + client_id, + organization_id, + browser_session=browser_session, + websocket=websocket, + workflow_run=workflow_run, + ) + + LOG.info("Got message channel for workflow run.", message_channel=message_channel) + + loops = [ + asyncio.create_task(loop_verify_workflow_run(message_channel)), + asyncio.create_task(loop_stream_messages(message_channel)), + ] + + return message_channel, loops diff --git a/skyvern/forge/sdk/routes/streaming/channels/vnc.py b/skyvern/forge/sdk/routes/streaming/channels/vnc.py new file mode 100644 index 00000000..a58bc662 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/channels/vnc.py @@ -0,0 +1,592 @@ +""" +A channel for streaming the VNC protocol data between our frontend and a +persistent browser instance. + +This is a pass-thru channel, through our API server. As such, we can monitor and/or +intercept RFB protocol messages as needed. + +What this channel looks like: + + [Skyvern App] <--> [API Server] <--> [websockified noVNC] <--> [Browser] + + +Channel data: + + One could call this RFB over WebSockets (rockets?), as the protocol data streaming + over the WebSocket is raw RFB protocol data. +""" + +import asyncio +import dataclasses +import typing as t +from enum import IntEnum +from urllib.parse import urlparse + +import structlog +import websockets +from fastapi import WebSocket, WebSocketDisconnect +from starlette.websockets import WebSocketState +from websockets import ConnectionClosedError, Data + +from skyvern.config import settings +from skyvern.forge.sdk.routes.streaming.auth import get_x_api_key +from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel +from skyvern.forge.sdk.routes.streaming.registries import ( + add_vnc_channel, + del_vnc_channel, + get_message_channel, + get_vnc_channel, +) +from skyvern.forge.sdk.routes.streaming.verify import ( + loop_verify_browser_session, + loop_verify_task, + loop_verify_workflow_run, + verify_browser_session, + verify_task, + verify_workflow_run, +) +from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession +from skyvern.forge.sdk.schemas.tasks import Task +from skyvern.forge.sdk.utils.aio import collect +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun + +LOG = structlog.get_logger() + + +Interactor = t.Literal["agent", "user"] +""" +NOTE: we don't really have an "agent" at this time. But any control of the +browser that is not user-originated is kinda' agent-like, by some +definition of "agent". Here, we do not have an "AI agent". Future work may +alter this state of affairs - and some "agent" could operate the browser +automatically. In any case, if the interactor is not a "user", we assume +it is an "agent". +""" + + +Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs" + + +class MessageType(IntEnum): + Keyboard = 4 + Mouse = 5 + + +class Keys: + """ + VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now. + + ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4 + ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf + """ + + class Down: + Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3" + Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9" + Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option + CKey = b"\x04\x01\x00\x00\x00\x00\x00c" + OKey = b"\x04\x01\x00\x00\x00\x00\x00o" + VKey = b"\x04\x01\x00\x00\x00\x00\x00v" + + class Up: + Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3" + Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9" + Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option + + +def is_rmb(data: bytes) -> bool: + return data[0:2] == b"\x05\x04" + + +class Mouse: + class Up: + Right = is_rmb + + +@dataclasses.dataclass +class KeyState: + ctrl_is_down: bool = False + alt_is_down: bool = False + cmd_is_down: bool = False + + def is_forbidden(self, data: bytes) -> bool: + """ + :return: True if the key is forbidden, else False + """ + return self.is_ctrl_o(data) + + def is_ctrl_o(self, data: bytes) -> bool: + """ + Do not allow the opening of files. + """ + return self.ctrl_is_down and data == Keys.Down.OKey + + def is_copy(self, data: bytes) -> bool: + """ + Detect Ctrl+C or Cmd+C for copy. + """ + return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey + + def is_paste(self, data: bytes) -> bool: + """ + Detect Ctrl+V or Cmd+V for paste. + """ + return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey + + +@dataclasses.dataclass +class VncChannel: + """ + A VNC channel for streaming RFB protocol data between our frontend app, and + a remote browser. + """ + + client_id: str + """ + Unique to a frontend app instance. + """ + + organization_id: str + vnc_port: int + x_api_key: str + websocket: WebSocket + + initial_interactor: dataclasses.InitVar[Interactor] + """ + The role of the entity interacting with the channel, either "agent" or "user". + """ + + _interactor: Interactor = dataclasses.field(init=False, repr=False) + + # -- + + browser_session: AddressablePersistentBrowserSession | None = None + key_state: KeyState = dataclasses.field(default_factory=KeyState) + task: Task | None = None + workflow_run: WorkflowRun | None = None + + def __post_init__(self, initial_interactor: Interactor) -> None: + self.interactor = initial_interactor + add_vnc_channel(self) + + @property + def class_name(self) -> str: + return self.__class__.__name__ + + @property + def identity(self) -> dict: + base = {"organization_id": self.organization_id} + + if self.task: + return base | {"task_id": self.task.task_id} + elif self.workflow_run: + return base | {"workflow_run_id": self.workflow_run.workflow_run_id} + elif self.browser_session: + return base | {"browser_session_id": self.browser_session.persistent_browser_session_id} + else: + return base | {"client_id": self.client_id} + + @property + def interactor(self) -> Interactor: + return self._interactor + + @interactor.setter + def interactor(self, value: Interactor) -> None: + self._interactor = value + + LOG.info(f"{self.class_name} Setting interactor to {value}", **self.identity) + + @property + def is_open(self) -> bool: + if self.websocket.client_state != WebSocketState.CONNECTED: + return False + + if not self.task and not self.workflow_run and not self.browser_session: + return False + + if not get_vnc_channel(self.client_id): + return False + + return True + + async def close(self, code: int = 1000, reason: str | None = None) -> t.Self: + LOG.info(f"{self.class_name} closing.", reason=reason, code=code, **self.identity) + + self.browser_session = None + self.task = None + self.workflow_run = None + + try: + await self.websocket.close(code=code, reason=reason) + except Exception: + pass + + del_vnc_channel(self.client_id) + + return self + + def update_key_state(self, data: bytes) -> None: + if data == Keys.Down.Ctrl: + self.key_state.ctrl_is_down = True + elif data == Keys.Up.Ctrl: + self.key_state.ctrl_is_down = False + elif data == Keys.Down.Alt: + self.key_state.alt_is_down = True + elif data == Keys.Up.Alt: + self.key_state.alt_is_down = False + elif data == Keys.Down.Cmd: + self.key_state.cmd_is_down = True + elif data == Keys.Up.Cmd: + self.key_state.cmd_is_down = False + + +async def copy_text(vnc_channel: VncChannel) -> None: + class_name = vnc_channel.class_name + LOG.info(f"{class_name} Retrieving selected text via CDP", **vnc_channel.identity) + + try: + async with execution_channel(vnc_channel) as execute: + copied_text = await execute.get_selected_text() + + message_channel = get_message_channel(vnc_channel.client_id) + + if message_channel: + await message_channel.send_copied_text(copied_text) + else: + LOG.warning( + f"{class_name} No message channel found for client, or it is not open", + message_channel=message_channel, + **vnc_channel.identity, + ) + except Exception: + LOG.exception(f"{class_name} Failed to retrieve selected text via CDP", **vnc_channel.identity) + + +async def ask_for_clipboard(vnc_channel: VncChannel) -> None: + class_name = vnc_channel.class_name + LOG.info(f"{class_name} Asking for clipboard data via CDP", **vnc_channel.identity) + + try: + message_channel = get_message_channel(vnc_channel.client_id) + + if message_channel: + await message_channel.ask_for_clipboard() + else: + LOG.warning( + f"{class_name} No message channel found for client, or it is not open", + message_channel=message_channel, + **vnc_channel.identity, + ) + except Exception: + LOG.exception(f"{class_name} Failed to ask for clipboard via CDP", **vnc_channel.identity) + + +async def loop_stream_vnc(vnc_channel: VncChannel) -> None: + """ + Actually stream the VNC data between a frontend and a browser. + + Loops until the task is cleared or the websocket is closed. + """ + + vnc_url: str = "" + browser_session = vnc_channel.browser_session + class_name = vnc_channel.class_name + + if browser_session: + if browser_session.ip_address: + if ":" in browser_session.ip_address: + ip, _ = browser_session.ip_address.split(":") + vnc_url = f"ws://{ip}:{vnc_channel.vnc_port}" + else: + vnc_url = f"ws://{browser_session.ip_address}:{vnc_channel.vnc_port}" + else: + browser_address = browser_session.browser_address + + parsed_browser_address = urlparse(browser_address) + host = parsed_browser_address.hostname + vnc_url = f"ws://{host}:{vnc_channel.vnc_port}" + else: + raise Exception(f"{class_name} No browser session associated with vnc channel.") + + # NOTE(jdo:streaming-local-dev) + # vnc_url = "ws://localhost:6080" + + LOG.info( + f"{class_name} Connecting to vnc url.", + vnc_url=vnc_url, + **vnc_channel.identity, + ) + + async with websockets.connect(vnc_url) as novnc_ws: + + async def frontend_to_browser() -> None: + nonlocal class_name + + LOG.info(f"{class_name} Starting frontend-to-browser data transfer.", **vnc_channel.identity) + data: Data | None = None + + while vnc_channel.is_open: + try: + data = await vnc_channel.websocket.receive_bytes() + + if data: + message_type = data[0] + + if message_type == MessageType.Keyboard.value: + vnc_channel.update_key_state(data) + + if vnc_channel.key_state.is_copy(data): + await copy_text(vnc_channel) + + if vnc_channel.key_state.is_paste(data): + await ask_for_clipboard(vnc_channel) + + if vnc_channel.key_state.is_forbidden(data): + continue + + # prevent right-mouse-button clicks for "security" reasons + if message_type == MessageType.Mouse.value: + if Mouse.Up.Right(data): + continue + + if not vnc_channel.interactor == "user" and message_type in ( + MessageType.Keyboard.value, + MessageType.Mouse.value, + ): + LOG.debug(f"{class_name} Blocking user message.", **vnc_channel.identity) + continue + + except WebSocketDisconnect: + LOG.info(f"{class_name} Frontend disconnected.", **vnc_channel.identity) + raise + except ConnectionClosedError: + LOG.info(f"{class_name} Frontend closed the vnc channel.", **vnc_channel.identity) + raise + except asyncio.CancelledError: + pass + except Exception: + LOG.exception(f"{class_name} An unexpected exception occurred.", **vnc_channel.identity) + raise + + if not data: + continue + + try: + await novnc_ws.send(data) + except WebSocketDisconnect: + LOG.info(f"{class_name} Browser disconnected from vnc.", **vnc_channel.identity) + raise + except ConnectionClosedError: + LOG.info(f"{class_name} Browser closed vnc.", **vnc_channel.identity) + raise + except asyncio.CancelledError: + pass + except Exception: + LOG.exception( + f"{class_name} An unexpected exception occurred in frontend-to-browser loop.", + **vnc_channel.identity, + ) + raise + + async def browser_to_frontend() -> None: + nonlocal class_name + + LOG.info(f"{class_name} Starting browser-to-frontend data transfer.", **vnc_channel.identity) + data: Data | None = None + + while vnc_channel.is_open: + try: + data = await novnc_ws.recv() + + except WebSocketDisconnect: + LOG.info(f"{class_name} Browser disconnected from the vnc channel session.", **vnc_channel.identity) + await vnc_channel.close(reason="browser-disconnected") + except ConnectionClosedError: + LOG.info(f"{class_name} Browser closed the vnc channel session.", **vnc_channel.identity) + await vnc_channel.close(reason="browser-closed") + except asyncio.CancelledError: + pass + except Exception: + LOG.exception( + f"{class_name} An unexpected exception occurred in browser-to-frontend loop.", + **vnc_channel.identity, + ) + raise + + if not data: + continue + + try: + if vnc_channel.websocket.client_state != WebSocketState.CONNECTED: + continue + await vnc_channel.websocket.send_bytes(data) + except WebSocketDisconnect: + LOG.info( + f"{class_name} Frontend disconnected from the vnc channel session.", **vnc_channel.identity + ) + await vnc_channel.close(reason="frontend-disconnected") + except ConnectionClosedError: + LOG.info(f"{class_name} Frontend closed the vnc channel session.", **vnc_channel.identity) + await vnc_channel.close(reason="frontend-closed") + except asyncio.CancelledError: + pass + except Exception: + LOG.exception(f"{class_name} An unexpected exception occurred.", **vnc_channel.identity) + raise + + loops = [ + asyncio.create_task(frontend_to_browser()), + asyncio.create_task(browser_to_frontend()), + ] + + try: + await collect(loops) + except WebSocketDisconnect: + pass + except Exception: + LOG.exception(f"{class_name} An exception occurred in loop stream.", **vnc_channel.identity) + finally: + LOG.info(f"{class_name} Closing the loop stream.", **vnc_channel.identity) + await vnc_channel.close(reason="loop-stream-vnc-closed") + + +async def get_vnc_channel_for_browser_session( + client_id: str, + browser_session_id: str, + organization_id: str, + websocket: WebSocket, +) -> tuple[VncChannel, Loops] | None: + """ + Return a vnc channel for a browser session, with a list of loops to run concurrently. + """ + + LOG.info("Getting vnc context for browser session.", browser_session_id=browser_session_id) + + browser_session = await verify_browser_session( + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + + if not browser_session: + LOG.info( + "No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id + ) + return None + + x_api_key = await get_x_api_key(organization_id) + + try: + vnc_channel = VncChannel( + client_id=client_id, + initial_interactor="agent", + organization_id=organization_id, + vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, + browser_session=browser_session, + x_api_key=x_api_key, + websocket=websocket, + ) + except Exception as e: + LOG.exception("Failed to create VncChannel.", error=str(e)) + return None + + LOG.info("Got vnc context for browser session.", vnc_channel=vnc_channel) + + loops = [ + asyncio.create_task(loop_verify_browser_session(vnc_channel)), + asyncio.create_task(loop_stream_vnc(vnc_channel)), + ] + + return vnc_channel, loops + + +async def get_vnc_channel_for_task( + client_id: str, + task_id: str, + organization_id: str, + websocket: WebSocket, +) -> tuple[VncChannel, Loops] | None: + """ + Return a vnc channel for a task, with a list of loops to run concurrently. + """ + + task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id) + + if not task: + LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id) + return None + + if not browser_session: + LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id) + return None + + x_api_key = await get_x_api_key(organization_id) + + vnc_channel = VncChannel( + client_id=client_id, + initial_interactor="agent", + organization_id=organization_id, + vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, + x_api_key=x_api_key, + websocket=websocket, + browser_session=browser_session, + task=task, + ) + + loops = [ + asyncio.create_task(loop_verify_task(vnc_channel)), + asyncio.create_task(loop_stream_vnc(vnc_channel)), + ] + + return vnc_channel, loops + + +async def get_vnc_channel_for_workflow_run( + client_id: str, + workflow_run_id: str, + organization_id: str, + websocket: WebSocket, +) -> tuple[VncChannel, Loops] | None: + """ + Return a vnc channel for a workflow run, with a list of loops to run concurrently. + """ + + LOG.info("Getting vnc channel for workflow run.", workflow_run_id=workflow_run_id) + + workflow_run, browser_session = await verify_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + + if not workflow_run: + LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id) + return None + + if not browser_session: + LOG.info( + "No initial browser session found for workflow run.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return None + + x_api_key = await get_x_api_key(organization_id) + + vnc_channel = VncChannel( + client_id=client_id, + initial_interactor="agent", + organization_id=organization_id, + vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, + browser_session=browser_session, + workflow_run=workflow_run, + x_api_key=x_api_key, + websocket=websocket, + ) + + LOG.info("Got vnc channel context for workflow run.", vnc_channel=vnc_channel) + + loops = [ + asyncio.create_task(loop_verify_workflow_run(vnc_channel)), + asyncio.create_task(loop_stream_vnc(vnc_channel)), + ] + + return vnc_channel, loops diff --git a/skyvern/forge/sdk/routes/streaming/clients.py b/skyvern/forge/sdk/routes/streaming/clients.py deleted file mode 100644 index 3ec90cd4..00000000 --- a/skyvern/forge/sdk/routes/streaming/clients.py +++ /dev/null @@ -1,328 +0,0 @@ -""" -Streaming types. -""" - -import asyncio -import dataclasses -import typing as t -from enum import IntEnum - -import structlog -from fastapi import WebSocket -from starlette.websockets import WebSocketState - -from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession -from skyvern.forge.sdk.schemas.tasks import Task -from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun - -LOG = structlog.get_logger() - - -Interactor = t.Literal["agent", "user"] -Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs" - - -# Messages - - -# a global registry for WS message clients -message_channels: dict[str, "MessageChannel"] = {} - - -def add_message_client(message_channel: "MessageChannel") -> None: - message_channels[message_channel.client_id] = message_channel - - -def get_message_client(client_id: str) -> t.Union["MessageChannel", None]: - return message_channels.get(client_id, None) - - -def del_message_client(client_id: str) -> None: - try: - del message_channels[client_id] - except KeyError: - pass - - -@dataclasses.dataclass -class MessageChannel: - client_id: str - organization_id: str - websocket: WebSocket - - # -- - - browser_session: AddressablePersistentBrowserSession | None = None - workflow_run: WorkflowRun | None = None - - def __post_init__(self) -> None: - add_message_client(self) - - async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel": - LOG.info("Closing message stream.", reason=reason, code=code) - - self.browser_session = None - self.workflow_run = None - - try: - await self.websocket.close(code=code, reason=reason) - except Exception: - pass - - del_message_client(self.client_id) - - return self - - @property - def is_open(self) -> bool: - if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING): - return False - - if not self.workflow_run and not self.browser_session: - return False - - if not get_message_client(self.client_id): - return False - - return True - - async def ask_for_clipboard(self, streaming: "Streaming") -> None: - try: - await self.websocket.send_json( - { - "kind": "ask-for-clipboard", - } - ) - LOG.info( - "Sent ask-for-clipboard to message channel", - organization_id=streaming.organization_id, - ) - except Exception: - LOG.exception( - "Failed to send ask-for-clipboard to message channel", - organization_id=streaming.organization_id, - ) - - async def send_copied_text(self, copied_text: str, streaming: "Streaming") -> None: - try: - await self.websocket.send_json( - { - "kind": "copied-text", - "text": copied_text, - } - ) - LOG.info( - "Sent copied text to message channel", - organization_id=streaming.organization_id, - ) - except Exception: - LOG.exception( - "Failed to send copied text to message channel", - organization_id=streaming.organization_id, - ) - - -MessageKinds = t.Literal["take-control", "cede-control", "ask-for-clipboard-response"] - - -@dataclasses.dataclass -class Message: - kind: MessageKinds - - -@dataclasses.dataclass -class MessageTakeControl(Message): - kind: t.Literal["take-control"] = "take-control" - - -@dataclasses.dataclass -class MessageCedeControl(Message): - kind: t.Literal["cede-control"] = "cede-control" - - -@dataclasses.dataclass -class MessageInAskForClipboardResponse(Message): - kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response" - text: str = "" - - -ChannelMessage = t.Union[MessageTakeControl, MessageCedeControl, MessageInAskForClipboardResponse] - - -def reify_channel_message(data: dict) -> ChannelMessage: - kind = data.get("kind", None) - - match kind: - case "take-control": - return MessageTakeControl() - case "cede-control": - return MessageCedeControl() - case "ask-for-clipboard-response": - text = data.get("text") or "" - return MessageInAskForClipboardResponse(text=text) - case _: - raise ValueError(f"Unknown message kind: '{kind}'") - - -# Streaming - - -# a global registry for WS streaming VNC clients -streaming_clients: dict[str, "Streaming"] = {} - - -def add_streaming_client(streaming: "Streaming") -> None: - streaming_clients[streaming.client_id] = streaming - - -def get_streaming_client(client_id: str) -> t.Union["Streaming", None]: - return streaming_clients.get(client_id, None) - - -def del_streaming_client(client_id: str) -> None: - try: - del streaming_clients[client_id] - except KeyError: - pass - - -class MessageType(IntEnum): - Keyboard = 4 - Mouse = 5 - - -class Keys: - """ - VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now. - - ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4 - ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf - """ - - class Down: - Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3" - Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9" - Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option - CKey = b"\x04\x01\x00\x00\x00\x00\x00c" - OKey = b"\x04\x01\x00\x00\x00\x00\x00o" - VKey = b"\x04\x01\x00\x00\x00\x00\x00v" - - class Up: - Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3" - Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9" - Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option - - -def is_rmb(data: bytes) -> bool: - return data[0:2] == b"\x05\x04" - - -class Mouse: - class Up: - Right = is_rmb - - -@dataclasses.dataclass -class KeyState: - ctrl_is_down: bool = False - alt_is_down: bool = False - cmd_is_down: bool = False - - def is_forbidden(self, data: bytes) -> bool: - """ - :return: True if the key is forbidden, else False - """ - return self.is_ctrl_o(data) - - def is_ctrl_o(self, data: bytes) -> bool: - """ - Do not allow the opening of files. - """ - return self.ctrl_is_down and data == Keys.Down.OKey - - def is_copy(self, data: bytes) -> bool: - """ - Detect Ctrl+C or Cmd+C for copy. - """ - return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey - - def is_paste(self, data: bytes) -> bool: - """ - Detect Ctrl+V or Cmd+V for paste. - """ - return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey - - -@dataclasses.dataclass -class Streaming: - """ - Streaming state. - """ - - client_id: str - """ - Unique to frontend app instance. - """ - - interactor: Interactor - """ - Whether the user or the agent are the interactor. - """ - - organization_id: str - vnc_port: int - x_api_key: str - websocket: WebSocket - - # -- - - browser_session: AddressablePersistentBrowserSession | None = None - key_state: KeyState = dataclasses.field(default_factory=KeyState) - task: Task | None = None - workflow_run: WorkflowRun | None = None - - def __post_init__(self) -> None: - add_streaming_client(self) - - @property - def is_open(self) -> bool: - if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING): - return False - - if not self.task and not self.workflow_run and not self.browser_session: - return False - - if not get_streaming_client(self.client_id): - return False - - return True - - async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming": - LOG.info("Closing Streaming.", reason=reason, code=code) - - self.browser_session = None - self.task = None - self.workflow_run = None - - try: - await self.websocket.close(code=code, reason=reason) - except Exception: - pass - - del_streaming_client(self.client_id) - - return self - - def update_key_state(self, data: bytes) -> None: - if data == Keys.Down.Ctrl: - self.key_state.ctrl_is_down = True - elif data == Keys.Up.Ctrl: - self.key_state.ctrl_is_down = False - elif data == Keys.Down.Alt: - self.key_state.alt_is_down = True - elif data == Keys.Up.Alt: - self.key_state.alt_is_down = False - elif data == Keys.Down.Cmd: - self.key_state.cmd_is_down = True - elif data == Keys.Up.Cmd: - self.key_state.cmd_is_down = False diff --git a/skyvern/forge/sdk/routes/streaming/messages.py b/skyvern/forge/sdk/routes/streaming/messages.py index 2bd6f0ef..775167ff 100644 --- a/skyvern/forge/sdk/routes/streaming/messages.py +++ b/skyvern/forge/sdk/routes/streaming/messages.py @@ -1,240 +1,24 @@ """ -Streaming messages for WebSocket connections. +Provides WS endpoints for streaming messages to/from our frontend application. """ -import asyncio - import structlog -from fastapi import WebSocket, WebSocketDisconnect -from websockets.exceptions import ConnectionClosedError +from fastapi import WebSocket -import skyvern.forge.sdk.routes.streaming.clients as sc from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router -from skyvern.forge.sdk.routes.streaming.agent import connected_agent from skyvern.forge.sdk.routes.streaming.auth import auth -from skyvern.forge.sdk.routes.streaming.verify import ( - loop_verify_browser_session, - loop_verify_workflow_run, - verify_browser_session, - verify_workflow_run, +from skyvern.forge.sdk.routes.streaming.channels.message import ( + Loops, + MessageChannel, + get_message_channel_for_browser_session, + get_message_channel_for_workflow_run, ) from skyvern.forge.sdk.utils.aio import collect LOG = structlog.get_logger() -async def get_messages_for_browser_session( - client_id: str, - browser_session_id: str, - organization_id: str, - websocket: WebSocket, -) -> tuple[sc.MessageChannel, sc.Loops] | None: - """ - Return a message channel for a browser session, with a list of loops to run concurrently. - """ - - LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id) - - browser_session = await verify_browser_session( - browser_session_id=browser_session_id, - organization_id=organization_id, - ) - - if not browser_session: - LOG.info( - "Message channel: no initial browser session found.", - browser_session_id=browser_session_id, - organization_id=organization_id, - ) - return None - - message_channel = sc.MessageChannel( - client_id=client_id, - organization_id=organization_id, - browser_session=browser_session, - websocket=websocket, - ) - - LOG.info("Got message channel for browser session.", message_channel=message_channel) - - loops = [ - asyncio.create_task(loop_verify_browser_session(message_channel)), - asyncio.create_task(loop_channel(message_channel)), - ] - - return message_channel, loops - - -async def get_messages_for_workflow_run( - client_id: str, - workflow_run_id: str, - organization_id: str, - websocket: WebSocket, -) -> tuple[sc.MessageChannel, sc.Loops] | None: - """ - Return a message channel for a workflow run, with a list of loops to run concurrently. - """ - - LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id) - - workflow_run, browser_session = await verify_workflow_run( - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ) - - if not workflow_run: - LOG.info( - "Message channel: no initial workflow run found.", - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ) - return None - - if not browser_session: - LOG.info( - "Message channel: no initial browser session found for workflow run.", - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ) - return None - - message_channel = sc.MessageChannel( - client_id=client_id, - organization_id=organization_id, - browser_session=browser_session, - workflow_run=workflow_run, - websocket=websocket, - ) - - LOG.info("Got message channel for workflow run.", message_channel=message_channel) - - loops = [ - asyncio.create_task(loop_verify_workflow_run(message_channel)), - asyncio.create_task(loop_channel(message_channel)), - ] - - return message_channel, loops - - -async def loop_channel(message_channel: sc.MessageChannel) -> None: - """ - Stream messages and their results back and forth. - - Loops until the workflow run is cleared or the websocket is closed. - """ - - if not message_channel.browser_session: - LOG.info( - "No browser session found for workflow run.", - workflow_run=message_channel.workflow_run, - organization_id=message_channel.organization_id, - ) - return - - async def frontend_to_backend() -> None: - LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel) - - while message_channel.is_open: - try: - data = await message_channel.websocket.receive_json() - - if not isinstance(data, dict): - LOG.error(f"Cannot create channel message: expected dict, got {type(data)}") - continue - - try: - message = sc.reify_channel_message(data) - except ValueError: - continue - - message_kind = message.kind - - match message_kind: - case "take-control": - streaming = sc.get_streaming_client(message_channel.client_id) - if not streaming: - LOG.error( - "No streaming client found for message.", - message_channel=message_channel, - message=message, - ) - continue - streaming.interactor = "user" - case "cede-control": - streaming = sc.get_streaming_client(message_channel.client_id) - if not streaming: - LOG.error( - "No streaming client found for message.", - message_channel=message_channel, - message=message, - ) - continue - streaming.interactor = "agent" - case "ask-for-clipboard-response": - if not isinstance(message, sc.MessageInAskForClipboardResponse): - LOG.error( - "Invalid message type for ask-for-clipboard-response.", - message_channel=message_channel, - message=message, - ) - continue - - streaming = sc.get_streaming_client(message_channel.client_id) - text = message.text - - async with connected_agent(streaming) as agent: - await agent.paste_text(text) - case _: - LOG.error(f"Unknown message kind: '{message_kind}'") - continue - - except WebSocketDisconnect: - LOG.info( - "Frontend disconnected.", - workflow_run=message_channel.workflow_run, - organization_id=message_channel.organization_id, - ) - raise - except ConnectionClosedError: - LOG.info( - "Frontend closed the streaming session.", - workflow_run=message_channel.workflow_run, - organization_id=message_channel.organization_id, - ) - raise - except asyncio.CancelledError: - pass - except Exception: - LOG.exception( - "An unexpected exception occurred.", - workflow_run=message_channel.workflow_run, - organization_id=message_channel.organization_id, - ) - raise - - loops = [ - asyncio.create_task(frontend_to_backend()), - ] - - try: - await collect(loops) - except Exception: - LOG.exception( - "An exception occurred in loop channel stream.", - workflow_run=message_channel.workflow_run, - organization_id=message_channel.organization_id, - ) - finally: - LOG.info( - "Closing the loop channel stream.", - workflow_run=message_channel.workflow_run, - organization_id=message_channel.organization_id, - ) - await message_channel.close(reason="loop-channel-closed") - - @base_router.websocket("/stream/messages/browser_session/{browser_session_id}") -@base_router.websocket("/stream/commands/browser_session/{browser_session_id}") async def browser_session_messages( websocket: WebSocket, browser_session_id: str, @@ -242,64 +26,16 @@ async def browser_session_messages( client_id: str | None = None, token: str | None = None, ) -> None: - LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id) - - organization_id = await auth(apikey=apikey, token=token, websocket=websocket) - - if not organization_id: - LOG.error("Authentication failed.", browser_session_id=browser_session_id) - return - - if not client_id: - LOG.error("No client ID provided.", browser_session_id=browser_session_id) - return - - message_channel: sc.MessageChannel - loops: list[asyncio.Task] = [] - - result = await get_messages_for_browser_session( - client_id=client_id, - browser_session_id=browser_session_id, - organization_id=organization_id, + return await messages( websocket=websocket, + browser_session_id=browser_session_id, + apikey=apikey, + client_id=client_id, + token=token, ) - if not result: - LOG.error( - "No streaming context found for the browser session.", - browser_session_id=browser_session_id, - organization_id=organization_id, - ) - await websocket.close(code=1013) - return - - message_channel, loops = result - - try: - LOG.info( - "Starting message stream loops for browser session.", - browser_session_id=browser_session_id, - organization_id=organization_id, - ) - await collect(loops) - except Exception: - LOG.exception( - "An exception occurred in the message stream function for browser session.", - browser_session_id=browser_session_id, - organization_id=organization_id, - ) - finally: - LOG.info( - "Closing the message stream session for browser session.", - browser_session_id=browser_session_id, - organization_id=organization_id, - ) - - await message_channel.close(reason="stream-closed") - @legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}") -@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}") async def workflow_run_messages( websocket: WebSocket, workflow_run_id: str, @@ -307,31 +43,77 @@ async def workflow_run_messages( client_id: str | None = None, token: str | None = None, ) -> None: - LOG.info("Starting message stream.", workflow_run_id=workflow_run_id) + return await messages( + websocket=websocket, + workflow_run_id=workflow_run_id, + apikey=apikey, + client_id=client_id, + token=token, + ) + + +async def messages( + websocket: WebSocket, + browser_session_id: str | None = None, + workflow_run_id: str | None = None, + apikey: str | None = None, + client_id: str | None = None, + token: str | None = None, +) -> None: + LOG.info( + "Starting message stream.", + browser_session_id=browser_session_id, + workflow_run_id=workflow_run_id, + ) organization_id = await auth(apikey=apikey, token=token, websocket=websocket) if not organization_id: - LOG.error("Authentication failed.", workflow_run_id=workflow_run_id) + LOG.error( + "Authentication failed.", + browser_session_id=browser_session_id, + workflow_run_id=workflow_run_id, + ) return if not client_id: - LOG.error("No client ID provided.", workflow_run_id=workflow_run_id) + LOG.error( + "No client ID provided.", + browser_session_id=browser_session_id, + workflow_run_id=workflow_run_id, + ) return - message_channel: sc.MessageChannel - loops: list[asyncio.Task] = [] + message_channel: MessageChannel + loops: Loops = [] - result = await get_messages_for_workflow_run( - client_id=client_id, - workflow_run_id=workflow_run_id, - organization_id=organization_id, - websocket=websocket, - ) + if browser_session_id: + result = await get_message_channel_for_browser_session( + client_id=client_id, + browser_session_id=browser_session_id, + organization_id=organization_id, + websocket=websocket, + ) + elif workflow_run_id: + result = await get_message_channel_for_workflow_run( + client_id=client_id, + workflow_run_id=workflow_run_id, + organization_id=organization_id, + websocket=websocket, + ) + else: + LOG.error( + "Message channel: no browser_session_id or workflow_run_id provided.", + client_id=client_id, + organization_id=organization_id, + ) + await websocket.close(code=1002) + return if not result: LOG.error( - "No streaming context found for the workflow run.", + "No message channel found.", + browser_session_id=browser_session_id, workflow_run_id=workflow_run_id, organization_id=organization_id, ) @@ -342,22 +124,25 @@ async def workflow_run_messages( try: LOG.info( - "Starting message stream loops.", + "Starting message channel loops.", + browser_session_id=browser_session_id, workflow_run_id=workflow_run_id, organization_id=organization_id, ) await collect(loops) except Exception: LOG.exception( - "An exception occurred in the message stream function.", + "An exception occurred in the message loop function(s).", + browser_session_id=browser_session_id, workflow_run_id=workflow_run_id, organization_id=organization_id, ) finally: LOG.info( - "Closing the message stream session.", + "Closing the message channel.", + browser_session_id=browser_session_id, workflow_run_id=workflow_run_id, organization_id=organization_id, ) - await message_channel.close(reason="stream-closed") + await message_channel.close(reason="message-stream-closed") diff --git a/skyvern/forge/sdk/routes/streaming/registries.py b/skyvern/forge/sdk/routes/streaming/registries.py new file mode 100644 index 00000000..689f0c6f --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/registries.py @@ -0,0 +1,78 @@ +""" +Contains registries for coordinating active WS connections (aka "channels", see +`./channels/README.md`). + +NOTE: in AWS we had to turn on what amounts to sticky sessions for frontend apps, +so that an individual frontend app instance is guaranteed to always connect to +the same backend api instance. This is beccause the two registries here are +tied together via a `client_id` string. + +The tale-of-the-tape is this: + - frontend app requires two different channels (WS connections) to the backend api + - one dedicated to streaming VNC's RFB protocol + - the other dedicated to messaging (JSON) + - both of these channels are stateful and need to coordinate with one another +""" + +from __future__ import annotations + +import typing as t + +import structlog + +if t.TYPE_CHECKING: + from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel + from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel + +LOG = structlog.get_logger() + + +# a registry for VNC channels, keyed by `client_id` +vnc_channels: dict[str, VncChannel] = {} + + +def add_vnc_channel(vnc_channel: VncChannel) -> None: + vnc_channels[vnc_channel.client_id] = vnc_channel + + +def get_vnc_channel(client_id: str) -> t.Union[VncChannel, None]: + return vnc_channels.get(client_id, None) + + +def del_vnc_channel(client_id: str) -> None: + try: + del vnc_channels[client_id] + except KeyError: + pass + + +# a registry for message channels, keyed by `client_id` +message_channels: dict[str, MessageChannel] = {} + + +def add_message_channel(message_channel: MessageChannel) -> None: + message_channels[message_channel.client_id] = message_channel + + +def get_message_channel(client_id: str) -> t.Union[MessageChannel, None]: + candidate = message_channels.get(client_id, None) + + if candidate and candidate.is_open: + return candidate + + if candidate: + LOG.info( + "MessageChannel: message channel is not open; deleting it", + client_id=candidate.client_id, + ) + + del_message_channel(candidate.client_id) + + return None + + +def del_message_channel(client_id: str) -> None: + try: + del message_channels[client_id] + except KeyError: + pass diff --git a/skyvern/forge/sdk/routes/streaming/screenshot.py b/skyvern/forge/sdk/routes/streaming/screenshot.py index bf9b8aee..9eb4ee80 100644 --- a/skyvern/forge/sdk/routes/streaming/screenshot.py +++ b/skyvern/forge/sdk/routes/streaming/screenshot.py @@ -1,3 +1,14 @@ +""" +Provides WS endpoints for streaming screenshots. + +Screenshot streaming is created on the basis of one of these database entities: + - task (run) + - workflow run + +Screenshot streaming is used for a run that is invoked without a browser session. +Otherwise, VNC streaming is used. +""" + import asyncio import base64 from datetime import datetime diff --git a/skyvern/forge/sdk/routes/streaming/verify.py b/skyvern/forge/sdk/routes/streaming/verify.py index 79893d4e..474155a3 100644 --- a/skyvern/forge/sdk/routes/streaming/verify.py +++ b/skyvern/forge/sdk/routes/streaming/verify.py @@ -1,18 +1,43 @@ +""" +Channels (see ./channels/README.md) variously rely on the state of one of +these database entities: + - browser session + - task + - workflow run + +That is, channels are created on the basis of one of those entities, and that +entity must be in a valid state for the channel to continue. + +So, in order to continue operating a channel, we need to periodically verify +that the underlying entity is still valid. This module provides logic to +perform those verifications. +""" + +from __future__ import annotations + import asyncio +import typing as t from datetime import datetime import structlog -import skyvern.forge.sdk.routes.streaming.clients as sc from skyvern.config import settings from skyvern.forge import app from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus +if t.TYPE_CHECKING: + from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel + from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel + LOG = structlog.get_logger() +class Constants: + POLL_INTERVAL_FOR_VERIFICATION_SECONDS = 5 + + async def verify_browser_session( browser_session_id: str, organization_id: str, @@ -122,9 +147,7 @@ async def verify_task( **browser_session.model_dump() | {"browser_address": browser_session.browser_address}, ) except Exception as e: - LOG.error( - "streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e - ) + LOG.error("vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e) return task, None return task, addressable_browser_session @@ -238,7 +261,7 @@ async def verify_workflow_run( return workflow_run, addressable_browser_session -async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streaming) -> None: +async def loop_verify_browser_session(verifiable: MessageChannel | VncChannel) -> None: """ Loop until the browser session is cleared or the websocket is closed. """ @@ -251,27 +274,27 @@ async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streami verifiable.browser_session = browser_session - await asyncio.sleep(2) + await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS) -async def loop_verify_task(streaming: sc.Streaming) -> None: +async def loop_verify_task(vnc_channel: VncChannel) -> None: """ Loop until the task is cleared or the websocket is closed. """ - while streaming.task and streaming.is_open: + while vnc_channel.task and vnc_channel.is_open: task, browser_session = await verify_task( - task_id=streaming.task.task_id, - organization_id=streaming.organization_id, + task_id=vnc_channel.task.task_id, + organization_id=vnc_channel.organization_id, ) - streaming.task = task - streaming.browser_session = browser_session + vnc_channel.task = task + vnc_channel.browser_session = browser_session - await asyncio.sleep(2) + await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS) -async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) -> None: +async def loop_verify_workflow_run(verifiable: MessageChannel | VncChannel) -> None: """ Loop until the workflow run is cleared or the websocket is closed. """ @@ -285,4 +308,4 @@ async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) verifiable.workflow_run = workflow_run verifiable.browser_session = browser_session - await asyncio.sleep(2) + await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS) diff --git a/skyvern/forge/sdk/routes/streaming/vnc.py b/skyvern/forge/sdk/routes/streaming/vnc.py index 9b1ad14f..de5e229f 100644 --- a/skyvern/forge/sdk/routes/streaming/vnc.py +++ b/skyvern/forge/sdk/routes/streaming/vnc.py @@ -1,5 +1,5 @@ """ -Streaming VNC WebSocket connections. +Provides WS endpoints for streaming a remote browser via VNC. NOTE(jdo:streaming-local-dev) ----------------------------- @@ -9,459 +9,23 @@ NOTE(jdo:streaming-local-dev) """ -import asyncio -import typing as t -from urllib.parse import urlparse - import structlog -import websockets -from fastapi import WebSocket, WebSocketDisconnect -from websockets import Data -from websockets.exceptions import ConnectionClosedError +from fastapi import WebSocket -import skyvern.forge.sdk.routes.streaming.clients as sc -from skyvern.config import settings -from skyvern.forge import app -from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router -from skyvern.forge.sdk.routes.streaming.agent import connected_agent from skyvern.forge.sdk.routes.streaming.auth import auth -from skyvern.forge.sdk.routes.streaming.verify import ( - loop_verify_browser_session, - loop_verify_task, - loop_verify_workflow_run, - verify_browser_session, - verify_task, - verify_workflow_run, +from skyvern.forge.sdk.routes.streaming.channels.vnc import ( + Loops, + VncChannel, + get_vnc_channel_for_browser_session, + get_vnc_channel_for_task, + get_vnc_channel_for_workflow_run, ) from skyvern.forge.sdk.utils.aio import collect LOG = structlog.get_logger() -class Constants: - MissingXApiKey = "" - - -async def get_x_api_key(organization_id: str) -> str: - token = await app.DATABASE.get_valid_org_auth_token( - organization_id, - OrganizationAuthTokenType.api.value, - ) - - if not token: - LOG.warning( - "No valid API key found for organization when streaming.", - organization_id=organization_id, - ) - x_api_key = Constants.MissingXApiKey - else: - x_api_key = token.token - - return x_api_key - - -async def get_streaming_for_browser_session( - client_id: str, - browser_session_id: str, - organization_id: str, - websocket: WebSocket, -) -> tuple[sc.Streaming, sc.Loops] | None: - """ - Return a streaming context for a browser session, with a list of loops to run concurrently. - """ - - LOG.info("Getting streaming context for browser session.", browser_session_id=browser_session_id) - - browser_session = await verify_browser_session( - browser_session_id=browser_session_id, - organization_id=organization_id, - ) - - if not browser_session: - LOG.info( - "No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id - ) - return None - - x_api_key = await get_x_api_key(organization_id) - - streaming = sc.Streaming( - client_id=client_id, - interactor="agent", - organization_id=organization_id, - vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, - browser_session=browser_session, - x_api_key=x_api_key, - websocket=websocket, - ) - - LOG.info("Got streaming context for browser session.", streaming=streaming) - - loops = [ - asyncio.create_task(loop_verify_browser_session(streaming)), - asyncio.create_task(loop_stream_vnc(streaming)), - ] - - return streaming, loops - - -async def get_streaming_for_task( - client_id: str, - task_id: str, - organization_id: str, - websocket: WebSocket, -) -> tuple[sc.Streaming, sc.Loops] | None: - """ - Return a streaming context for a task, with a list of loops to run concurrently. - """ - - task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id) - - if not task: - LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id) - return None - - if not browser_session: - LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id) - return None - - x_api_key = await get_x_api_key(organization_id) - - streaming = sc.Streaming( - client_id=client_id, - interactor="agent", - organization_id=organization_id, - vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, - x_api_key=x_api_key, - websocket=websocket, - browser_session=browser_session, - task=task, - ) - - loops = [ - asyncio.create_task(loop_verify_task(streaming)), - asyncio.create_task(loop_stream_vnc(streaming)), - ] - - return streaming, loops - - -async def get_streaming_for_workflow_run( - client_id: str, - workflow_run_id: str, - organization_id: str, - websocket: WebSocket, -) -> tuple[sc.Streaming, sc.Loops] | None: - """ - Return a streaming context for a workflow run, with a list of loops to run concurrently. - """ - - LOG.info("Getting streaming context for workflow run.", workflow_run_id=workflow_run_id) - - workflow_run, browser_session = await verify_workflow_run( - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ) - - if not workflow_run: - LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id) - return None - - if not browser_session: - LOG.info( - "No initial browser session found for workflow run.", - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ) - return None - - x_api_key = await get_x_api_key(organization_id) - - streaming = sc.Streaming( - client_id=client_id, - interactor="agent", - organization_id=organization_id, - vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, - browser_session=browser_session, - workflow_run=workflow_run, - x_api_key=x_api_key, - websocket=websocket, - ) - - LOG.info("Got streaming context for workflow run.", streaming=streaming) - - loops = [ - asyncio.create_task(loop_verify_workflow_run(streaming)), - asyncio.create_task(loop_stream_vnc(streaming)), - ] - - return streaming, loops - - -def verify_message_channel( - message_channel: sc.MessageChannel | None, streaming: sc.Streaming -) -> sc.MessageChannel | t.Literal[False]: - if message_channel and message_channel.is_open: - return message_channel - - LOG.warning( - "No message channel found for client, or it is not open", - message_channel=message_channel, - client_id=streaming.client_id, - organization_id=streaming.organization_id, - ) - - return False - - -async def copy_text(streaming: sc.Streaming) -> None: - try: - async with connected_agent(streaming) as agent: - copied_text = await agent.get_selected_text() - - LOG.info( - "Retrieved selected text via CDP", - organization_id=streaming.organization_id, - ) - - message_channel = sc.get_message_client(streaming.client_id) - - if cc := verify_message_channel(message_channel, streaming): - await cc.send_copied_text(copied_text, streaming) - else: - LOG.warning( - "No message channel found for client, or it is not open", - message_channel=message_channel, - client_id=streaming.client_id, - organization_id=streaming.organization_id, - ) - except Exception: - LOG.exception( - "Failed to retrieve selected text via CDP", - organization_id=streaming.organization_id, - ) - - -async def ask_for_clipboard(streaming: sc.Streaming) -> None: - try: - LOG.info( - "Asking for clipboard data via CDP", - organization_id=streaming.organization_id, - ) - - message_channel = sc.get_message_client(streaming.client_id) - - if cc := verify_message_channel(message_channel, streaming): - await cc.ask_for_clipboard(streaming) - except Exception: - LOG.exception( - "Failed to ask for clipboard via CDP", - organization_id=streaming.organization_id, - ) - - -async def loop_stream_vnc(streaming: sc.Streaming) -> None: - """ - Actually stream the VNC session data between a frontend and a browser - session. - - Loops until the task is cleared or the websocket is closed. - """ - - if not streaming.browser_session: - LOG.info("No browser session found for task.", task=streaming.task, organization_id=streaming.organization_id) - return - - vnc_url: str = "" - if streaming.browser_session.ip_address: - if ":" in streaming.browser_session.ip_address: - ip, _ = streaming.browser_session.ip_address.split(":") - vnc_url = f"ws://{ip}:{streaming.vnc_port}" - else: - vnc_url = f"ws://{streaming.browser_session.ip_address}:{streaming.vnc_port}" - else: - browser_address = streaming.browser_session.browser_address - - parsed_browser_address = urlparse(browser_address) - host = parsed_browser_address.hostname - vnc_url = f"ws://{host}:{streaming.vnc_port}" - - # NOTE(jdo:streaming-local-dev) - # vnc_url = "ws://localhost:9001/ws/novnc" - - LOG.info( - "Connecting to VNC URL.", - vnc_url=vnc_url, - task=streaming.task, - workflow_run=streaming.workflow_run, - organization_id=streaming.organization_id, - ) - - async with websockets.connect(vnc_url) as novnc_ws: - - async def frontend_to_browser() -> None: - LOG.info("Starting frontend-to-browser data transfer.", streaming=streaming) - data: Data | None = None - - while streaming.is_open: - try: - data = await streaming.websocket.receive_bytes() - - if data: - message_type = data[0] - - if message_type == sc.MessageType.Keyboard.value: - streaming.update_key_state(data) - - if streaming.key_state.is_copy(data): - await copy_text(streaming) - - if streaming.key_state.is_paste(data): - await ask_for_clipboard(streaming) - - if streaming.key_state.is_forbidden(data): - continue - - if message_type == sc.MessageType.Mouse.value: - if sc.Mouse.Up.Right(data): - continue - - if not streaming.interactor == "user" and message_type in ( - sc.MessageType.Keyboard.value, - sc.MessageType.Mouse.value, - ): - LOG.info( - "Blocking user message.", task=streaming.task, organization_id=streaming.organization_id - ) - continue - - except WebSocketDisconnect: - LOG.info("Frontend disconnected.", task=streaming.task, organization_id=streaming.organization_id) - raise - except ConnectionClosedError: - LOG.info( - "Frontend closed the streaming session.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - raise - except asyncio.CancelledError: - pass - except Exception: - LOG.exception( - "An unexpected exception occurred.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - raise - - if not data: - continue - - try: - await novnc_ws.send(data) - except WebSocketDisconnect: - LOG.info( - "Browser disconnected from the streaming session.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - raise - except ConnectionClosedError: - LOG.info( - "Browser closed the streaming session.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - raise - except asyncio.CancelledError: - pass - except Exception: - LOG.exception( - "An unexpected exception occurred in frontend-to-browser loop.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - raise - - async def browser_to_frontend() -> None: - LOG.info("Starting browser-to-frontend data transfer.", streaming=streaming) - data: Data | None = None - - while streaming.is_open: - try: - data = await novnc_ws.recv() - - except WebSocketDisconnect: - LOG.info( - "Browser disconnected from the streaming session.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - await streaming.close(reason="browser-disconnected") - except ConnectionClosedError: - LOG.info( - "Browser closed the streaming session.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - await streaming.close(reason="browser-closed") - except asyncio.CancelledError: - pass - except Exception: - LOG.exception( - "An unexpected exception occurred in browser-to-frontend loop.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - raise - - if not data: - continue - - try: - await streaming.websocket.send_bytes(data) - except WebSocketDisconnect: - LOG.info( - "Frontend disconnected from the streaming session.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - await streaming.close(reason="frontend-disconnected") - except ConnectionClosedError: - LOG.info( - "Frontend closed the streaming session.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - await streaming.close(reason="frontend-closed") - except asyncio.CancelledError: - pass - except Exception: - LOG.exception( - "An unexpected exception occurred.", - task=streaming.task, - organization_id=streaming.organization_id, - ) - raise - - loops = [ - asyncio.create_task(frontend_to_browser()), - asyncio.create_task(browser_to_frontend()), - ] - - try: - await collect(loops) - except Exception: - LOG.exception( - "An exception occurred in loop stream.", task=streaming.task, organization_id=streaming.organization_id - ) - finally: - LOG.info("Closing the loop stream.", task=streaming.task, organization_id=streaming.organization_id) - await streaming.close(reason="loop-stream-vnc-closed") - - @base_router.websocket("/stream/vnc/browser_session/{browser_session_id}") async def browser_session_stream( websocket: WebSocket, @@ -507,7 +71,7 @@ async def stream( ) -> None: if not client_id: LOG.error( - "Client ID not provided for VNC stream.", + "Client ID not provided for vnc stream.", browser_session_id=browser_session_id, task_id=task_id, workflow_run_id=workflow_run_id, @@ -515,7 +79,7 @@ async def stream( return LOG.info( - "Starting VNC stream.", + "Starting vnc stream.", browser_session_id=browser_session_id, client_id=client_id, task_id=task_id, @@ -528,11 +92,11 @@ async def stream( LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id) return - streaming: sc.Streaming - loops: list[asyncio.Task] = [] + vnc_channel: VncChannel + loops: Loops if browser_session_id: - result = await get_streaming_for_browser_session( + result = await get_vnc_channel_for_browser_session( client_id=client_id, browser_session_id=browser_session_id, organization_id=organization_id, @@ -541,22 +105,22 @@ async def stream( if not result: LOG.error( - "No streaming context found for the browser session.", + "No vnc context found for the browser session.", browser_session_id=browser_session_id, organization_id=organization_id, ) await websocket.close(code=1013) return - streaming, loops = result + vnc_channel, loops = result LOG.info( - "Starting streaming for browser session.", + "Starting vnc for browser session.", browser_session_id=browser_session_id, organization_id=organization_id, ) elif task_id: - result = await get_streaming_for_task( + result = await get_vnc_channel_for_task( client_id=client_id, task_id=task_id, organization_id=organization_id, @@ -564,19 +128,17 @@ async def stream( ) if not result: - LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id) + LOG.error("No vnc context found for the task.", task_id=task_id, organization_id=organization_id) await websocket.close(code=1013) return - streaming, loops = result + vnc_channel, loops = result - LOG.info("Starting streaming for task.", task_id=task_id, organization_id=organization_id) + LOG.info("Starting vnc for task.", task_id=task_id, organization_id=organization_id) elif workflow_run_id: - LOG.info( - "Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id - ) - result = await get_streaming_for_workflow_run( + LOG.info("Starting vnc for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id) + result = await get_vnc_channel_for_workflow_run( client_id=client_id, workflow_run_id=workflow_run_id, organization_id=organization_id, @@ -585,25 +147,23 @@ async def stream( if not result: LOG.error( - "No streaming context found for the workflow run.", + "No vnc context found for the workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id, ) await websocket.close(code=1013) return - streaming, loops = result + vnc_channel, loops = result - LOG.info( - "Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id - ) + LOG.info("Starting vnc for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id) else: LOG.error("Neither task ID nor workflow run ID was provided.") return try: LOG.info( - "Starting streaming loops.", + "Starting vnc loops.", task_id=task_id, workflow_run_id=workflow_run_id, organization_id=organization_id, @@ -611,16 +171,16 @@ async def stream( await collect(loops) except Exception: LOG.exception( - "An exception occurred in the stream function.", + "An exception occurred in the vnc loop.", task_id=task_id, workflow_run_id=workflow_run_id, organization_id=organization_id, ) finally: LOG.info( - "Closing the streaming session.", + "Closing the vnc session.", task_id=task_id, workflow_run_id=workflow_run_id, organization_id=organization_id, ) - await streaming.close(reason="stream-closed") + await vnc_channel.close(reason="vnc-closed") diff --git a/skyvern/py.typed b/skyvern/py.typed new file mode 100644 index 00000000..e69de29b