From f00e82c1bb97a97dfa0076580512410d08eb3142 Mon Sep 17 00:00:00 2001 From: Jonathan Dobson Date: Tue, 25 Nov 2025 13:54:29 -0500 Subject: [PATCH] Browser Exfiltration (#4093) --- .../sdk/routes/streaming/channels/cdp.py | 16 ++ .../routes/streaming/channels/exfiltration.py | 235 ++++++++++++++- .../routes/streaming/channels/js/decorate.js | 95 ++++++ .../streaming/channels/js/exfiltrate.js | 165 +++++++++++ .../streaming/channels/js/undecorate.js | 9 + .../sdk/routes/streaming/channels/message.py | 272 ++++++++++++++---- 6 files changed, 738 insertions(+), 54 deletions(-) create mode 100644 skyvern/forge/sdk/routes/streaming/channels/js/decorate.js create mode 100644 skyvern/forge/sdk/routes/streaming/channels/js/exfiltrate.js create mode 100644 skyvern/forge/sdk/routes/streaming/channels/js/undecorate.js diff --git a/skyvern/forge/sdk/routes/streaming/channels/cdp.py b/skyvern/forge/sdk/routes/streaming/channels/cdp.py index e65fe3df..76e9bae1 100644 --- a/skyvern/forge/sdk/routes/streaming/channels/cdp.py +++ b/skyvern/forge/sdk/routes/streaming/channels/cdp.py @@ -13,6 +13,8 @@ Channel data: from __future__ import annotations import asyncio +import functools +import pathlib import typing as t import structlog @@ -183,3 +185,17 @@ class CdpChannel: except Exception: LOG.exception(f"{self.class_name} failed to evaluate js", expression=expression, **self.identity) raise + + @functools.lru_cache(maxsize=None) + def js(self, file_name: str) -> str: + base_path = pathlib.Path(__file__).parent / "js" + file_name = file_name.lstrip("/") + + if not file_name.endswith(".js"): + file_name += ".js" + + relative_path = pathlib.Path(file_name) + full_path = base_path / relative_path + + with open(full_path, encoding="utf-8") as f: + return f.read() diff --git a/skyvern/forge/sdk/routes/streaming/channels/exfiltration.py b/skyvern/forge/sdk/routes/streaming/channels/exfiltration.py index d352c7e8..b76cefc5 100644 --- a/skyvern/forge/sdk/routes/streaming/channels/exfiltration.py +++ b/skyvern/forge/sdk/routes/streaming/channels/exfiltration.py @@ -1 +1,234 @@ -# stub +""" +This channel exfiltrates all user activity in a browser. + +What this channel looks like: + + [Skyvern App] <-- [API Server] <--> [Browser (CDP)] + +Channel data: + + Raw JavaScript events (as JSON) over WebSockets. +""" + +import asyncio +import dataclasses +import enum +import json +import typing as t + +import structlog +from playwright.async_api import CDPSession, ConsoleMessage, Page + +from skyvern.forge.sdk.routes.streaming.channels.cdp import CdpChannel +from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel + +LOG = structlog.get_logger() + + +class ExfiltratedEventSource(enum.Enum): + CONSOLE = "console" + CDP = "cdp" + NOT_SPECIFIED = "[not-specified]" + + +@dataclasses.dataclass +class ExfiltratedEvent: + kind: t.Literal["exfiltrated-event"] = "exfiltrated-event" + event_name: str = "[not-specified]" + + # TODO(jdo): improve typing for params + params: dict = dataclasses.field(default_factory=dict) + source: ExfiltratedEventSource = ExfiltratedEventSource.NOT_SPECIFIED + + +OnExfiltrationEvent = t.Callable[[list[ExfiltratedEvent]], None] + + +class ExfiltrationChannel(CdpChannel): + """ + ExfiltrationChannel. + """ + + def __init__(self, *, on_event: OnExfiltrationEvent, vnc_channel: VncChannel) -> None: + self.cdp_session: CDPSession | None = None + self.on_event = on_event + + super().__init__(vnc_channel=vnc_channel) + + def _handle_console_event(self, msg: ConsoleMessage) -> None: + """Parse console messages for exfiltrated event data.""" + text = msg.text + if text.startswith("[EXFIL]"): + try: + event_data = json.loads(text[7:]) # Strip '[EXFIL]' prefix + + messages = [ + ExfiltratedEvent( + kind="exfiltrated-event", + event_name="user_interaction", + params=event_data, + source=ExfiltratedEventSource.CONSOLE, + ), + ] + + self.on_event(messages) + except Exception: + LOG.exception(f"{self.class_name} Failed to parse exfiltrated event", text=text) + + def _handle_cdp_event(self, event_name: str, params: dict) -> None: + LOG.debug(f"{self.class_name} cdp event captured: {event_name}", params=params) + + messages = [ + ExfiltratedEvent( + kind="exfiltrated-event", + event_name=event_name, + params=params, + source=ExfiltratedEventSource.CDP, + ), + ] + + self.on_event(messages) + + async def connect(self, cdp_url: str | None = None) -> t.Self: + if self.browser and self.browser.is_connected() and self.cdp_session: + return self + + await super().connect(cdp_url) + + # NOTE(jdo:streaming-local-dev) + # from skyvern.config import settings + # await super().connect(cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL) + + page = self.page + + if not page: + raise RuntimeError(f"{self.class_name} No page available after connecting to browser.") + + self.cdp_session = await page.context.new_cdp_session(page) + + return self + + async def exfiltrate(self, page: Page) -> t.Self: + """ + Track user interactions and send to console for CDP to capture. + """ + + LOG.info(f"{self.class_name} setting up exfiltration on new page.", url=page.url) + + page.on("console", self._handle_console_event) + + await page.evaluate(self.js("exfiltrate")) + + LOG.info(f"{self.class_name} setup complete on page.", url=page.url) + + return self + + async def decorate(self, page: Page) -> t.Self: + """Add a mouse-following follower to the page.""" + LOG.info(f"{self.class_name} adding decoration to page.", url=page.url) + + await page.evaluate(self.js("decorate")) + + LOG.info(f"{self.class_name} decoration setup complete on page.", url=page.url) + + return self + + async def undecorate(self, page: Page) -> t.Self: + """Remove the mouse-following follower from the page.""" + LOG.info(f"{self.class_name} removing decoration from page.", url=page.url) + + await page.evaluate(self.js("undecorate")) + + LOG.info(f"{self.class_name} decoration removed from page.", url=page.url) + + return self + + async def enable_cdp_events(self) -> t.Self: + await self.connect() + + cdp_session = self.cdp_session + + if not cdp_session: + raise RuntimeError(f"{self.class_name} No CDP session available to enable events.") + + enables = [ + cdp_session.send("Runtime.enable"), + cdp_session.send("DOM.enable"), + cdp_session.send("Page.enable"), + cdp_session.send("Target.setDiscoverTargets", {"discover": True}), + ] + + await asyncio.gather(*enables) + + # listen to CDP events for tab management and navigation + cdp_session.on("Target.targetCreated", lambda params: self._handle_cdp_event("target_created", params)) + cdp_session.on("Target.targetDestroyed", lambda params: self._handle_cdp_event("target_destroyed", params)) + cdp_session.on("Target.targetInfoChanged", lambda params: self._handle_cdp_event("target_info_changed", params)) + cdp_session.on("Page.frameNavigated", lambda params: self._handle_cdp_event("frame_navigated", params)) + cdp_session.on( + "Page.navigatedWithinDocument", lambda params: self._handle_cdp_event("navigated_within_document", params) + ) + + return self + + def enable_console_events(self) -> t.Self: + browser_context = self.browser_context + + if not browser_context: + LOG.error(f"{self.class_name} no browser context to enable console events.") + return self + + for page in browser_context.pages: + asyncio.create_task(self.exfiltrate(page)) + + browser_context.on("page", lambda page: asyncio.create_task(self.exfiltrate(page))) + + return self + + def enable_decoration(self) -> t.Self: + browser_context = self.browser_context + + if not browser_context: + LOG.error(f"{self.class_name} no browser context to enable decoration.") + return self + + for page in browser_context.pages: + asyncio.create_task(self.decorate(page)) + + browser_context.on("page", lambda page: asyncio.create_task(self.decorate(page))) + + return self + + async def start(self) -> t.Self: + LOG.info(f"{self.class_name} starting.") + + await self.enable_cdp_events() + + self.enable_console_events() + + self.enable_decoration() + + return self + + async def stop(self) -> t.Self: + LOG.info(f"{self.class_name} stopping.") + + if not self.cdp_session: + return self + + try: + await self.cdp_session.detach() + except Exception: + pass + + self.cdp_session = None + + pages = self.browser_context.pages if self.browser_context else [] + + for page in pages: + page.remove_listener("console", self._handle_console_event) + await self.undecorate(page) + + LOG.info(f"{self.class_name} stopped.") + + return self diff --git a/skyvern/forge/sdk/routes/streaming/channels/js/decorate.js b/skyvern/forge/sdk/routes/streaming/channels/js/decorate.js new file mode 100644 index 00000000..1228d77c --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/channels/js/decorate.js @@ -0,0 +1,95 @@ +(function () { + if (!window.__skyvern_decoration_initialized) { + window.__skyvern_decoration_initialized = true; + + window.__skyvern_create_mouse_follower = function () { + // create the circle element + const circle = document.createElement("div"); + window.__skyvern_decoration_mouse_follower = circle; + circle.id = "__skyvern_mouse_follower"; + circle.style.position = "fixed"; + circle.style.left = "0"; + circle.style.top = "0"; + circle.style.width = "30px"; + circle.style.height = "30px"; + circle.style.borderRadius = "50%"; + circle.style.backgroundColor = "rgba(255, 0, 0, 0.2)"; + circle.style.pointerEvents = "none"; + circle.style.zIndex = "999999"; + circle.style.willChange = "transform"; + document.body.appendChild(circle); + }; + + window.__skyvern_create_mouse_follower(); + + let scale = 1; + let targetScale = 1; + let mouseX = 0; + let mouseY = 0; + + // smooth scale animation + function animate() { + if (!window.__skyvern_decoration_mouse_follower) { + return; + } + + const follower = window.__skyvern_decoration_mouse_follower; + + scale += (targetScale - scale) * 0.2; + + if (Math.abs(targetScale - scale) > 0.001) { + requestAnimationFrame(animate); + } else { + scale = targetScale; + } + + follower.style.transform = `translate(${mouseX - 15}px, ${mouseY - 15}px) scale(${scale})`; + } + + // update follower position on mouse move + document.addEventListener( + "mousemove", + (e) => { + if (!window.__skyvern_decoration_mouse_follower) { + return; + } + + const follower = window.__skyvern_decoration_mouse_follower; + mouseX = e.clientX; + mouseY = e.clientY; + follower.style.transform = `translate(${mouseX - 15}px, ${mouseY - 15}px) scale(${scale})`; + }, + true, + ); + + // expand follower on mouse down + document.addEventListener( + "mousedown", + () => { + if (!window.__skyvern_decoration_mouse_follower) { + return; + } + + targetScale = 50 / 30; + requestAnimationFrame(animate); + }, + true, + ); + + // return follower to original size on mouse up + document.addEventListener( + "mouseup", + () => { + if (!window.__skyvern_decoration_mouse_follower) { + return; + } + + targetScale = 1; + requestAnimationFrame(animate); + }, + true, + ); + } else { + window.__skyvern_create_mouse_follower(); + } +})(); diff --git a/skyvern/forge/sdk/routes/streaming/channels/js/exfiltrate.js b/skyvern/forge/sdk/routes/streaming/channels/js/exfiltrate.js new file mode 100644 index 00000000..280ac82b --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/channels/js/exfiltrate.js @@ -0,0 +1,165 @@ +(function () { + if (!window.__skyvern_exfiltration_initialized) { + window.__skyvern_exfiltration_initialized = true; + + [ + "click", + "mousedown", + "mouseup", + "mouseenter", + "mouseleave", + "keydown", + "keyup", + "keypress", + "focus", + "blur", + "input", + "change", + "scroll", + "contextmenu", + "dblclick", + ].forEach((eventType) => { + document.addEventListener( + eventType, + (e) => { + // find associated labels + const getAssociatedLabels = (element) => { + const labels = []; + + // label with 'for' attribute matching element's id + if (element.id) { + const labelsByFor = document.querySelectorAll( + `label[for="${element.id}"]`, + ); + + labelsByFor.forEach((label) => { + if (label.textContent) labels.push(label.textContent.trim()); + }); + } + + // label wrapping the element + let parent = element.parentElement; + + while (parent) { + if (parent.tagName === "LABEL") { + if (parent.textContent) labels.push(parent.textContent.trim()); + break; + } + parent = parent.parentElement; + } + + return labels.length > 0 ? labels : null; + }; + + // get any kind of text content + const getElementText = (element) => { + const textSources = []; + + if (element.getAttribute("aria-label")) { + textSources.push(element.getAttribute("aria-label")); + } + + if (element.getAttribute("aria-labelledby")) { + const labelIds = element + .getAttribute("aria-labelledby") + .split(" "); + + labelIds.forEach((id) => { + const labelElement = document.getElementById(id); + + if (labelElement?.textContent) { + textSources.push(labelElement.textContent.trim()); + } + }); + } + + if (element.getAttribute("placeholder")) { + textSources.push(element.getAttribute("placeholder")); + } + + if (element.getAttribute("title")) { + textSources.push(element.getAttribute("title")); + } + + if (element.getAttribute("alt")) { + textSources.push(element.getAttribute("alt")); + } + + if (element.innerText) { + textSources.push(element.innerText.substring(0, 100)); + } + + if (!element.innerText && element.textContent) { + textSources.push(element.textContent.trim().substring(0, 100)); + } + + return textSources.length > 0 ? textSources : []; + }; + + const eventData = { + url: window.location.href, + type: eventType, + timestamp: Date.now(), + target: { + tagName: e.target?.tagName, + id: e.target?.id, + className: e.target?.className, + value: e.target?.value, + text: getElementText(e.target), + labels: getAssociatedLabels(e.target), + }, + inputValue: ["input", "focus", "blur"].includes(eventType) + ? e.target?.value + : undefined, + mousePosition: { + xa: e.clientX, + ya: e.clientY, + xp: e.clientX / window.innerWidth, + yp: e.clientY / window.innerHeight, + }, + key: e.key, + code: e.code, + activeElement: { + tagName: document.activeElement?.tagName, + id: document.activeElement?.id, + className: document.activeElement?.className, + boundingRect: document.activeElement?.getBoundingClientRect + ? { + x: document.activeElement.getBoundingClientRect().x, + y: document.activeElement.getBoundingClientRect().y, + width: document.activeElement.getBoundingClientRect().width, + height: + document.activeElement.getBoundingClientRect().height, + top: document.activeElement.getBoundingClientRect().top, + right: document.activeElement.getBoundingClientRect().right, + bottom: + document.activeElement.getBoundingClientRect().bottom, + left: document.activeElement.getBoundingClientRect().left, + } + : null, + scroll: document.activeElement + ? { + scrollTop: document.activeElement.scrollTop, + scrollLeft: document.activeElement.scrollLeft, + scrollHeight: document.activeElement.scrollHeight, + scrollWidth: document.activeElement.scrollWidth, + clientHeight: document.activeElement.clientHeight, + clientWidth: document.activeElement.clientWidth, + } + : null, + }, + window: { + width: window.innerWidth, + height: window.innerHeight, + scrollX: window.scrollX, + scrollY: window.scrollY, + }, + }; + + console.log("[EXFIL]", JSON.stringify(eventData)); + }, + true, + ); + }); + } +})(); diff --git a/skyvern/forge/sdk/routes/streaming/channels/js/undecorate.js b/skyvern/forge/sdk/routes/streaming/channels/js/undecorate.js new file mode 100644 index 00000000..9f8cd13f --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/channels/js/undecorate.js @@ -0,0 +1,9 @@ +(function () { + const followers = document.querySelectorAll("#__skyvern_mouse_follower"); + + for (const follower of followers) { + follower.remove(); + } + + window.__skyvern_decoration_mouse_follower = null; +})(); diff --git a/skyvern/forge/sdk/routes/streaming/channels/message.py b/skyvern/forge/sdk/routes/streaming/channels/message.py index 5e12dfaf..c95cc9f1 100644 --- a/skyvern/forge/sdk/routes/streaming/channels/message.py +++ b/skyvern/forge/sdk/routes/streaming/channels/message.py @@ -14,6 +14,7 @@ Channel data: import asyncio import dataclasses +import enum import typing as t import structlog @@ -22,6 +23,7 @@ from starlette.websockets import WebSocketState from websockets.exceptions import ConnectionClosedError from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel +from skyvern.forge.sdk.routes.streaming.channels.exfiltration import ExfiltratedEvent, ExfiltrationChannel from skyvern.forge.sdk.routes.streaming.registries import ( add_message_channel, del_message_channel, @@ -42,10 +44,42 @@ LOG = structlog.get_logger() Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs" +class MessageKind(enum.StrEnum): + ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response" + BEGIN_EXFILTRATION = "begin-exfiltration" + BROWSER_TABS = "browser-tabs" + CEDE_CONTROL = "cede-control" + END_EXFILTRATION = "end-exfiltration" + EXFILTRATED_EVENT = "exfiltrated-event" + TAKE_CONTROL = "take-control" + + +class ExfiltratedEventSource(enum.StrEnum): + CONSOLE = "console" + CDP = "cdp" + NOT_SPECIFIED = "[not-specified]" + + +@dataclasses.dataclass +class TabInfo: + id: str + title: str + url: str + # -- + active: bool = False + favicon: str | None = None + isReady: bool = True + pageNumber: int | None = None + + MessageKinds = t.Literal[ - "ask-for-clipboard-response", - "cede-control", - "take-control", + MessageKind.ASK_FOR_CLIPBOARD_RESPONSE, + MessageKind.BEGIN_EXFILTRATION, + MessageKind.BROWSER_TABS, + MessageKind.CEDE_CONTROL, + MessageKind.END_EXFILTRATION, + MessageKind.EXFILTRATED_EVENT, + MessageKind.TAKE_CONTROL, ] @@ -54,44 +88,95 @@ class Message: kind: MessageKinds +@dataclasses.dataclass +class MessageInBeginExfiltration(Message): + kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION + + +@dataclasses.dataclass +class MessageInEndExfiltration(Message): + kind: t.Literal[MessageKind.END_EXFILTRATION] = MessageKind.END_EXFILTRATION + + @dataclasses.dataclass class MessageInTakeControl(Message): - kind: t.Literal["take-control"] = "take-control" + kind: t.Literal[MessageKind.TAKE_CONTROL] = MessageKind.TAKE_CONTROL @dataclasses.dataclass class MessageInCedeControl(Message): - kind: t.Literal["cede-control"] = "cede-control" + kind: t.Literal[MessageKind.CEDE_CONTROL] = MessageKind.CEDE_CONTROL @dataclasses.dataclass class MessageInAskForClipboardResponse(Message): - kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response" + kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE text: str = "" -ChannelMessage = t.Union[ - MessageInAskForClipboardResponse, - MessageInCedeControl, - MessageInTakeControl, -] +@dataclasses.dataclass +class MessageOutExfiltratedEvent(Message): + kind: t.Literal[MessageKind.EXFILTRATED_EVENT] = MessageKind.EXFILTRATED_EVENT + event_name: str = "[not-specified]" + + # TODO(jdo): improve typing for params + params: dict = dataclasses.field(default_factory=dict) + source: ExfiltratedEventSource = ExfiltratedEventSource.NOT_SPECIFIED + + +@dataclasses.dataclass +class MessageOutTabInfo(Message): + kind: t.Literal[MessageKind.BROWSER_TABS] = MessageKind.BROWSER_TABS + tabs: list[TabInfo] = dataclasses.field(default_factory=list) + + +MessageIn = ( + MessageInAskForClipboardResponse + | MessageInBeginExfiltration + | MessageInCedeControl + | MessageInEndExfiltration + | MessageInTakeControl +) + + +MessageOut = MessageOutExfiltratedEvent | MessageOutTabInfo + + +ChannelMessage = MessageIn | MessageOut def reify_channel_message(data: dict) -> ChannelMessage: kind = data.get("kind", None) match kind: - case "ask-for-clipboard-response": + case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE: text = data.get("text") or "" return MessageInAskForClipboardResponse(text=text) - case "cede-control": + case MessageKind.BEGIN_EXFILTRATION: + return MessageInBeginExfiltration() + case MessageKind.CEDE_CONTROL: return MessageInCedeControl() - case "take-control": + case MessageKind.END_EXFILTRATION: + return MessageInEndExfiltration() + case MessageKind.TAKE_CONTROL: return MessageInTakeControl() case _: raise ValueError(f"Unknown message kind: '{kind}'") +def message_to_dict(message: MessageOut) -> dict: + """ + Convert message to dict with enums as their values. + """ + + def convert_value(obj: t.Any) -> t.Any: + if isinstance(obj, enum.Enum): + return obj.value + return obj + + return dataclasses.asdict(message, dict_factory=lambda x: {k: convert_value(v) for k, v in x}) + + @dataclasses.dataclass class MessageChannel: """ @@ -102,7 +187,7 @@ class MessageChannel: organization_id: str websocket: WebSocket # -- - out_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) # warn: unbounded + out_queue: asyncio.Queue[MessageOut] = dataclasses.field(default_factory=asyncio.Queue) # warn: unbounded browser_session: AddressablePersistentBrowserSession | None = None workflow_run: WorkflowRun | None = None @@ -147,18 +232,31 @@ class MessageChannel: return True - async def drain(self) -> list[dict]: - datums: list[dict] = [] + async def drain(self) -> list[dict | MessageOut]: + datums: list[dict | MessageOut] = [] - tasks = [ - asyncio.create_task(self.receive_from_out_queue()), - asyncio.create_task(self.receive_from_user()), - ] + result = await asyncio.gather( + self.receive_from_out_queue(), + self.receive_from_user(), + ) - results = await asyncio.gather(*tasks) + # NOTE(jdo): mypy seems to be unable to infer this, whereas pylance has + # no issue; added explicit type hints here to help mypy out. + out_queue: list[MessageOut] = result[0] + in_queue: list[dict] = result[1] - for result in results: - datums.extend(result) + for out_message in out_queue: + datums.append(out_message) + + for in_message in in_queue: + if isinstance(in_message, dict): + datums.append(in_message) + else: + LOG.error( + f"{self.class_name} drain dropping user message: unexpected result type: {type(in_message)}", + message=in_message, + **self.identity, + ) if datums: LOG.info(f"{self.class_name} Drained {len(datums)} messages from message channel.", **self.identity) @@ -183,8 +281,8 @@ class MessageChannel: return datums - async def receive_from_out_queue(self) -> list[dict]: - datums: list[dict] = [] + async def receive_from_out_queue(self) -> list[MessageOut]: + datums: list[MessageOut] = [] while True: try: @@ -197,8 +295,8 @@ class MessageChannel: return datums - def receive_from_out_queue_nowait(self) -> list[dict]: - datums: list[dict] = [] + def receive_from_out_queue_nowait(self) -> list[MessageOut]: + datums: list[MessageOut] = [] while True: try: @@ -209,13 +307,14 @@ class MessageChannel: return datums - async def send(self, *, messages: list[dict]) -> t.Self: + # async def send(self, *, messages: list[dict]) -> t.Self: + async def send(self, *, messages: list[MessageOut]) -> t.Self: for message in messages: await self.out_queue.put(message) return self - def send_nowait(self, *, messages: list[dict]) -> t.Self: + def send_nowait(self, *, messages: list[MessageOut]) -> t.Self: for message in messages: self.out_queue.put_nowait(message) @@ -255,28 +354,43 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None: """ class_name = message_channel.class_name + exfiltration_channel: ExfiltrationChannel | None = None - async def handle_data(data: dict) -> None: - nonlocal class_name - - try: - message = reify_channel_message(data) - except ValueError: - LOG.error(f"MessageChannel: cannot reify channel message from data: {data}", **message_channel.identity) + async def send(message: MessageOut) -> None: + if message_channel.websocket.client_state != WebSocketState.CONNECTED: return - message_kind = message.kind + data = message_to_dict(message) - match message_kind: - case "ask-for-clipboard-response": - if not isinstance(message, MessageInAskForClipboardResponse): - LOG.error( - f"{class_name} invalid message type for ask-for-clipboard-response.", - message=message, - **message_channel.identity, - ) - return + try: + await message_channel.websocket.send_json(data) + except WebSocketDisconnect: + pass + except Exception: + LOG.exception("MessageChannel: failed to send data.") + async def handle_data(data: dict | MessageOut) -> None: + nonlocal class_name + nonlocal exfiltration_channel + message: ChannelMessage + + if isinstance(data, MessageOut): + message = data + elif isinstance(data, dict): + try: + message = reify_channel_message(data) + except ValueError: + LOG.error(f"MessageChannel: cannot reify channel message from data: {data}", **message_channel.identity) + return + else: + LOG.error( + f"{class_name} cannot handle data: expected dict or MessageOut, got {type(data)}", + **message_channel.identity, + ) + return + + match message.kind: + case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE: vnc_channel = get_vnc_channel(message_channel.client_id) if not vnc_channel: @@ -292,7 +406,43 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None: async with execution_channel(vnc_channel) as execute: await execute.paste_text(text) - case "cede-control": + case MessageKind.BEGIN_EXFILTRATION: + if exfiltration_channel is not None: + LOG.error( + "MessageChannel: cannot begin exfiltration: already active.", message_channel=message_channel + ) + return + + vnc_channel = get_vnc_channel(message_channel.client_id) + + if not vnc_channel: + LOG.error( + f"{class_name} no vnc channel client found for message channel - cannot exfiltrate.", + message=message, + **message_channel.identity, + ) + return + + def on_event(events: list[ExfiltratedEvent]) -> None: + for event in events: + message_out_exfiltrated_event = MessageOutExfiltratedEvent( + kind=t.cast(t.Literal[MessageKind.EXFILTRATED_EVENT], event.kind), + event_name=event.event_name, + params=event.params, + source=t.cast(ExfiltratedEventSource, event.source or ExfiltratedEventSource.NOT_SPECIFIED), + ) + + message_channel.send_nowait(messages=[message_out_exfiltrated_event]) + + exfiltration_channel = await ExfiltrationChannel( + on_event=on_event, + vnc_channel=vnc_channel, + ).start() + + case MessageKind.BROWSER_TABS: + await send(message) + + case MessageKind.CEDE_CONTROL: vnc_channel = get_vnc_channel(message_channel.client_id) if not vnc_channel: @@ -304,7 +454,24 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None: return vnc_channel.interactor = "agent" - case "take-control": + case MessageKind.END_EXFILTRATION: + if exfiltration_channel is None: + return + + await exfiltration_channel.stop() + + exfiltration_channel = None + + case MessageKind.EXFILTRATED_EVENT: + await send(message) + + # case MessageKind.GET_TAB_INFO: + # """ + # TODO(jdo): implement - this is an on-demand request for tab info, which is + # required when connecting to an existing browser session. + # """ + + case MessageKind.TAKE_CONTROL: LOG.info(f"{class_name} processing take-control message.", **message_channel.identity) vnc_channel = get_vnc_channel(message_channel.client_id) @@ -318,8 +485,7 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None: vnc_channel.interactor = "user" case _: - LOG.error(f"{class_name} unknown message kind: '{message_kind}'", **message_channel.identity) - return + t.assert_never(message.kind) async def frontend_to_backend() -> None: nonlocal class_name @@ -331,9 +497,9 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None: datums = await message_channel.drain() for data in datums: - if not isinstance(data, dict): + if not isinstance(data, (dict, MessageOut)): LOG.error( - f"{class_name} cannot create message: expected dict, got {type(data)}", + f"{class_name} cannot handle message: expected dict or MessageOut, got {type(data)}", **message_channel.identity, ) continue