Browser Exfiltration (#4093)
This commit is contained in:
@@ -13,6 +13,8 @@ Channel data:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import pathlib
|
||||
import typing as t
|
||||
|
||||
import structlog
|
||||
@@ -183,3 +185,17 @@ class CdpChannel:
|
||||
except Exception:
|
||||
LOG.exception(f"{self.class_name} failed to evaluate js", expression=expression, **self.identity)
|
||||
raise
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def js(self, file_name: str) -> str:
|
||||
base_path = pathlib.Path(__file__).parent / "js"
|
||||
file_name = file_name.lstrip("/")
|
||||
|
||||
if not file_name.endswith(".js"):
|
||||
file_name += ".js"
|
||||
|
||||
relative_path = pathlib.Path(file_name)
|
||||
full_path = base_path / relative_path
|
||||
|
||||
with open(full_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
@@ -1 +1,234 @@
|
||||
# stub
|
||||
"""
|
||||
This channel exfiltrates all user activity in a browser.
|
||||
|
||||
What this channel looks like:
|
||||
|
||||
[Skyvern App] <-- [API Server] <--> [Browser (CDP)]
|
||||
|
||||
Channel data:
|
||||
|
||||
Raw JavaScript events (as JSON) over WebSockets.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import json
|
||||
import typing as t
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import CDPSession, ConsoleMessage, Page
|
||||
|
||||
from skyvern.forge.sdk.routes.streaming.channels.cdp import CdpChannel
|
||||
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ExfiltratedEventSource(enum.Enum):
|
||||
CONSOLE = "console"
|
||||
CDP = "cdp"
|
||||
NOT_SPECIFIED = "[not-specified]"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExfiltratedEvent:
|
||||
kind: t.Literal["exfiltrated-event"] = "exfiltrated-event"
|
||||
event_name: str = "[not-specified]"
|
||||
|
||||
# TODO(jdo): improve typing for params
|
||||
params: dict = dataclasses.field(default_factory=dict)
|
||||
source: ExfiltratedEventSource = ExfiltratedEventSource.NOT_SPECIFIED
|
||||
|
||||
|
||||
OnExfiltrationEvent = t.Callable[[list[ExfiltratedEvent]], None]
|
||||
|
||||
|
||||
class ExfiltrationChannel(CdpChannel):
|
||||
"""
|
||||
ExfiltrationChannel.
|
||||
"""
|
||||
|
||||
def __init__(self, *, on_event: OnExfiltrationEvent, vnc_channel: VncChannel) -> None:
|
||||
self.cdp_session: CDPSession | None = None
|
||||
self.on_event = on_event
|
||||
|
||||
super().__init__(vnc_channel=vnc_channel)
|
||||
|
||||
def _handle_console_event(self, msg: ConsoleMessage) -> None:
|
||||
"""Parse console messages for exfiltrated event data."""
|
||||
text = msg.text
|
||||
if text.startswith("[EXFIL]"):
|
||||
try:
|
||||
event_data = json.loads(text[7:]) # Strip '[EXFIL]' prefix
|
||||
|
||||
messages = [
|
||||
ExfiltratedEvent(
|
||||
kind="exfiltrated-event",
|
||||
event_name="user_interaction",
|
||||
params=event_data,
|
||||
source=ExfiltratedEventSource.CONSOLE,
|
||||
),
|
||||
]
|
||||
|
||||
self.on_event(messages)
|
||||
except Exception:
|
||||
LOG.exception(f"{self.class_name} Failed to parse exfiltrated event", text=text)
|
||||
|
||||
def _handle_cdp_event(self, event_name: str, params: dict) -> None:
|
||||
LOG.debug(f"{self.class_name} cdp event captured: {event_name}", params=params)
|
||||
|
||||
messages = [
|
||||
ExfiltratedEvent(
|
||||
kind="exfiltrated-event",
|
||||
event_name=event_name,
|
||||
params=params,
|
||||
source=ExfiltratedEventSource.CDP,
|
||||
),
|
||||
]
|
||||
|
||||
self.on_event(messages)
|
||||
|
||||
async def connect(self, cdp_url: str | None = None) -> t.Self:
|
||||
if self.browser and self.browser.is_connected() and self.cdp_session:
|
||||
return self
|
||||
|
||||
await super().connect(cdp_url)
|
||||
|
||||
# NOTE(jdo:streaming-local-dev)
|
||||
# from skyvern.config import settings
|
||||
# await super().connect(cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL)
|
||||
|
||||
page = self.page
|
||||
|
||||
if not page:
|
||||
raise RuntimeError(f"{self.class_name} No page available after connecting to browser.")
|
||||
|
||||
self.cdp_session = await page.context.new_cdp_session(page)
|
||||
|
||||
return self
|
||||
|
||||
async def exfiltrate(self, page: Page) -> t.Self:
|
||||
"""
|
||||
Track user interactions and send to console for CDP to capture.
|
||||
"""
|
||||
|
||||
LOG.info(f"{self.class_name} setting up exfiltration on new page.", url=page.url)
|
||||
|
||||
page.on("console", self._handle_console_event)
|
||||
|
||||
await page.evaluate(self.js("exfiltrate"))
|
||||
|
||||
LOG.info(f"{self.class_name} setup complete on page.", url=page.url)
|
||||
|
||||
return self
|
||||
|
||||
async def decorate(self, page: Page) -> t.Self:
|
||||
"""Add a mouse-following follower to the page."""
|
||||
LOG.info(f"{self.class_name} adding decoration to page.", url=page.url)
|
||||
|
||||
await page.evaluate(self.js("decorate"))
|
||||
|
||||
LOG.info(f"{self.class_name} decoration setup complete on page.", url=page.url)
|
||||
|
||||
return self
|
||||
|
||||
async def undecorate(self, page: Page) -> t.Self:
|
||||
"""Remove the mouse-following follower from the page."""
|
||||
LOG.info(f"{self.class_name} removing decoration from page.", url=page.url)
|
||||
|
||||
await page.evaluate(self.js("undecorate"))
|
||||
|
||||
LOG.info(f"{self.class_name} decoration removed from page.", url=page.url)
|
||||
|
||||
return self
|
||||
|
||||
async def enable_cdp_events(self) -> t.Self:
|
||||
await self.connect()
|
||||
|
||||
cdp_session = self.cdp_session
|
||||
|
||||
if not cdp_session:
|
||||
raise RuntimeError(f"{self.class_name} No CDP session available to enable events.")
|
||||
|
||||
enables = [
|
||||
cdp_session.send("Runtime.enable"),
|
||||
cdp_session.send("DOM.enable"),
|
||||
cdp_session.send("Page.enable"),
|
||||
cdp_session.send("Target.setDiscoverTargets", {"discover": True}),
|
||||
]
|
||||
|
||||
await asyncio.gather(*enables)
|
||||
|
||||
# listen to CDP events for tab management and navigation
|
||||
cdp_session.on("Target.targetCreated", lambda params: self._handle_cdp_event("target_created", params))
|
||||
cdp_session.on("Target.targetDestroyed", lambda params: self._handle_cdp_event("target_destroyed", params))
|
||||
cdp_session.on("Target.targetInfoChanged", lambda params: self._handle_cdp_event("target_info_changed", params))
|
||||
cdp_session.on("Page.frameNavigated", lambda params: self._handle_cdp_event("frame_navigated", params))
|
||||
cdp_session.on(
|
||||
"Page.navigatedWithinDocument", lambda params: self._handle_cdp_event("navigated_within_document", params)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def enable_console_events(self) -> t.Self:
|
||||
browser_context = self.browser_context
|
||||
|
||||
if not browser_context:
|
||||
LOG.error(f"{self.class_name} no browser context to enable console events.")
|
||||
return self
|
||||
|
||||
for page in browser_context.pages:
|
||||
asyncio.create_task(self.exfiltrate(page))
|
||||
|
||||
browser_context.on("page", lambda page: asyncio.create_task(self.exfiltrate(page)))
|
||||
|
||||
return self
|
||||
|
||||
def enable_decoration(self) -> t.Self:
|
||||
browser_context = self.browser_context
|
||||
|
||||
if not browser_context:
|
||||
LOG.error(f"{self.class_name} no browser context to enable decoration.")
|
||||
return self
|
||||
|
||||
for page in browser_context.pages:
|
||||
asyncio.create_task(self.decorate(page))
|
||||
|
||||
browser_context.on("page", lambda page: asyncio.create_task(self.decorate(page)))
|
||||
|
||||
return self
|
||||
|
||||
async def start(self) -> t.Self:
|
||||
LOG.info(f"{self.class_name} starting.")
|
||||
|
||||
await self.enable_cdp_events()
|
||||
|
||||
self.enable_console_events()
|
||||
|
||||
self.enable_decoration()
|
||||
|
||||
return self
|
||||
|
||||
async def stop(self) -> t.Self:
|
||||
LOG.info(f"{self.class_name} stopping.")
|
||||
|
||||
if not self.cdp_session:
|
||||
return self
|
||||
|
||||
try:
|
||||
await self.cdp_session.detach()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.cdp_session = None
|
||||
|
||||
pages = self.browser_context.pages if self.browser_context else []
|
||||
|
||||
for page in pages:
|
||||
page.remove_listener("console", self._handle_console_event)
|
||||
await self.undecorate(page)
|
||||
|
||||
LOG.info(f"{self.class_name} stopped.")
|
||||
|
||||
return self
|
||||
|
||||
95
skyvern/forge/sdk/routes/streaming/channels/js/decorate.js
Normal file
95
skyvern/forge/sdk/routes/streaming/channels/js/decorate.js
Normal file
@@ -0,0 +1,95 @@
|
||||
(function () {
|
||||
if (!window.__skyvern_decoration_initialized) {
|
||||
window.__skyvern_decoration_initialized = true;
|
||||
|
||||
window.__skyvern_create_mouse_follower = function () {
|
||||
// create the circle element
|
||||
const circle = document.createElement("div");
|
||||
window.__skyvern_decoration_mouse_follower = circle;
|
||||
circle.id = "__skyvern_mouse_follower";
|
||||
circle.style.position = "fixed";
|
||||
circle.style.left = "0";
|
||||
circle.style.top = "0";
|
||||
circle.style.width = "30px";
|
||||
circle.style.height = "30px";
|
||||
circle.style.borderRadius = "50%";
|
||||
circle.style.backgroundColor = "rgba(255, 0, 0, 0.2)";
|
||||
circle.style.pointerEvents = "none";
|
||||
circle.style.zIndex = "999999";
|
||||
circle.style.willChange = "transform";
|
||||
document.body.appendChild(circle);
|
||||
};
|
||||
|
||||
window.__skyvern_create_mouse_follower();
|
||||
|
||||
let scale = 1;
|
||||
let targetScale = 1;
|
||||
let mouseX = 0;
|
||||
let mouseY = 0;
|
||||
|
||||
// smooth scale animation
|
||||
function animate() {
|
||||
if (!window.__skyvern_decoration_mouse_follower) {
|
||||
return;
|
||||
}
|
||||
|
||||
const follower = window.__skyvern_decoration_mouse_follower;
|
||||
|
||||
scale += (targetScale - scale) * 0.2;
|
||||
|
||||
if (Math.abs(targetScale - scale) > 0.001) {
|
||||
requestAnimationFrame(animate);
|
||||
} else {
|
||||
scale = targetScale;
|
||||
}
|
||||
|
||||
follower.style.transform = `translate(${mouseX - 15}px, ${mouseY - 15}px) scale(${scale})`;
|
||||
}
|
||||
|
||||
// update follower position on mouse move
|
||||
document.addEventListener(
|
||||
"mousemove",
|
||||
(e) => {
|
||||
if (!window.__skyvern_decoration_mouse_follower) {
|
||||
return;
|
||||
}
|
||||
|
||||
const follower = window.__skyvern_decoration_mouse_follower;
|
||||
mouseX = e.clientX;
|
||||
mouseY = e.clientY;
|
||||
follower.style.transform = `translate(${mouseX - 15}px, ${mouseY - 15}px) scale(${scale})`;
|
||||
},
|
||||
true,
|
||||
);
|
||||
|
||||
// expand follower on mouse down
|
||||
document.addEventListener(
|
||||
"mousedown",
|
||||
() => {
|
||||
if (!window.__skyvern_decoration_mouse_follower) {
|
||||
return;
|
||||
}
|
||||
|
||||
targetScale = 50 / 30;
|
||||
requestAnimationFrame(animate);
|
||||
},
|
||||
true,
|
||||
);
|
||||
|
||||
// return follower to original size on mouse up
|
||||
document.addEventListener(
|
||||
"mouseup",
|
||||
() => {
|
||||
if (!window.__skyvern_decoration_mouse_follower) {
|
||||
return;
|
||||
}
|
||||
|
||||
targetScale = 1;
|
||||
requestAnimationFrame(animate);
|
||||
},
|
||||
true,
|
||||
);
|
||||
} else {
|
||||
window.__skyvern_create_mouse_follower();
|
||||
}
|
||||
})();
|
||||
165
skyvern/forge/sdk/routes/streaming/channels/js/exfiltrate.js
Normal file
165
skyvern/forge/sdk/routes/streaming/channels/js/exfiltrate.js
Normal file
@@ -0,0 +1,165 @@
|
||||
(function () {
|
||||
if (!window.__skyvern_exfiltration_initialized) {
|
||||
window.__skyvern_exfiltration_initialized = true;
|
||||
|
||||
[
|
||||
"click",
|
||||
"mousedown",
|
||||
"mouseup",
|
||||
"mouseenter",
|
||||
"mouseleave",
|
||||
"keydown",
|
||||
"keyup",
|
||||
"keypress",
|
||||
"focus",
|
||||
"blur",
|
||||
"input",
|
||||
"change",
|
||||
"scroll",
|
||||
"contextmenu",
|
||||
"dblclick",
|
||||
].forEach((eventType) => {
|
||||
document.addEventListener(
|
||||
eventType,
|
||||
(e) => {
|
||||
// find associated labels
|
||||
const getAssociatedLabels = (element) => {
|
||||
const labels = [];
|
||||
|
||||
// label with 'for' attribute matching element's id
|
||||
if (element.id) {
|
||||
const labelsByFor = document.querySelectorAll(
|
||||
`label[for="${element.id}"]`,
|
||||
);
|
||||
|
||||
labelsByFor.forEach((label) => {
|
||||
if (label.textContent) labels.push(label.textContent.trim());
|
||||
});
|
||||
}
|
||||
|
||||
// label wrapping the element
|
||||
let parent = element.parentElement;
|
||||
|
||||
while (parent) {
|
||||
if (parent.tagName === "LABEL") {
|
||||
if (parent.textContent) labels.push(parent.textContent.trim());
|
||||
break;
|
||||
}
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
|
||||
return labels.length > 0 ? labels : null;
|
||||
};
|
||||
|
||||
// get any kind of text content
|
||||
const getElementText = (element) => {
|
||||
const textSources = [];
|
||||
|
||||
if (element.getAttribute("aria-label")) {
|
||||
textSources.push(element.getAttribute("aria-label"));
|
||||
}
|
||||
|
||||
if (element.getAttribute("aria-labelledby")) {
|
||||
const labelIds = element
|
||||
.getAttribute("aria-labelledby")
|
||||
.split(" ");
|
||||
|
||||
labelIds.forEach((id) => {
|
||||
const labelElement = document.getElementById(id);
|
||||
|
||||
if (labelElement?.textContent) {
|
||||
textSources.push(labelElement.textContent.trim());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (element.getAttribute("placeholder")) {
|
||||
textSources.push(element.getAttribute("placeholder"));
|
||||
}
|
||||
|
||||
if (element.getAttribute("title")) {
|
||||
textSources.push(element.getAttribute("title"));
|
||||
}
|
||||
|
||||
if (element.getAttribute("alt")) {
|
||||
textSources.push(element.getAttribute("alt"));
|
||||
}
|
||||
|
||||
if (element.innerText) {
|
||||
textSources.push(element.innerText.substring(0, 100));
|
||||
}
|
||||
|
||||
if (!element.innerText && element.textContent) {
|
||||
textSources.push(element.textContent.trim().substring(0, 100));
|
||||
}
|
||||
|
||||
return textSources.length > 0 ? textSources : [];
|
||||
};
|
||||
|
||||
const eventData = {
|
||||
url: window.location.href,
|
||||
type: eventType,
|
||||
timestamp: Date.now(),
|
||||
target: {
|
||||
tagName: e.target?.tagName,
|
||||
id: e.target?.id,
|
||||
className: e.target?.className,
|
||||
value: e.target?.value,
|
||||
text: getElementText(e.target),
|
||||
labels: getAssociatedLabels(e.target),
|
||||
},
|
||||
inputValue: ["input", "focus", "blur"].includes(eventType)
|
||||
? e.target?.value
|
||||
: undefined,
|
||||
mousePosition: {
|
||||
xa: e.clientX,
|
||||
ya: e.clientY,
|
||||
xp: e.clientX / window.innerWidth,
|
||||
yp: e.clientY / window.innerHeight,
|
||||
},
|
||||
key: e.key,
|
||||
code: e.code,
|
||||
activeElement: {
|
||||
tagName: document.activeElement?.tagName,
|
||||
id: document.activeElement?.id,
|
||||
className: document.activeElement?.className,
|
||||
boundingRect: document.activeElement?.getBoundingClientRect
|
||||
? {
|
||||
x: document.activeElement.getBoundingClientRect().x,
|
||||
y: document.activeElement.getBoundingClientRect().y,
|
||||
width: document.activeElement.getBoundingClientRect().width,
|
||||
height:
|
||||
document.activeElement.getBoundingClientRect().height,
|
||||
top: document.activeElement.getBoundingClientRect().top,
|
||||
right: document.activeElement.getBoundingClientRect().right,
|
||||
bottom:
|
||||
document.activeElement.getBoundingClientRect().bottom,
|
||||
left: document.activeElement.getBoundingClientRect().left,
|
||||
}
|
||||
: null,
|
||||
scroll: document.activeElement
|
||||
? {
|
||||
scrollTop: document.activeElement.scrollTop,
|
||||
scrollLeft: document.activeElement.scrollLeft,
|
||||
scrollHeight: document.activeElement.scrollHeight,
|
||||
scrollWidth: document.activeElement.scrollWidth,
|
||||
clientHeight: document.activeElement.clientHeight,
|
||||
clientWidth: document.activeElement.clientWidth,
|
||||
}
|
||||
: null,
|
||||
},
|
||||
window: {
|
||||
width: window.innerWidth,
|
||||
height: window.innerHeight,
|
||||
scrollX: window.scrollX,
|
||||
scrollY: window.scrollY,
|
||||
},
|
||||
};
|
||||
|
||||
console.log("[EXFIL]", JSON.stringify(eventData));
|
||||
},
|
||||
true,
|
||||
);
|
||||
});
|
||||
}
|
||||
})();
|
||||
@@ -0,0 +1,9 @@
|
||||
(function () {
|
||||
const followers = document.querySelectorAll("#__skyvern_mouse_follower");
|
||||
|
||||
for (const follower of followers) {
|
||||
follower.remove();
|
||||
}
|
||||
|
||||
window.__skyvern_decoration_mouse_follower = null;
|
||||
})();
|
||||
@@ -14,6 +14,7 @@ Channel data:
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import typing as t
|
||||
|
||||
import structlog
|
||||
@@ -22,6 +23,7 @@ from starlette.websockets import WebSocketState
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel
|
||||
from skyvern.forge.sdk.routes.streaming.channels.exfiltration import ExfiltratedEvent, ExfiltrationChannel
|
||||
from skyvern.forge.sdk.routes.streaming.registries import (
|
||||
add_message_channel,
|
||||
del_message_channel,
|
||||
@@ -42,10 +44,42 @@ LOG = structlog.get_logger()
|
||||
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response"
|
||||
BEGIN_EXFILTRATION = "begin-exfiltration"
|
||||
BROWSER_TABS = "browser-tabs"
|
||||
CEDE_CONTROL = "cede-control"
|
||||
END_EXFILTRATION = "end-exfiltration"
|
||||
EXFILTRATED_EVENT = "exfiltrated-event"
|
||||
TAKE_CONTROL = "take-control"
|
||||
|
||||
|
||||
class ExfiltratedEventSource(enum.StrEnum):
|
||||
CONSOLE = "console"
|
||||
CDP = "cdp"
|
||||
NOT_SPECIFIED = "[not-specified]"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TabInfo:
|
||||
id: str
|
||||
title: str
|
||||
url: str
|
||||
# --
|
||||
active: bool = False
|
||||
favicon: str | None = None
|
||||
isReady: bool = True
|
||||
pageNumber: int | None = None
|
||||
|
||||
|
||||
MessageKinds = t.Literal[
|
||||
"ask-for-clipboard-response",
|
||||
"cede-control",
|
||||
"take-control",
|
||||
MessageKind.ASK_FOR_CLIPBOARD_RESPONSE,
|
||||
MessageKind.BEGIN_EXFILTRATION,
|
||||
MessageKind.BROWSER_TABS,
|
||||
MessageKind.CEDE_CONTROL,
|
||||
MessageKind.END_EXFILTRATION,
|
||||
MessageKind.EXFILTRATED_EVENT,
|
||||
MessageKind.TAKE_CONTROL,
|
||||
]
|
||||
|
||||
|
||||
@@ -54,44 +88,95 @@ class Message:
|
||||
kind: MessageKinds
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInBeginExfiltration(Message):
|
||||
kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInEndExfiltration(Message):
|
||||
kind: t.Literal[MessageKind.END_EXFILTRATION] = MessageKind.END_EXFILTRATION
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInTakeControl(Message):
|
||||
kind: t.Literal["take-control"] = "take-control"
|
||||
kind: t.Literal[MessageKind.TAKE_CONTROL] = MessageKind.TAKE_CONTROL
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInCedeControl(Message):
|
||||
kind: t.Literal["cede-control"] = "cede-control"
|
||||
kind: t.Literal[MessageKind.CEDE_CONTROL] = MessageKind.CEDE_CONTROL
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInAskForClipboardResponse(Message):
|
||||
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
|
||||
kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE
|
||||
text: str = ""
|
||||
|
||||
|
||||
ChannelMessage = t.Union[
|
||||
MessageInAskForClipboardResponse,
|
||||
MessageInCedeControl,
|
||||
MessageInTakeControl,
|
||||
]
|
||||
@dataclasses.dataclass
|
||||
class MessageOutExfiltratedEvent(Message):
|
||||
kind: t.Literal[MessageKind.EXFILTRATED_EVENT] = MessageKind.EXFILTRATED_EVENT
|
||||
event_name: str = "[not-specified]"
|
||||
|
||||
# TODO(jdo): improve typing for params
|
||||
params: dict = dataclasses.field(default_factory=dict)
|
||||
source: ExfiltratedEventSource = ExfiltratedEventSource.NOT_SPECIFIED
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageOutTabInfo(Message):
|
||||
kind: t.Literal[MessageKind.BROWSER_TABS] = MessageKind.BROWSER_TABS
|
||||
tabs: list[TabInfo] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
MessageIn = (
|
||||
MessageInAskForClipboardResponse
|
||||
| MessageInBeginExfiltration
|
||||
| MessageInCedeControl
|
||||
| MessageInEndExfiltration
|
||||
| MessageInTakeControl
|
||||
)
|
||||
|
||||
|
||||
MessageOut = MessageOutExfiltratedEvent | MessageOutTabInfo
|
||||
|
||||
|
||||
ChannelMessage = MessageIn | MessageOut
|
||||
|
||||
|
||||
def reify_channel_message(data: dict) -> ChannelMessage:
|
||||
kind = data.get("kind", None)
|
||||
|
||||
match kind:
|
||||
case "ask-for-clipboard-response":
|
||||
case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE:
|
||||
text = data.get("text") or ""
|
||||
return MessageInAskForClipboardResponse(text=text)
|
||||
case "cede-control":
|
||||
case MessageKind.BEGIN_EXFILTRATION:
|
||||
return MessageInBeginExfiltration()
|
||||
case MessageKind.CEDE_CONTROL:
|
||||
return MessageInCedeControl()
|
||||
case "take-control":
|
||||
case MessageKind.END_EXFILTRATION:
|
||||
return MessageInEndExfiltration()
|
||||
case MessageKind.TAKE_CONTROL:
|
||||
return MessageInTakeControl()
|
||||
case _:
|
||||
raise ValueError(f"Unknown message kind: '{kind}'")
|
||||
|
||||
|
||||
def message_to_dict(message: MessageOut) -> dict:
|
||||
"""
|
||||
Convert message to dict with enums as their values.
|
||||
"""
|
||||
|
||||
def convert_value(obj: t.Any) -> t.Any:
|
||||
if isinstance(obj, enum.Enum):
|
||||
return obj.value
|
||||
return obj
|
||||
|
||||
return dataclasses.asdict(message, dict_factory=lambda x: {k: convert_value(v) for k, v in x})
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageChannel:
|
||||
"""
|
||||
@@ -102,7 +187,7 @@ class MessageChannel:
|
||||
organization_id: str
|
||||
websocket: WebSocket
|
||||
# --
|
||||
out_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) # warn: unbounded
|
||||
out_queue: asyncio.Queue[MessageOut] = dataclasses.field(default_factory=asyncio.Queue) # warn: unbounded
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
@@ -147,18 +232,31 @@ class MessageChannel:
|
||||
|
||||
return True
|
||||
|
||||
async def drain(self) -> list[dict]:
|
||||
datums: list[dict] = []
|
||||
async def drain(self) -> list[dict | MessageOut]:
|
||||
datums: list[dict | MessageOut] = []
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self.receive_from_out_queue()),
|
||||
asyncio.create_task(self.receive_from_user()),
|
||||
]
|
||||
result = await asyncio.gather(
|
||||
self.receive_from_out_queue(),
|
||||
self.receive_from_user(),
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
# NOTE(jdo): mypy seems to be unable to infer this, whereas pylance has
|
||||
# no issue; added explicit type hints here to help mypy out.
|
||||
out_queue: list[MessageOut] = result[0]
|
||||
in_queue: list[dict] = result[1]
|
||||
|
||||
for result in results:
|
||||
datums.extend(result)
|
||||
for out_message in out_queue:
|
||||
datums.append(out_message)
|
||||
|
||||
for in_message in in_queue:
|
||||
if isinstance(in_message, dict):
|
||||
datums.append(in_message)
|
||||
else:
|
||||
LOG.error(
|
||||
f"{self.class_name} drain dropping user message: unexpected result type: {type(in_message)}",
|
||||
message=in_message,
|
||||
**self.identity,
|
||||
)
|
||||
|
||||
if datums:
|
||||
LOG.info(f"{self.class_name} Drained {len(datums)} messages from message channel.", **self.identity)
|
||||
@@ -183,8 +281,8 @@ class MessageChannel:
|
||||
|
||||
return datums
|
||||
|
||||
async def receive_from_out_queue(self) -> list[dict]:
|
||||
datums: list[dict] = []
|
||||
async def receive_from_out_queue(self) -> list[MessageOut]:
|
||||
datums: list[MessageOut] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -197,8 +295,8 @@ class MessageChannel:
|
||||
|
||||
return datums
|
||||
|
||||
def receive_from_out_queue_nowait(self) -> list[dict]:
|
||||
datums: list[dict] = []
|
||||
def receive_from_out_queue_nowait(self) -> list[MessageOut]:
|
||||
datums: list[MessageOut] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -209,13 +307,14 @@ class MessageChannel:
|
||||
|
||||
return datums
|
||||
|
||||
async def send(self, *, messages: list[dict]) -> t.Self:
|
||||
# async def send(self, *, messages: list[dict]) -> t.Self:
|
||||
async def send(self, *, messages: list[MessageOut]) -> t.Self:
|
||||
for message in messages:
|
||||
await self.out_queue.put(message)
|
||||
|
||||
return self
|
||||
|
||||
def send_nowait(self, *, messages: list[dict]) -> t.Self:
|
||||
def send_nowait(self, *, messages: list[MessageOut]) -> t.Self:
|
||||
for message in messages:
|
||||
self.out_queue.put_nowait(message)
|
||||
|
||||
@@ -255,28 +354,43 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None:
|
||||
"""
|
||||
|
||||
class_name = message_channel.class_name
|
||||
exfiltration_channel: ExfiltrationChannel | None = None
|
||||
|
||||
async def handle_data(data: dict) -> None:
|
||||
nonlocal class_name
|
||||
|
||||
try:
|
||||
message = reify_channel_message(data)
|
||||
except ValueError:
|
||||
LOG.error(f"MessageChannel: cannot reify channel message from data: {data}", **message_channel.identity)
|
||||
async def send(message: MessageOut) -> None:
|
||||
if message_channel.websocket.client_state != WebSocketState.CONNECTED:
|
||||
return
|
||||
|
||||
message_kind = message.kind
|
||||
data = message_to_dict(message)
|
||||
|
||||
match message_kind:
|
||||
case "ask-for-clipboard-response":
|
||||
if not isinstance(message, MessageInAskForClipboardResponse):
|
||||
LOG.error(
|
||||
f"{class_name} invalid message type for ask-for-clipboard-response.",
|
||||
message=message,
|
||||
**message_channel.identity,
|
||||
)
|
||||
return
|
||||
try:
|
||||
await message_channel.websocket.send_json(data)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception("MessageChannel: failed to send data.")
|
||||
|
||||
async def handle_data(data: dict | MessageOut) -> None:
|
||||
nonlocal class_name
|
||||
nonlocal exfiltration_channel
|
||||
message: ChannelMessage
|
||||
|
||||
if isinstance(data, MessageOut):
|
||||
message = data
|
||||
elif isinstance(data, dict):
|
||||
try:
|
||||
message = reify_channel_message(data)
|
||||
except ValueError:
|
||||
LOG.error(f"MessageChannel: cannot reify channel message from data: {data}", **message_channel.identity)
|
||||
return
|
||||
else:
|
||||
LOG.error(
|
||||
f"{class_name} cannot handle data: expected dict or MessageOut, got {type(data)}",
|
||||
**message_channel.identity,
|
||||
)
|
||||
return
|
||||
|
||||
match message.kind:
|
||||
case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE:
|
||||
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||
|
||||
if not vnc_channel:
|
||||
@@ -292,7 +406,43 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None:
|
||||
async with execution_channel(vnc_channel) as execute:
|
||||
await execute.paste_text(text)
|
||||
|
||||
case "cede-control":
|
||||
case MessageKind.BEGIN_EXFILTRATION:
|
||||
if exfiltration_channel is not None:
|
||||
LOG.error(
|
||||
"MessageChannel: cannot begin exfiltration: already active.", message_channel=message_channel
|
||||
)
|
||||
return
|
||||
|
||||
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||
|
||||
if not vnc_channel:
|
||||
LOG.error(
|
||||
f"{class_name} no vnc channel client found for message channel - cannot exfiltrate.",
|
||||
message=message,
|
||||
**message_channel.identity,
|
||||
)
|
||||
return
|
||||
|
||||
def on_event(events: list[ExfiltratedEvent]) -> None:
|
||||
for event in events:
|
||||
message_out_exfiltrated_event = MessageOutExfiltratedEvent(
|
||||
kind=t.cast(t.Literal[MessageKind.EXFILTRATED_EVENT], event.kind),
|
||||
event_name=event.event_name,
|
||||
params=event.params,
|
||||
source=t.cast(ExfiltratedEventSource, event.source or ExfiltratedEventSource.NOT_SPECIFIED),
|
||||
)
|
||||
|
||||
message_channel.send_nowait(messages=[message_out_exfiltrated_event])
|
||||
|
||||
exfiltration_channel = await ExfiltrationChannel(
|
||||
on_event=on_event,
|
||||
vnc_channel=vnc_channel,
|
||||
).start()
|
||||
|
||||
case MessageKind.BROWSER_TABS:
|
||||
await send(message)
|
||||
|
||||
case MessageKind.CEDE_CONTROL:
|
||||
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||
|
||||
if not vnc_channel:
|
||||
@@ -304,7 +454,24 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None:
|
||||
return
|
||||
vnc_channel.interactor = "agent"
|
||||
|
||||
case "take-control":
|
||||
case MessageKind.END_EXFILTRATION:
|
||||
if exfiltration_channel is None:
|
||||
return
|
||||
|
||||
await exfiltration_channel.stop()
|
||||
|
||||
exfiltration_channel = None
|
||||
|
||||
case MessageKind.EXFILTRATED_EVENT:
|
||||
await send(message)
|
||||
|
||||
# case MessageKind.GET_TAB_INFO:
|
||||
# """
|
||||
# TODO(jdo): implement - this is an on-demand request for tab info, which is
|
||||
# required when connecting to an existing browser session.
|
||||
# """
|
||||
|
||||
case MessageKind.TAKE_CONTROL:
|
||||
LOG.info(f"{class_name} processing take-control message.", **message_channel.identity)
|
||||
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||
|
||||
@@ -318,8 +485,7 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None:
|
||||
vnc_channel.interactor = "user"
|
||||
|
||||
case _:
|
||||
LOG.error(f"{class_name} unknown message kind: '{message_kind}'", **message_channel.identity)
|
||||
return
|
||||
t.assert_never(message.kind)
|
||||
|
||||
async def frontend_to_backend() -> None:
|
||||
nonlocal class_name
|
||||
@@ -331,9 +497,9 @@ async def loop_stream_messages(message_channel: MessageChannel) -> None:
|
||||
datums = await message_channel.drain()
|
||||
|
||||
for data in datums:
|
||||
if not isinstance(data, dict):
|
||||
if not isinstance(data, (dict, MessageOut)):
|
||||
LOG.error(
|
||||
f"{class_name} cannot create message: expected dict, got {type(data)}",
|
||||
f"{class_name} cannot handle message: expected dict or MessageOut, got {type(data)}",
|
||||
**message_channel.identity,
|
||||
)
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user