Browser streaming refactor (#4064)

This commit is contained in:
Jonathan Dobson
2025-11-21 15:12:26 -05:00
committed by GitHub
parent 91b8a9e0bb
commit d96de3b7a2
16 changed files with 1948 additions and 1295 deletions

View File

@@ -1,191 +0,0 @@
"""
A lightweight "agent" for interacting with the streaming browser over CDP.
"""
import typing
from contextlib import asynccontextmanager
import structlog
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
import skyvern.forge.sdk.routes.streaming.clients as sc
from skyvern.config import settings
LOG = structlog.get_logger()
class StreamingAgent:
"""
A minimal agent that can connect to a browser via CDP and execute JavaScript.
Specifically for operations during streaming sessions (like copy/pasting selected text, etc.).
"""
def __init__(self, streaming: sc.Streaming) -> None:
self.streaming = streaming
self.browser: Browser | None = None
self.browser_context: BrowserContext | None = None
self.page: Page | None = None
self.pw: Playwright | None = None
async def connect(self, cdp_url: str | None = None) -> None:
url = cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL
LOG.info("StreamingAgent connecting to CDP", cdp_url=url)
pw = self.pw or await async_playwright().start()
self.pw = pw
headers = {
"x-api-key": self.streaming.x_api_key,
}
self.browser = await pw.chromium.connect_over_cdp(url, headers=headers)
org_id = self.streaming.organization_id
browser_session_id = (
self.streaming.browser_session.persistent_browser_session_id if self.streaming.browser_session else None
)
if browser_session_id:
cdp_session = await self.browser.new_browser_cdp_session()
await cdp_session.send(
"Browser.setDownloadBehavior",
{
"behavior": "allow",
"downloadPath": f"/app/downloads/{org_id}/{browser_session_id}",
"eventsEnabled": True,
},
)
contexts = self.browser.contexts
if contexts:
LOG.info("StreamingAgent using existing browser context")
self.browser_context = contexts[0]
else:
LOG.warning("No existing browser context found, creating new one")
self.browser_context = await self.browser.new_context()
pages = self.browser_context.pages
if pages:
self.page = pages[0]
LOG.info("StreamingAgent connected to page", url=self.page.url)
else:
LOG.warning("No pages found in browser context")
LOG.info("StreamingAgent connected successfully")
async def evaluate_js(
self, expression: str, arg: str | int | float | bool | list | dict | None = None
) -> str | int | float | bool | list | dict | None:
if not self.page:
raise RuntimeError("StreamingAgent is not connected to a page. Call connect() first.")
LOG.info("StreamingAgent evaluating JS", expression=expression[:100])
try:
result = await self.page.evaluate(expression, arg)
LOG.info("StreamingAgent JS evaluation successful")
return result
except Exception as ex:
LOG.exception("StreamingAgent JS evaluation failed", expression=expression, ex=str(ex))
raise
async def get_selected_text(self) -> str:
LOG.info("StreamingAgent getting selected text")
js_expression = """
() => {
const selection = window.getSelection();
return selection ? selection.toString() : '';
}
"""
selected_text = await self.evaluate_js(js_expression)
if isinstance(selected_text, str) or selected_text is None:
LOG.info("StreamingAgent got selected text", length=len(selected_text) if selected_text else 0)
return selected_text or ""
raise RuntimeError(f"StreamingAgent selected text is not a string, but a(n) '{type(selected_text)}'")
async def paste_text(self, text: str) -> None:
LOG.info("StreamingAgent pasting text")
js_expression = """
(text) => {
const activeElement = document.activeElement;
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
const start = activeElement.selectionStart || 0;
const end = activeElement.selectionEnd || 0;
const value = activeElement.value || '';
activeElement.value = value.slice(0, start) + text + value.slice(end);
const newCursorPos = start + text.length;
activeElement.setSelectionRange(newCursorPos, newCursorPos);
}
}
"""
await self.evaluate_js(js_expression, text)
LOG.info("StreamingAgent pasted text successfully")
async def close(self) -> None:
LOG.info("StreamingAgent closing connection")
if self.browser:
await self.browser.close()
self.browser = None
self.browser_context = None
self.page = None
if self.pw:
await self.pw.stop()
self.pw = None
LOG.info("StreamingAgent closed")
@asynccontextmanager
async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterator[StreamingAgent]:
"""
The first pass at this has us doing the following for every operation:
- creating a new agent
- connecting
- [doing smth]
- closing the agent
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
If it turns out it's too slow, we can refactor to keep a persistent agent per streaming client.
"""
if not streaming:
msg = "connected_agent: no streaming client provided."
LOG.error(msg)
raise Exception(msg)
if not streaming.browser_session or not streaming.browser_session.browser_address:
msg = "connected_agent: no browser session or browser address found for streaming client."
LOG.error(
msg,
client_id=streaming.client_id,
organization_id=streaming.organization_id,
)
raise Exception(msg)
agent = StreamingAgent(streaming=streaming)
try:
await agent.connect(streaming.browser_session.browser_address)
# NOTE(jdo:streaming-local-dev): use BROWSER_REMOTE_DEBUGGING_URL from settings
# await agent.connect()
yield agent
finally:
await agent.close()

View File

@@ -6,11 +6,35 @@ import structlog
from fastapi import WebSocket
from websockets.exceptions import ConnectionClosedOK
from skyvern.forge import app
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.services.org_auth_service import get_current_org
LOG = structlog.get_logger()
class Constants:
MISSING_API_KEY = "<missing-x-api-key>"
async def get_x_api_key(organization_id: str) -> str:
token = await app.DATABASE.get_valid_org_auth_token(
organization_id,
OrganizationAuthTokenType.api.value,
)
if not token:
LOG.warning(
"No valid API key found for organization when streaming.",
organization_id=organization_id,
)
x_api_key = Constants.MISSING_API_KEY
else:
x_api_key = token.token
return x_api_key
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
"""
Accepts the websocket connection.

View File

@@ -0,0 +1,336 @@
# Channels
A "channel", as used within the streaming mechanism of our remote browsers,
is a WebSocket fit to some particular purpose.
There is/are:
- a "VNC" channel that transmits NoVNC's RFB protocol data
- a "Message" channel that transmits JSON between the frontend app and
the api server
- "CDP" channels that send messages to a remote browser using CDP protocol
data
- an "Execution" channel (one-off executions)
- a soon-to-be "Exfiltration" channel (user event streaming)
In all cases, these are just WebSockets. They have been bucketed into "named channels"
to aid understanding.
These channels are described at the top of their respective files.
## Architecture
WARN: below is an AI-generated architecture document for all of the code beneath
the `skyvern/forge/sdk/routes/streaming` directory. It looks correct.
### High-Level Component Diagram
```
┌─────────────────┐
│ Frontend App │
│ (Skyvern) │
└────────┬────────┘
│ Two WebSocket Connections (paired via client_id)
┌────┴────┬──────────────────────────────────────────┐
│ │ │
│ VNC Channel Message Channel
http (RFB Protocol) (JSON Messages)
│ │ │
│ │ │
┌───▼─────────▼──────────────────────────────────────────▼────┐
│ API Server │
│ ┌────────────────────────────────────────────────────────┐ │
│ │ Registries (In-Memory State) │ │
│ │ - vnc_channels: dict[client_id -> VncChannel] │ │
│ │ - message_channels: dict[client_id -> MessageChannel] │ │
│ └────────────────────────────────────────────────────────┘ │
│ │
│ VNC Channel Logic: Message Channel Logic: │
│ - RFB pass-through - Copy/paste coordination │
│ - Keyboard/mouse filtering - Control handoff (agent/user)│
│ - Interactor control - Clipboard management │
│ - Copy/paste detection - Channel coordination │
│ │
│ CDP Channels (created on-demand): │
│ - ExecutionChannel: JS evaluation (paste, get selected) │
│ - ExfiltrationChannel: (future) user event streaming │
│ │
└────┬─────────────────────────────────────────────────────┬──┘
│ │
│ WebSocket (RFB) Playwright (CDP)│
│ │
┌────▼─────────────────────────────────────────────────────▼──┐
│ Persistent Browser Session │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ noVNC Server Chrome/Chromium │ │
│ │ (websockify) │ │
│ └──────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
### Channel Pairing & Sticky Sessions
**Critical Design Constraint**: The VNC and Message channels for a given frontend
instance MUST connect to the same API server instance because they coordinate
via in-memory registries keyed by `client_id`.
```
Frontend Instance (client_id="abc123")
├─→ VNC Channel ─────→ API Server Instance #2
│ ↓
│ vnc_channels["abc123"] = VncChannel
│ ↕ (coordinate via client_id)
└─→ Message Channel ──→ API Server Instance #2
message_channels["abc123"] = MessageChannel
```
**Deployment Requirement**: Load balancer must use sticky sessions (e.g., cookie-based
or IP-based affinity) to ensure both WebSocket connections from the same client_id
reach the same backend instance.
### Channel Lifecycle & Verification
```
┌──────────────────────────────────────────────────────────────┐
│ Channel Creation (per browser_session/task/workflow_run) │
└────────────────────┬─────────────────────────────────────────┘
┌────────────────────────────────────────┐
│ Initial Verification │
│ - verify_browser_session() │
│ - verify_task() │
│ - verify_workflow_run() │
│ Returns: entity + browser session │
└────────┬───────────────────────────────┘
┌────────────────────────────────────────┐
│ Channel + Loops Created │
│ VncChannel/MessageChannel initialized │
│ + Added to registry │
└────────┬───────────────────────────────┘
┌────────────────────────────────────────┐
│ Concurrent Loop Execution │
│ (via collect() - fail-fast) │
│ │
│ Loop 1: Verification Loop │
│ - Polls every 5s │
│ - Updates channel state │
│ - Exits if entity invalid │
│ │
│ Loop 2: Data Streaming Loop │
│ - VNC: bidirectional RFB │
│ - Message: JSON messages │
│ - Exits on disconnect │
└────────┬───────────────────────────────┘
┌────────────────────────────────────────┐
│ Channel Cleanup │
│ - Close WebSocket │
│ - Remove from registry │
│ - Clear channel state │
└────────────────────────────────────────┘
```
### VNC Channel Data Flow
```
User Keyboard/Mouse Input
┌───────────────────────────────────────────────────────────┐
│ Frontend (noVNC client) │
│ Encodes input as RFB protocol bytes │
└────────┬──────────────────────────────────────────────────┘
│ WebSocket (bytes)
┌───────────────────────────────────────────────────────────┐
│ API Server: VncChannel.loop_stream_vnc() │
│ frontend_to_browser() coroutine │
│ │
│ 1. Receive RFB bytes from frontend │
│ 2. Detect message type (keyboard=4, mouse=5) │
│ 3. Update key_state tracking │
│ 4. Check for special key combinations: │
│ - Ctrl+C / Cmd+C → copy_text() via CDP │
│ - Ctrl+V / Cmd+V → ask_for_clipboard() via Message │
│ - Ctrl+O → BLOCK (forbidden) │
│ 5. Check interactor mode: │
│ - If interactor=="agent" → BLOCK user input │
│ - If interactor=="user" → PASS THROUGH │
│ 6. Block right-mouse-button (security) │
│ 7. Forward to noVNC server │
└────────┬──────────────────────────────────────────────────┘
│ WebSocket (bytes)
┌───────────────────────────────────────────────────────────┐
│ Persistent Browser: noVNC Server (websockify) │
│ Translates RFB → VNC protocol │
└────────┬──────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────┐
│ Browser Display Update │
└───────────────────────────────────────────────────────────┘
Screen Updates (reverse direction):
Browser → noVNC → VncChannel.browser_to_frontend() → Frontend
```
### Message Channel + CDP Execution Flow
```
User Pastes (Ctrl+V detected in VNC channel)
┌───────────────────────────────────────────────────────────┐
│ VncChannel: ask_for_clipboard() │
│ Finds MessageChannel via registry[client_id] │
└────────┬──────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────┐
│ MessageChannel.ask_for_clipboard() │
│ Sends: {"kind": "ask-for-clipboard"} │
└────────┬──────────────────────────────────────────────────┘
│ WebSocket (JSON)
┌───────────────────────────────────────────────────────────┐
│ Frontend: User's clipboard content │
│ Responds: {"kind": "ask-for-clipboard-response", │
│ "text": "clipboard content"} │
└────────┬──────────────────────────────────────────────────┘
│ WebSocket (JSON)
┌───────────────────────────────────────────────────────────┐
│ MessageChannel: handle_data() │
│ Receives clipboard text │
│ Finds VncChannel via registry[client_id] │
└────────┬──────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────┐
│ ExecutionChannel (CDP) │
│ 1. Connect to browser via Playwright │
│ 2. Get browser context + page │
│ 3. evaluate_js(paste_text_script, clipboard_text) │
│ 4. Close CDP connection │
└────────┬──────────────────────────────────────────────────┘
│ CDP over WebSocket
┌───────────────────────────────────────────────────────────┐
│ Browser: Text pasted into active element │
└───────────────────────────────────────────────────────────┘
Similar flow for Copy (Ctrl+C):
VNC detects → ExecutionChannel.get_selected_text() →
MessageChannel sends {"kind": "copied-text", "text": "..."}
→ Frontend updates clipboard
```
### Control Flow: Agent ↔ User Interaction
NOTE: we don't really have an "agent" at this time. But any control of the
browser that is not user-originated is kinda' agent-like, by some
definition of "agent". Here, we do not have an "AI agent". Future work may
alter this state of affairs - and some "agent" could operate the browser
automatically.
```
┌───────────────────────────────────────────────────────────┐
│ Initial State: interactor = "agent" │
│ - User keyboard/mouse input is BLOCKED │
│ - Agent can control browser via CDP │
└────────┬──────────────────────────────────────────────────┘
│ User clicks "Take Control" in frontend
┌───────────────────────────────────────────────────────────┐
│ Frontend → MessageChannel │
│ {"kind": "take-control"} │
└────────┬──────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────┐
│ MessageChannel.handle_data() │
│ vnc_channel.interactor = "user" │
└────────┬──────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────┐
│ New State: interactor = "user" │
│ - User keyboard/mouse input is PASSED THROUGH │
│ - Agent should pause automation │
└────────┬──────────────────────────────────────────────────┘
│ User clicks "Cede Control" in frontend
┌───────────────────────────────────────────────────────────┐
│ Frontend → MessageChannel │
│ {"kind": "cede-control"} │
└────────┬──────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────┐
│ MessageChannel.handle_data() │
│ vnc_channel.interactor = "agent" │
│ → Back to initial state │
└───────────────────────────────────────────────────────────┘
```
### Error Propagation & Cleanup
The system uses `collect()` (fail-fast gather) for loop management:
```
Channel has 2 concurrent loops:
- Verification loop (polls DB every 5s)
- Streaming loop (handles WebSocket I/O)
collect() behavior:
1. Waits for ANY loop to fail or complete
2. Cancels all other loops
3. Propagates the first exception
Cleanup (always executed via finally):
- channel.close()
- Sets browser_session/task/workflow_run = None
- Closes WebSocket
- Removes from registry
```
### Database Entity Relationships
```
Organization
├─→ BrowserSession ────────┐
│ │
├─→ Task ──────────────────┤
│ (has optional │
│ browser_session) │
│ │
└─→ WorkflowRun ───────────┤
(has optional │
browser_session) │
VncChannel + MessageChannel
(in-memory, paired by client_id)
```
### Key Design Patterns
1. **Channel Pairing**: Two WebSocket connections coordinated via in-memory registry
2. **Fail-Fast Loops**: `collect()` ensures any loop failure closes the entire channel
3. **Interactor Mode**: Binary state controlling whether user input is allowed
4. **On-Demand CDP**: ExecutionChannel creates temporary connections for each operation
5. **Polling Verification**: Every 5s, channels verify their backing entity still exists
6. **Pass-Through Proxy**: API server intercepts but doesn't transform RFB data

View File

@@ -0,0 +1,185 @@
"""
A channel for connecting to a persistent browser instance.
What this channel looks like:
[API Server] <--> [Browser (CDP)]
Channel data:
CDP protocol data, with Playwright thrown in.
"""
from __future__ import annotations
import asyncio
import typing as t
import structlog
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
from skyvern.config import settings
if t.TYPE_CHECKING:
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
LOG = structlog.get_logger()
class CdpChannel:
"""
CdpChannel. Relies on a VncChannel - without one, a CdpChannel has no
r'aison d'etre.
"""
def __new__(cls, *_: t.Iterable[t.Any], **__: t.Mapping[str, t.Any]) -> t.Self: # noqa: N805
if cls is CdpChannel:
raise TypeError("CdpChannel class cannot be instantiated directly.")
return super().__new__(cls)
def __init__(self, *, vnc_channel: VncChannel) -> None:
self.vnc_channel = vnc_channel
# --
self.browser: Browser | None = None
self.browser_context: BrowserContext | None = None
self.page: Page | None = None
self.pw: Playwright | None = None
self.url: str | None = None
@property
def class_name(self) -> str:
return self.__class__.__name__
@property
def identity(self) -> t.Dict[str, t.Any]:
base = self.vnc_channel.identity
return base | {"url": self.url}
async def connect(self, cdp_url: str | None = None) -> t.Self:
"""
Idempotent.
"""
if self.browser and self.browser.is_connected():
return self
await self.close()
if cdp_url:
url = cdp_url
elif self.vnc_channel.browser_session and self.vnc_channel.browser_session.browser_address:
url = self.vnc_channel.browser_session.browser_address
else:
url = settings.BROWSER_REMOTE_DEBUGGING_URL
self.url = url
LOG.info(f"{self.class_name} connecting to CDP", **self.identity)
pw = self.pw or await async_playwright().start()
self.pw = pw
headers = (
{
"x-api-key": self.vnc_channel.x_api_key,
}
if self.vnc_channel.x_api_key
else None
)
def on_close() -> None:
LOG.warning(
f"{self.class_name} closing because the persistent browser disconnected itself.", **self.identity
)
close_task = asyncio.create_task(self.close())
close_task.add_done_callback(lambda _: asyncio.create_task(self.connect())) # TODO: avoid blind reconnect
self.browser = await pw.chromium.connect_over_cdp(url, headers=headers)
self.browser.on("disconnected", on_close)
await self.apply_download_behavior(self.browser)
contexts = self.browser.contexts
if contexts:
LOG.info(f"{self.class_name} using existing browser context", **self.identity)
self.browser_context = contexts[0]
else:
LOG.warning(f"{self.class_name} No existing browser context found, creating new one", **self.identity)
self.browser_context = await self.browser.new_context()
pages = self.browser_context.pages
if pages:
self.page = pages[0]
LOG.info(f"{self.class_name} connected to page", **self.identity)
else:
LOG.warning(f"{self.class_name} No pages found in browser context", **self.identity)
LOG.info(f"{self.class_name} connected successfully", **self.identity)
return self
async def apply_download_behavior(self, browser: Browser) -> t.Self:
org_id = self.vnc_channel.organization_id
browser_session_id = (
self.vnc_channel.browser_session.persistent_browser_session_id if self.vnc_channel.browser_session else None
)
download_path = f"/app/downloads/{org_id}/{browser_session_id}" if browser_session_id else "/app/downloads/"
cdp_session = await browser.new_browser_cdp_session()
await cdp_session.send(
"Browser.setDownloadBehavior",
{
"behavior": "allow",
"downloadPath": download_path,
"eventsEnabled": True,
},
)
await cdp_session.detach()
return self
async def close(self) -> None:
LOG.info(f"{self.class_name} closing connection", **self.identity)
if self.browser:
try:
await self.browser.close()
except Exception:
pass
self.browser = None
if self.pw:
await self.pw.stop()
self.pw = None
self.browser_context = None
self.page = None
LOG.info(f"{self.class_name} closed", **self.identity)
async def evaluate_js(
self,
expression: str,
arg: str | int | float | bool | list | dict | None = None,
) -> str | int | float | bool | list | dict | None:
await self.connect()
if not self.page:
raise RuntimeError(f"{self.class_name} evaluate_js: not connected to a page. Call connect() first.")
LOG.info(f"{self.class_name} evaluating js", expression=expression[:100], **self.identity)
try:
result = await self.page.evaluate(expression, arg)
LOG.info(f"{self.class_name} evaluated js successfully", **self.identity)
return result
except Exception:
LOG.exception(f"{self.class_name} failed to evaluate js", expression=expression, **self.identity)
raise

View File

@@ -0,0 +1,121 @@
"""
A channel for executing JavaScript against a persistent browser instance.
What this channel looks like:
[API Server] <--> [Browser (CDP)]
Channel data:
Chrome DevTools Protocol (CDP) over WebSockets. We cheat and use Playwright.
"""
from __future__ import annotations
import typing as t
from contextlib import asynccontextmanager
import structlog
from skyvern.forge.sdk.routes.streaming.channels.cdp import CdpChannel
if t.TYPE_CHECKING:
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
LOG = structlog.get_logger()
class ExecutionChannel(CdpChannel):
"""
ExecutionChannel.
"""
@property
def class_name(self) -> str:
return self.__class__.__name__
async def get_selected_text(self) -> str:
LOG.info(f"{self.class_name} getting selected text", **self.identity)
js_expression = """
() => {
const selection = window.getSelection();
return selection ? selection.toString() : '';
}
"""
selected_text = await self.evaluate_js(js_expression, self.page)
if isinstance(selected_text, str) or selected_text is None:
LOG.info(
f"{self.class_name} got selected text",
length=len(selected_text) if selected_text else 0,
**self.identity,
)
return selected_text or ""
raise RuntimeError(f"{self.class_name} selected text is not a string, but a(n) '{type(selected_text)}'")
async def paste_text(self, text: str) -> None:
LOG.info(f"{self.class_name} pasting text", **self.identity)
js_expression = """
(text) => {
const activeElement = document.activeElement;
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
const start = activeElement.selectionStart || 0;
const end = activeElement.selectionEnd || 0;
const value = activeElement.value || '';
activeElement.value = value.slice(0, start) + text + value.slice(end);
const newCursorPos = start + text.length;
activeElement.setSelectionRange(newCursorPos, newCursorPos);
}
}
"""
await self.evaluate_js(js_expression, text)
LOG.info(f"{self.class_name} pasted text successfully", **self.identity)
async def close(self) -> None:
LOG.info(f"{self.class_name} closing connection", **self.identity)
if self.browser:
await self.browser.close()
self.browser = None
self.browser_context = None
self.page = None
if self.pw:
await self.pw.stop()
self.pw = None
LOG.info(f"{self.class_name} closed", **self.identity)
@asynccontextmanager
async def execution_channel(vnc_channel: VncChannel) -> t.AsyncIterator[ExecutionChannel]:
"""
The first pass at this has us doing the following for every operation:
- creating a new channel
- connecting
- [doing smth]
- closing the channel
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
If it turns out it's too slow, we can refactor to keep a persistent channel per vnc client.
"""
channel = ExecutionChannel(vnc_channel=vnc_channel)
try:
await channel.connect()
# NOTE(jdo:streaming-local-dev)
# from skyvern.config import settings
# await channel.connect(settings.BROWSER_REMOTE_DEBUGGING_URL)
yield channel
finally:
await channel.close()

View File

@@ -0,0 +1 @@
# stub

View File

@@ -0,0 +1,456 @@
"""
A channel for streaming whole messages between our frontend and our API server.
This channel can access a persistent browser instance through the execution channel.
What this channel looks like:
[Skyvern App] <--> [API Server]
Channel data:
JSON over WebSockets. Semantics are fire and forget. Req-resp is built on
top of that using message types.
"""
import asyncio
import dataclasses
import typing as t
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
from websockets.exceptions import ConnectionClosedError
from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel
from skyvern.forge.sdk.routes.streaming.registries import (
add_message_channel,
del_message_channel,
get_vnc_channel,
)
from skyvern.forge.sdk.routes.streaming.verify import (
loop_verify_browser_session,
loop_verify_workflow_run,
verify_browser_session,
verify_workflow_run,
)
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.utils.aio import collect
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
LOG = structlog.get_logger()
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
MessageKinds = t.Literal[
"ask-for-clipboard-response",
"cede-control",
"take-control",
]
@dataclasses.dataclass
class Message:
kind: MessageKinds
@dataclasses.dataclass
class MessageInTakeControl(Message):
kind: t.Literal["take-control"] = "take-control"
@dataclasses.dataclass
class MessageInCedeControl(Message):
kind: t.Literal["cede-control"] = "cede-control"
@dataclasses.dataclass
class MessageInAskForClipboardResponse(Message):
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
text: str = ""
ChannelMessage = t.Union[
MessageInAskForClipboardResponse,
MessageInCedeControl,
MessageInTakeControl,
]
def reify_channel_message(data: dict) -> ChannelMessage:
kind = data.get("kind", None)
match kind:
case "ask-for-clipboard-response":
text = data.get("text") or ""
return MessageInAskForClipboardResponse(text=text)
case "cede-control":
return MessageInCedeControl()
case "take-control":
return MessageInTakeControl()
case _:
raise ValueError(f"Unknown message kind: '{kind}'")
@dataclasses.dataclass
class MessageChannel:
"""
A message channel for streaming JSON messages between our frontend and our API server.
"""
client_id: str
organization_id: str
websocket: WebSocket
# --
out_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) # warn: unbounded
browser_session: AddressablePersistentBrowserSession | None = None
workflow_run: WorkflowRun | None = None
def __post_init__(self) -> None:
add_message_channel(self)
@property
def class_name(self) -> str:
return self.__class__.__name__
@property
def identity(self) -> dict[str, str]:
base = {"organization_id": self.organization_id}
if self.browser_session:
return base | {"browser_session_id": self.browser_session.persistent_browser_session_id}
if self.workflow_run:
return base | {"workflow_run_id": self.workflow_run.id}
return base
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
LOG.info(f"{self.class_name} closing message stream.", reason=reason, code=code, **self.identity)
self.browser_session = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
del_message_channel(self.client_id)
return self
@property
def is_open(self) -> bool:
if self.websocket.client_state != WebSocketState.CONNECTED:
return False
return True
async def drain(self) -> list[dict]:
datums: list[dict] = []
tasks = [
asyncio.create_task(self.receive_from_out_queue()),
asyncio.create_task(self.receive_from_user()),
]
results = await asyncio.gather(*tasks)
for result in results:
datums.extend(result)
if datums:
LOG.info(f"{self.class_name} Drained {len(datums)} messages from message channel.", **self.identity)
return datums
async def receive_from_user(self) -> list[dict]:
datums: list[dict] = []
while True:
try:
data = await asyncio.wait_for(self.websocket.receive_json(), timeout=0.001)
datums.append(data)
except asyncio.TimeoutError:
break
except RuntimeError:
if "not connected" in str(RuntimeError).lower():
break
except Exception:
LOG.exception(f"{self.class_name} Failed to receive message from message channel", **self.identity)
break
return datums
async def receive_from_out_queue(self) -> list[dict]:
datums: list[dict] = []
while True:
try:
data = await asyncio.wait_for(self.out_queue.get(), timeout=0.001)
datums.append(data)
except asyncio.TimeoutError:
break
except asyncio.QueueEmpty:
break
return datums
def receive_from_out_queue_nowait(self) -> list[dict]:
datums: list[dict] = []
while True:
try:
data = self.out_queue.get_nowait()
datums.append(data)
except asyncio.QueueEmpty:
break
return datums
async def send(self, *, messages: list[dict]) -> t.Self:
for message in messages:
await self.out_queue.put(message)
return self
def send_nowait(self, *, messages: list[dict]) -> t.Self:
for message in messages:
self.out_queue.put_nowait(message)
return self
async def ask_for_clipboard(self) -> None:
LOG.info(f"{self.class_name} Sending ask-for-clipboard to message channel", **self.identity)
try:
await self.websocket.send_json(
{
"kind": "ask-for-clipboard",
}
)
except Exception:
LOG.exception(f"{self.class_name} Failed to send ask-for-clipboard to message channel", **self.identity)
async def send_copied_text(self, copied_text: str) -> None:
LOG.info(f"{self.class_name} Sending copied text to message channel", **self.identity)
try:
await self.websocket.send_json(
{
"kind": "copied-text",
"text": copied_text,
}
)
except Exception:
LOG.exception(f"{self.class_name} Failed to send copied text to message channel", **self.identity)
async def loop_stream_messages(message_channel: MessageChannel) -> None:
"""
Stream messages and their results back and forth.
Loops until the websocket is closed.
"""
class_name = message_channel.class_name
async def handle_data(data: dict) -> None:
nonlocal class_name
try:
message = reify_channel_message(data)
except ValueError:
LOG.error(f"MessageChannel: cannot reify channel message from data: {data}", **message_channel.identity)
return
message_kind = message.kind
match message_kind:
case "ask-for-clipboard-response":
if not isinstance(message, MessageInAskForClipboardResponse):
LOG.error(
f"{class_name} invalid message type for ask-for-clipboard-response.",
message=message,
**message_channel.identity,
)
return
vnc_channel = get_vnc_channel(message_channel.client_id)
if not vnc_channel:
LOG.error(
f"{class_name} no vnc channel found for message channel.",
message=message,
**message_channel.identity,
)
return
text = message.text
async with execution_channel(vnc_channel) as execute:
await execute.paste_text(text)
case "cede-control":
vnc_channel = get_vnc_channel(message_channel.client_id)
if not vnc_channel:
LOG.error(
f"{class_name} no vnc channel client found for message channel.",
message=message,
**message_channel.identity,
)
return
vnc_channel.interactor = "agent"
case "take-control":
LOG.info(f"{class_name} processing take-control message.", **message_channel.identity)
vnc_channel = get_vnc_channel(message_channel.client_id)
if not vnc_channel:
LOG.error(
f"{class_name} no vnc channel client found for message channel.",
message=message,
**message_channel.identity,
)
return
vnc_channel.interactor = "user"
case _:
LOG.error(f"{class_name} unknown message kind: '{message_kind}'", **message_channel.identity)
return
async def frontend_to_backend() -> None:
nonlocal class_name
LOG.info(f"{class_name} starting frontend-to-backend loop.", **message_channel.identity)
while message_channel.is_open:
try:
datums = await message_channel.drain()
for data in datums:
if not isinstance(data, dict):
LOG.error(
f"{class_name} cannot create message: expected dict, got {type(data)}",
**message_channel.identity,
)
continue
await handle_data(data)
except WebSocketDisconnect:
LOG.info(f"{class_name} frontend disconnected.", **message_channel.identity)
raise
except ConnectionClosedError:
LOG.info(f"{class_name} frontend closed channel.", **message_channel.identity)
raise
except Exception:
LOG.exception(f"{class_name} An unexpected exception occurred.", **message_channel.identity)
raise
loops = [
asyncio.create_task(frontend_to_backend()),
]
try:
await collect(loops)
except Exception:
LOG.exception(f"{class_name} An exception occurred in loop message channel stream.", **message_channel.identity)
finally:
LOG.info(f"{class_name} Closing the message channel stream.", **message_channel.identity)
await message_channel.close(reason="loop-channel-closed")
async def get_message_channel_for_browser_session(
client_id: str,
browser_session_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[MessageChannel, Loops] | None:
"""
Return a message channel for a browser session, with a list of loops to run concurrently.
"""
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
browser_session = await verify_browser_session(
browser_session_id=browser_session_id,
organization_id=organization_id,
)
if not browser_session:
LOG.info(
"Message channel: no initial browser session found.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
return None
message_channel = MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
websocket=websocket,
)
LOG.info("Got message channel for browser session.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_browser_session(message_channel)),
asyncio.create_task(loop_stream_messages(message_channel)),
]
return message_channel, loops
async def get_message_channel_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[MessageChannel, Loops] | None:
"""
Return a message channel for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info(
"Message channel: no initial workflow run found.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
if not browser_session:
LOG.info(
"Message channel: no initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
message_channel = MessageChannel(
client_id,
organization_id,
browser_session=browser_session,
websocket=websocket,
workflow_run=workflow_run,
)
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_workflow_run(message_channel)),
asyncio.create_task(loop_stream_messages(message_channel)),
]
return message_channel, loops

View File

@@ -0,0 +1,592 @@
"""
A channel for streaming the VNC protocol data between our frontend and a
persistent browser instance.
This is a pass-thru channel, through our API server. As such, we can monitor and/or
intercept RFB protocol messages as needed.
What this channel looks like:
[Skyvern App] <--> [API Server] <--> [websockified noVNC] <--> [Browser]
Channel data:
One could call this RFB over WebSockets (rockets?), as the protocol data streaming
over the WebSocket is raw RFB protocol data.
"""
import asyncio
import dataclasses
import typing as t
from enum import IntEnum
from urllib.parse import urlparse
import structlog
import websockets
from fastapi import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
from websockets import ConnectionClosedError, Data
from skyvern.config import settings
from skyvern.forge.sdk.routes.streaming.auth import get_x_api_key
from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel
from skyvern.forge.sdk.routes.streaming.registries import (
add_vnc_channel,
del_vnc_channel,
get_message_channel,
get_vnc_channel,
)
from skyvern.forge.sdk.routes.streaming.verify import (
loop_verify_browser_session,
loop_verify_task,
loop_verify_workflow_run,
verify_browser_session,
verify_task,
verify_workflow_run,
)
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.utils.aio import collect
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
LOG = structlog.get_logger()
Interactor = t.Literal["agent", "user"]
"""
NOTE: we don't really have an "agent" at this time. But any control of the
browser that is not user-originated is kinda' agent-like, by some
definition of "agent". Here, we do not have an "AI agent". Future work may
alter this state of affairs - and some "agent" could operate the browser
automatically. In any case, if the interactor is not a "user", we assume
it is an "agent".
"""
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
class MessageType(IntEnum):
Keyboard = 4
Mouse = 5
class Keys:
"""
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
"""
class Down:
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
class Up:
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
def is_rmb(data: bytes) -> bool:
return data[0:2] == b"\x05\x04"
class Mouse:
class Up:
Right = is_rmb
@dataclasses.dataclass
class KeyState:
ctrl_is_down: bool = False
alt_is_down: bool = False
cmd_is_down: bool = False
def is_forbidden(self, data: bytes) -> bool:
"""
:return: True if the key is forbidden, else False
"""
return self.is_ctrl_o(data)
def is_ctrl_o(self, data: bytes) -> bool:
"""
Do not allow the opening of files.
"""
return self.ctrl_is_down and data == Keys.Down.OKey
def is_copy(self, data: bytes) -> bool:
"""
Detect Ctrl+C or Cmd+C for copy.
"""
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
def is_paste(self, data: bytes) -> bool:
"""
Detect Ctrl+V or Cmd+V for paste.
"""
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
@dataclasses.dataclass
class VncChannel:
"""
A VNC channel for streaming RFB protocol data between our frontend app, and
a remote browser.
"""
client_id: str
"""
Unique to a frontend app instance.
"""
organization_id: str
vnc_port: int
x_api_key: str
websocket: WebSocket
initial_interactor: dataclasses.InitVar[Interactor]
"""
The role of the entity interacting with the channel, either "agent" or "user".
"""
_interactor: Interactor = dataclasses.field(init=False, repr=False)
# --
browser_session: AddressablePersistentBrowserSession | None = None
key_state: KeyState = dataclasses.field(default_factory=KeyState)
task: Task | None = None
workflow_run: WorkflowRun | None = None
def __post_init__(self, initial_interactor: Interactor) -> None:
self.interactor = initial_interactor
add_vnc_channel(self)
@property
def class_name(self) -> str:
return self.__class__.__name__
@property
def identity(self) -> dict:
base = {"organization_id": self.organization_id}
if self.task:
return base | {"task_id": self.task.task_id}
elif self.workflow_run:
return base | {"workflow_run_id": self.workflow_run.workflow_run_id}
elif self.browser_session:
return base | {"browser_session_id": self.browser_session.persistent_browser_session_id}
else:
return base | {"client_id": self.client_id}
@property
def interactor(self) -> Interactor:
return self._interactor
@interactor.setter
def interactor(self, value: Interactor) -> None:
self._interactor = value
LOG.info(f"{self.class_name} Setting interactor to {value}", **self.identity)
@property
def is_open(self) -> bool:
if self.websocket.client_state != WebSocketState.CONNECTED:
return False
if not self.task and not self.workflow_run and not self.browser_session:
return False
if not get_vnc_channel(self.client_id):
return False
return True
async def close(self, code: int = 1000, reason: str | None = None) -> t.Self:
LOG.info(f"{self.class_name} closing.", reason=reason, code=code, **self.identity)
self.browser_session = None
self.task = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
del_vnc_channel(self.client_id)
return self
def update_key_state(self, data: bytes) -> None:
if data == Keys.Down.Ctrl:
self.key_state.ctrl_is_down = True
elif data == Keys.Up.Ctrl:
self.key_state.ctrl_is_down = False
elif data == Keys.Down.Alt:
self.key_state.alt_is_down = True
elif data == Keys.Up.Alt:
self.key_state.alt_is_down = False
elif data == Keys.Down.Cmd:
self.key_state.cmd_is_down = True
elif data == Keys.Up.Cmd:
self.key_state.cmd_is_down = False
async def copy_text(vnc_channel: VncChannel) -> None:
class_name = vnc_channel.class_name
LOG.info(f"{class_name} Retrieving selected text via CDP", **vnc_channel.identity)
try:
async with execution_channel(vnc_channel) as execute:
copied_text = await execute.get_selected_text()
message_channel = get_message_channel(vnc_channel.client_id)
if message_channel:
await message_channel.send_copied_text(copied_text)
else:
LOG.warning(
f"{class_name} No message channel found for client, or it is not open",
message_channel=message_channel,
**vnc_channel.identity,
)
except Exception:
LOG.exception(f"{class_name} Failed to retrieve selected text via CDP", **vnc_channel.identity)
async def ask_for_clipboard(vnc_channel: VncChannel) -> None:
class_name = vnc_channel.class_name
LOG.info(f"{class_name} Asking for clipboard data via CDP", **vnc_channel.identity)
try:
message_channel = get_message_channel(vnc_channel.client_id)
if message_channel:
await message_channel.ask_for_clipboard()
else:
LOG.warning(
f"{class_name} No message channel found for client, or it is not open",
message_channel=message_channel,
**vnc_channel.identity,
)
except Exception:
LOG.exception(f"{class_name} Failed to ask for clipboard via CDP", **vnc_channel.identity)
async def loop_stream_vnc(vnc_channel: VncChannel) -> None:
"""
Actually stream the VNC data between a frontend and a browser.
Loops until the task is cleared or the websocket is closed.
"""
vnc_url: str = ""
browser_session = vnc_channel.browser_session
class_name = vnc_channel.class_name
if browser_session:
if browser_session.ip_address:
if ":" in browser_session.ip_address:
ip, _ = browser_session.ip_address.split(":")
vnc_url = f"ws://{ip}:{vnc_channel.vnc_port}"
else:
vnc_url = f"ws://{browser_session.ip_address}:{vnc_channel.vnc_port}"
else:
browser_address = browser_session.browser_address
parsed_browser_address = urlparse(browser_address)
host = parsed_browser_address.hostname
vnc_url = f"ws://{host}:{vnc_channel.vnc_port}"
else:
raise Exception(f"{class_name} No browser session associated with vnc channel.")
# NOTE(jdo:streaming-local-dev)
# vnc_url = "ws://localhost:6080"
LOG.info(
f"{class_name} Connecting to vnc url.",
vnc_url=vnc_url,
**vnc_channel.identity,
)
async with websockets.connect(vnc_url) as novnc_ws:
async def frontend_to_browser() -> None:
nonlocal class_name
LOG.info(f"{class_name} Starting frontend-to-browser data transfer.", **vnc_channel.identity)
data: Data | None = None
while vnc_channel.is_open:
try:
data = await vnc_channel.websocket.receive_bytes()
if data:
message_type = data[0]
if message_type == MessageType.Keyboard.value:
vnc_channel.update_key_state(data)
if vnc_channel.key_state.is_copy(data):
await copy_text(vnc_channel)
if vnc_channel.key_state.is_paste(data):
await ask_for_clipboard(vnc_channel)
if vnc_channel.key_state.is_forbidden(data):
continue
# prevent right-mouse-button clicks for "security" reasons
if message_type == MessageType.Mouse.value:
if Mouse.Up.Right(data):
continue
if not vnc_channel.interactor == "user" and message_type in (
MessageType.Keyboard.value,
MessageType.Mouse.value,
):
LOG.debug(f"{class_name} Blocking user message.", **vnc_channel.identity)
continue
except WebSocketDisconnect:
LOG.info(f"{class_name} Frontend disconnected.", **vnc_channel.identity)
raise
except ConnectionClosedError:
LOG.info(f"{class_name} Frontend closed the vnc channel.", **vnc_channel.identity)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(f"{class_name} An unexpected exception occurred.", **vnc_channel.identity)
raise
if not data:
continue
try:
await novnc_ws.send(data)
except WebSocketDisconnect:
LOG.info(f"{class_name} Browser disconnected from vnc.", **vnc_channel.identity)
raise
except ConnectionClosedError:
LOG.info(f"{class_name} Browser closed vnc.", **vnc_channel.identity)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
f"{class_name} An unexpected exception occurred in frontend-to-browser loop.",
**vnc_channel.identity,
)
raise
async def browser_to_frontend() -> None:
nonlocal class_name
LOG.info(f"{class_name} Starting browser-to-frontend data transfer.", **vnc_channel.identity)
data: Data | None = None
while vnc_channel.is_open:
try:
data = await novnc_ws.recv()
except WebSocketDisconnect:
LOG.info(f"{class_name} Browser disconnected from the vnc channel session.", **vnc_channel.identity)
await vnc_channel.close(reason="browser-disconnected")
except ConnectionClosedError:
LOG.info(f"{class_name} Browser closed the vnc channel session.", **vnc_channel.identity)
await vnc_channel.close(reason="browser-closed")
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
f"{class_name} An unexpected exception occurred in browser-to-frontend loop.",
**vnc_channel.identity,
)
raise
if not data:
continue
try:
if vnc_channel.websocket.client_state != WebSocketState.CONNECTED:
continue
await vnc_channel.websocket.send_bytes(data)
except WebSocketDisconnect:
LOG.info(
f"{class_name} Frontend disconnected from the vnc channel session.", **vnc_channel.identity
)
await vnc_channel.close(reason="frontend-disconnected")
except ConnectionClosedError:
LOG.info(f"{class_name} Frontend closed the vnc channel session.", **vnc_channel.identity)
await vnc_channel.close(reason="frontend-closed")
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(f"{class_name} An unexpected exception occurred.", **vnc_channel.identity)
raise
loops = [
asyncio.create_task(frontend_to_browser()),
asyncio.create_task(browser_to_frontend()),
]
try:
await collect(loops)
except WebSocketDisconnect:
pass
except Exception:
LOG.exception(f"{class_name} An exception occurred in loop stream.", **vnc_channel.identity)
finally:
LOG.info(f"{class_name} Closing the loop stream.", **vnc_channel.identity)
await vnc_channel.close(reason="loop-stream-vnc-closed")
async def get_vnc_channel_for_browser_session(
client_id: str,
browser_session_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[VncChannel, Loops] | None:
"""
Return a vnc channel for a browser session, with a list of loops to run concurrently.
"""
LOG.info("Getting vnc context for browser session.", browser_session_id=browser_session_id)
browser_session = await verify_browser_session(
browser_session_id=browser_session_id,
organization_id=organization_id,
)
if not browser_session:
LOG.info(
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
)
return None
x_api_key = await get_x_api_key(organization_id)
try:
vnc_channel = VncChannel(
client_id=client_id,
initial_interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
browser_session=browser_session,
x_api_key=x_api_key,
websocket=websocket,
)
except Exception as e:
LOG.exception("Failed to create VncChannel.", error=str(e))
return None
LOG.info("Got vnc context for browser session.", vnc_channel=vnc_channel)
loops = [
asyncio.create_task(loop_verify_browser_session(vnc_channel)),
asyncio.create_task(loop_stream_vnc(vnc_channel)),
]
return vnc_channel, loops
async def get_vnc_channel_for_task(
client_id: str,
task_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[VncChannel, Loops] | None:
"""
Return a vnc channel for a task, with a list of loops to run concurrently.
"""
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
return None
if not browser_session:
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
return None
x_api_key = await get_x_api_key(organization_id)
vnc_channel = VncChannel(
client_id=client_id,
initial_interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
x_api_key=x_api_key,
websocket=websocket,
browser_session=browser_session,
task=task,
)
loops = [
asyncio.create_task(loop_verify_task(vnc_channel)),
asyncio.create_task(loop_stream_vnc(vnc_channel)),
]
return vnc_channel, loops
async def get_vnc_channel_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[VncChannel, Loops] | None:
"""
Return a vnc channel for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting vnc channel for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
return None
if not browser_session:
LOG.info(
"No initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
x_api_key = await get_x_api_key(organization_id)
vnc_channel = VncChannel(
client_id=client_id,
initial_interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
browser_session=browser_session,
workflow_run=workflow_run,
x_api_key=x_api_key,
websocket=websocket,
)
LOG.info("Got vnc channel context for workflow run.", vnc_channel=vnc_channel)
loops = [
asyncio.create_task(loop_verify_workflow_run(vnc_channel)),
asyncio.create_task(loop_stream_vnc(vnc_channel)),
]
return vnc_channel, loops

View File

@@ -1,328 +0,0 @@
"""
Streaming types.
"""
import asyncio
import dataclasses
import typing as t
from enum import IntEnum
import structlog
from fastapi import WebSocket
from starlette.websockets import WebSocketState
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
LOG = structlog.get_logger()
Interactor = t.Literal["agent", "user"]
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
# Messages
# a global registry for WS message clients
message_channels: dict[str, "MessageChannel"] = {}
def add_message_client(message_channel: "MessageChannel") -> None:
message_channels[message_channel.client_id] = message_channel
def get_message_client(client_id: str) -> t.Union["MessageChannel", None]:
return message_channels.get(client_id, None)
def del_message_client(client_id: str) -> None:
try:
del message_channels[client_id]
except KeyError:
pass
@dataclasses.dataclass
class MessageChannel:
client_id: str
organization_id: str
websocket: WebSocket
# --
browser_session: AddressablePersistentBrowserSession | None = None
workflow_run: WorkflowRun | None = None
def __post_init__(self) -> None:
add_message_client(self)
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
LOG.info("Closing message stream.", reason=reason, code=code)
self.browser_session = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
del_message_client(self.client_id)
return self
@property
def is_open(self) -> bool:
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
return False
if not self.workflow_run and not self.browser_session:
return False
if not get_message_client(self.client_id):
return False
return True
async def ask_for_clipboard(self, streaming: "Streaming") -> None:
try:
await self.websocket.send_json(
{
"kind": "ask-for-clipboard",
}
)
LOG.info(
"Sent ask-for-clipboard to message channel",
organization_id=streaming.organization_id,
)
except Exception:
LOG.exception(
"Failed to send ask-for-clipboard to message channel",
organization_id=streaming.organization_id,
)
async def send_copied_text(self, copied_text: str, streaming: "Streaming") -> None:
try:
await self.websocket.send_json(
{
"kind": "copied-text",
"text": copied_text,
}
)
LOG.info(
"Sent copied text to message channel",
organization_id=streaming.organization_id,
)
except Exception:
LOG.exception(
"Failed to send copied text to message channel",
organization_id=streaming.organization_id,
)
MessageKinds = t.Literal["take-control", "cede-control", "ask-for-clipboard-response"]
@dataclasses.dataclass
class Message:
kind: MessageKinds
@dataclasses.dataclass
class MessageTakeControl(Message):
kind: t.Literal["take-control"] = "take-control"
@dataclasses.dataclass
class MessageCedeControl(Message):
kind: t.Literal["cede-control"] = "cede-control"
@dataclasses.dataclass
class MessageInAskForClipboardResponse(Message):
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
text: str = ""
ChannelMessage = t.Union[MessageTakeControl, MessageCedeControl, MessageInAskForClipboardResponse]
def reify_channel_message(data: dict) -> ChannelMessage:
kind = data.get("kind", None)
match kind:
case "take-control":
return MessageTakeControl()
case "cede-control":
return MessageCedeControl()
case "ask-for-clipboard-response":
text = data.get("text") or ""
return MessageInAskForClipboardResponse(text=text)
case _:
raise ValueError(f"Unknown message kind: '{kind}'")
# Streaming
# a global registry for WS streaming VNC clients
streaming_clients: dict[str, "Streaming"] = {}
def add_streaming_client(streaming: "Streaming") -> None:
streaming_clients[streaming.client_id] = streaming
def get_streaming_client(client_id: str) -> t.Union["Streaming", None]:
return streaming_clients.get(client_id, None)
def del_streaming_client(client_id: str) -> None:
try:
del streaming_clients[client_id]
except KeyError:
pass
class MessageType(IntEnum):
Keyboard = 4
Mouse = 5
class Keys:
"""
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
"""
class Down:
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
class Up:
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
def is_rmb(data: bytes) -> bool:
return data[0:2] == b"\x05\x04"
class Mouse:
class Up:
Right = is_rmb
@dataclasses.dataclass
class KeyState:
ctrl_is_down: bool = False
alt_is_down: bool = False
cmd_is_down: bool = False
def is_forbidden(self, data: bytes) -> bool:
"""
:return: True if the key is forbidden, else False
"""
return self.is_ctrl_o(data)
def is_ctrl_o(self, data: bytes) -> bool:
"""
Do not allow the opening of files.
"""
return self.ctrl_is_down and data == Keys.Down.OKey
def is_copy(self, data: bytes) -> bool:
"""
Detect Ctrl+C or Cmd+C for copy.
"""
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
def is_paste(self, data: bytes) -> bool:
"""
Detect Ctrl+V or Cmd+V for paste.
"""
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
@dataclasses.dataclass
class Streaming:
"""
Streaming state.
"""
client_id: str
"""
Unique to frontend app instance.
"""
interactor: Interactor
"""
Whether the user or the agent are the interactor.
"""
organization_id: str
vnc_port: int
x_api_key: str
websocket: WebSocket
# --
browser_session: AddressablePersistentBrowserSession | None = None
key_state: KeyState = dataclasses.field(default_factory=KeyState)
task: Task | None = None
workflow_run: WorkflowRun | None = None
def __post_init__(self) -> None:
add_streaming_client(self)
@property
def is_open(self) -> bool:
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
return False
if not self.task and not self.workflow_run and not self.browser_session:
return False
if not get_streaming_client(self.client_id):
return False
return True
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
LOG.info("Closing Streaming.", reason=reason, code=code)
self.browser_session = None
self.task = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
del_streaming_client(self.client_id)
return self
def update_key_state(self, data: bytes) -> None:
if data == Keys.Down.Ctrl:
self.key_state.ctrl_is_down = True
elif data == Keys.Up.Ctrl:
self.key_state.ctrl_is_down = False
elif data == Keys.Down.Alt:
self.key_state.alt_is_down = True
elif data == Keys.Up.Alt:
self.key_state.alt_is_down = False
elif data == Keys.Down.Cmd:
self.key_state.cmd_is_down = True
elif data == Keys.Up.Cmd:
self.key_state.cmd_is_down = False

View File

@@ -1,240 +1,24 @@
"""
Streaming messages for WebSocket connections.
Provides WS endpoints for streaming messages to/from our frontend application.
"""
import asyncio
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from fastapi import WebSocket
import skyvern.forge.sdk.routes.streaming.clients as sc
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
from skyvern.forge.sdk.routes.streaming.auth import auth
from skyvern.forge.sdk.routes.streaming.verify import (
loop_verify_browser_session,
loop_verify_workflow_run,
verify_browser_session,
verify_workflow_run,
from skyvern.forge.sdk.routes.streaming.channels.message import (
Loops,
MessageChannel,
get_message_channel_for_browser_session,
get_message_channel_for_workflow_run,
)
from skyvern.forge.sdk.utils.aio import collect
LOG = structlog.get_logger()
async def get_messages_for_browser_session(
client_id: str,
browser_session_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.MessageChannel, sc.Loops] | None:
"""
Return a message channel for a browser session, with a list of loops to run concurrently.
"""
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
browser_session = await verify_browser_session(
browser_session_id=browser_session_id,
organization_id=organization_id,
)
if not browser_session:
LOG.info(
"Message channel: no initial browser session found.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
return None
message_channel = sc.MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
websocket=websocket,
)
LOG.info("Got message channel for browser session.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_browser_session(message_channel)),
asyncio.create_task(loop_channel(message_channel)),
]
return message_channel, loops
async def get_messages_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.MessageChannel, sc.Loops] | None:
"""
Return a message channel for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info(
"Message channel: no initial workflow run found.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
if not browser_session:
LOG.info(
"Message channel: no initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
message_channel = sc.MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
workflow_run=workflow_run,
websocket=websocket,
)
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_workflow_run(message_channel)),
asyncio.create_task(loop_channel(message_channel)),
]
return message_channel, loops
async def loop_channel(message_channel: sc.MessageChannel) -> None:
"""
Stream messages and their results back and forth.
Loops until the workflow run is cleared or the websocket is closed.
"""
if not message_channel.browser_session:
LOG.info(
"No browser session found for workflow run.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
return
async def frontend_to_backend() -> None:
LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel)
while message_channel.is_open:
try:
data = await message_channel.websocket.receive_json()
if not isinstance(data, dict):
LOG.error(f"Cannot create channel message: expected dict, got {type(data)}")
continue
try:
message = sc.reify_channel_message(data)
except ValueError:
continue
message_kind = message.kind
match message_kind:
case "take-control":
streaming = sc.get_streaming_client(message_channel.client_id)
if not streaming:
LOG.error(
"No streaming client found for message.",
message_channel=message_channel,
message=message,
)
continue
streaming.interactor = "user"
case "cede-control":
streaming = sc.get_streaming_client(message_channel.client_id)
if not streaming:
LOG.error(
"No streaming client found for message.",
message_channel=message_channel,
message=message,
)
continue
streaming.interactor = "agent"
case "ask-for-clipboard-response":
if not isinstance(message, sc.MessageInAskForClipboardResponse):
LOG.error(
"Invalid message type for ask-for-clipboard-response.",
message_channel=message_channel,
message=message,
)
continue
streaming = sc.get_streaming_client(message_channel.client_id)
text = message.text
async with connected_agent(streaming) as agent:
await agent.paste_text(text)
case _:
LOG.error(f"Unknown message kind: '{message_kind}'")
continue
except WebSocketDisconnect:
LOG.info(
"Frontend disconnected.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
loops = [
asyncio.create_task(frontend_to_backend()),
]
try:
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in loop channel stream.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
finally:
LOG.info(
"Closing the loop channel stream.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
await message_channel.close(reason="loop-channel-closed")
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
async def browser_session_messages(
websocket: WebSocket,
browser_session_id: str,
@@ -242,64 +26,16 @@ async def browser_session_messages(
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.error("Authentication failed.", browser_session_id=browser_session_id)
return
if not client_id:
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
return
message_channel: sc.MessageChannel
loops: list[asyncio.Task] = []
result = await get_messages_for_browser_session(
client_id=client_id,
browser_session_id=browser_session_id,
organization_id=organization_id,
return await messages(
websocket=websocket,
browser_session_id=browser_session_id,
apikey=apikey,
client_id=client_id,
token=token,
)
if not result:
LOG.error(
"No streaming context found for the browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
message_channel, loops = result
try:
LOG.info(
"Starting message stream loops for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the message stream function for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the message stream session for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await message_channel.close(reason="stream-closed")
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
async def workflow_run_messages(
websocket: WebSocket,
workflow_run_id: str,
@@ -307,31 +43,77 @@ async def workflow_run_messages(
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info("Starting message stream.", workflow_run_id=workflow_run_id)
return await messages(
websocket=websocket,
workflow_run_id=workflow_run_id,
apikey=apikey,
client_id=client_id,
token=token,
)
async def messages(
websocket: WebSocket,
browser_session_id: str | None = None,
workflow_run_id: str | None = None,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info(
"Starting message stream.",
browser_session_id=browser_session_id,
workflow_run_id=workflow_run_id,
)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
LOG.error(
"Authentication failed.",
browser_session_id=browser_session_id,
workflow_run_id=workflow_run_id,
)
return
if not client_id:
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
LOG.error(
"No client ID provided.",
browser_session_id=browser_session_id,
workflow_run_id=workflow_run_id,
)
return
message_channel: sc.MessageChannel
loops: list[asyncio.Task] = []
message_channel: MessageChannel
loops: Loops = []
result = await get_messages_for_workflow_run(
client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
websocket=websocket,
)
if browser_session_id:
result = await get_message_channel_for_browser_session(
client_id=client_id,
browser_session_id=browser_session_id,
organization_id=organization_id,
websocket=websocket,
)
elif workflow_run_id:
result = await get_message_channel_for_workflow_run(
client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
websocket=websocket,
)
else:
LOG.error(
"Message channel: no browser_session_id or workflow_run_id provided.",
client_id=client_id,
organization_id=organization_id,
)
await websocket.close(code=1002)
return
if not result:
LOG.error(
"No streaming context found for the workflow run.",
"No message channel found.",
browser_session_id=browser_session_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@@ -342,22 +124,25 @@ async def workflow_run_messages(
try:
LOG.info(
"Starting message stream loops.",
"Starting message channel loops.",
browser_session_id=browser_session_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the message stream function.",
"An exception occurred in the message loop function(s).",
browser_session_id=browser_session_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the message stream session.",
"Closing the message channel.",
browser_session_id=browser_session_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await message_channel.close(reason="stream-closed")
await message_channel.close(reason="message-stream-closed")

View File

@@ -0,0 +1,78 @@
"""
Contains registries for coordinating active WS connections (aka "channels", see
`./channels/README.md`).
NOTE: in AWS we had to turn on what amounts to sticky sessions for frontend apps,
so that an individual frontend app instance is guaranteed to always connect to
the same backend api instance. This is beccause the two registries here are
tied together via a `client_id` string.
The tale-of-the-tape is this:
- frontend app requires two different channels (WS connections) to the backend api
- one dedicated to streaming VNC's RFB protocol
- the other dedicated to messaging (JSON)
- both of these channels are stateful and need to coordinate with one another
"""
from __future__ import annotations
import typing as t
import structlog
if t.TYPE_CHECKING:
from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
LOG = structlog.get_logger()
# a registry for VNC channels, keyed by `client_id`
vnc_channels: dict[str, VncChannel] = {}
def add_vnc_channel(vnc_channel: VncChannel) -> None:
vnc_channels[vnc_channel.client_id] = vnc_channel
def get_vnc_channel(client_id: str) -> t.Union[VncChannel, None]:
return vnc_channels.get(client_id, None)
def del_vnc_channel(client_id: str) -> None:
try:
del vnc_channels[client_id]
except KeyError:
pass
# a registry for message channels, keyed by `client_id`
message_channels: dict[str, MessageChannel] = {}
def add_message_channel(message_channel: MessageChannel) -> None:
message_channels[message_channel.client_id] = message_channel
def get_message_channel(client_id: str) -> t.Union[MessageChannel, None]:
candidate = message_channels.get(client_id, None)
if candidate and candidate.is_open:
return candidate
if candidate:
LOG.info(
"MessageChannel: message channel is not open; deleting it",
client_id=candidate.client_id,
)
del_message_channel(candidate.client_id)
return None
def del_message_channel(client_id: str) -> None:
try:
del message_channels[client_id]
except KeyError:
pass

View File

@@ -1,3 +1,14 @@
"""
Provides WS endpoints for streaming screenshots.
Screenshot streaming is created on the basis of one of these database entities:
- task (run)
- workflow run
Screenshot streaming is used for a run that is invoked without a browser session.
Otherwise, VNC streaming is used.
"""
import asyncio
import base64
from datetime import datetime

View File

@@ -1,18 +1,43 @@
"""
Channels (see ./channels/README.md) variously rely on the state of one of
these database entities:
- browser session
- task
- workflow run
That is, channels are created on the basis of one of those entities, and that
entity must be in a valid state for the channel to continue.
So, in order to continue operating a channel, we need to periodically verify
that the underlying entity is still valid. This module provides logic to
perform those verifications.
"""
from __future__ import annotations
import asyncio
import typing as t
from datetime import datetime
import structlog
import skyvern.forge.sdk.routes.streaming.clients as sc
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
if t.TYPE_CHECKING:
from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
LOG = structlog.get_logger()
class Constants:
POLL_INTERVAL_FOR_VERIFICATION_SECONDS = 5
async def verify_browser_session(
browser_session_id: str,
organization_id: str,
@@ -122,9 +147,7 @@ async def verify_task(
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
)
except Exception as e:
LOG.error(
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
)
LOG.error("vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e)
return task, None
return task, addressable_browser_session
@@ -238,7 +261,7 @@ async def verify_workflow_run(
return workflow_run, addressable_browser_session
async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streaming) -> None:
async def loop_verify_browser_session(verifiable: MessageChannel | VncChannel) -> None:
"""
Loop until the browser session is cleared or the websocket is closed.
"""
@@ -251,27 +274,27 @@ async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streami
verifiable.browser_session = browser_session
await asyncio.sleep(2)
await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS)
async def loop_verify_task(streaming: sc.Streaming) -> None:
async def loop_verify_task(vnc_channel: VncChannel) -> None:
"""
Loop until the task is cleared or the websocket is closed.
"""
while streaming.task and streaming.is_open:
while vnc_channel.task and vnc_channel.is_open:
task, browser_session = await verify_task(
task_id=streaming.task.task_id,
organization_id=streaming.organization_id,
task_id=vnc_channel.task.task_id,
organization_id=vnc_channel.organization_id,
)
streaming.task = task
streaming.browser_session = browser_session
vnc_channel.task = task
vnc_channel.browser_session = browser_session
await asyncio.sleep(2)
await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS)
async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) -> None:
async def loop_verify_workflow_run(verifiable: MessageChannel | VncChannel) -> None:
"""
Loop until the workflow run is cleared or the websocket is closed.
"""
@@ -285,4 +308,4 @@ async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming)
verifiable.workflow_run = workflow_run
verifiable.browser_session = browser_session
await asyncio.sleep(2)
await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS)

View File

@@ -1,5 +1,5 @@
"""
Streaming VNC WebSocket connections.
Provides WS endpoints for streaming a remote browser via VNC.
NOTE(jdo:streaming-local-dev)
-----------------------------
@@ -9,459 +9,23 @@ NOTE(jdo:streaming-local-dev)
"""
import asyncio
import typing as t
from urllib.parse import urlparse
import structlog
import websockets
from fastapi import WebSocket, WebSocketDisconnect
from websockets import Data
from websockets.exceptions import ConnectionClosedError
from fastapi import WebSocket
import skyvern.forge.sdk.routes.streaming.clients as sc
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
from skyvern.forge.sdk.routes.streaming.auth import auth
from skyvern.forge.sdk.routes.streaming.verify import (
loop_verify_browser_session,
loop_verify_task,
loop_verify_workflow_run,
verify_browser_session,
verify_task,
verify_workflow_run,
from skyvern.forge.sdk.routes.streaming.channels.vnc import (
Loops,
VncChannel,
get_vnc_channel_for_browser_session,
get_vnc_channel_for_task,
get_vnc_channel_for_workflow_run,
)
from skyvern.forge.sdk.utils.aio import collect
LOG = structlog.get_logger()
class Constants:
MissingXApiKey = "<missing-x-api-key>"
async def get_x_api_key(organization_id: str) -> str:
token = await app.DATABASE.get_valid_org_auth_token(
organization_id,
OrganizationAuthTokenType.api.value,
)
if not token:
LOG.warning(
"No valid API key found for organization when streaming.",
organization_id=organization_id,
)
x_api_key = Constants.MissingXApiKey
else:
x_api_key = token.token
return x_api_key
async def get_streaming_for_browser_session(
client_id: str,
browser_session_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a browser session, with a list of loops to run concurrently.
"""
LOG.info("Getting streaming context for browser session.", browser_session_id=browser_session_id)
browser_session = await verify_browser_session(
browser_session_id=browser_session_id,
organization_id=organization_id,
)
if not browser_session:
LOG.info(
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
)
return None
x_api_key = await get_x_api_key(organization_id)
streaming = sc.Streaming(
client_id=client_id,
interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
browser_session=browser_session,
x_api_key=x_api_key,
websocket=websocket,
)
LOG.info("Got streaming context for browser session.", streaming=streaming)
loops = [
asyncio.create_task(loop_verify_browser_session(streaming)),
asyncio.create_task(loop_stream_vnc(streaming)),
]
return streaming, loops
async def get_streaming_for_task(
client_id: str,
task_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a task, with a list of loops to run concurrently.
"""
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
return None
if not browser_session:
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
return None
x_api_key = await get_x_api_key(organization_id)
streaming = sc.Streaming(
client_id=client_id,
interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
x_api_key=x_api_key,
websocket=websocket,
browser_session=browser_session,
task=task,
)
loops = [
asyncio.create_task(loop_verify_task(streaming)),
asyncio.create_task(loop_stream_vnc(streaming)),
]
return streaming, loops
async def get_streaming_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting streaming context for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
return None
if not browser_session:
LOG.info(
"No initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
x_api_key = await get_x_api_key(organization_id)
streaming = sc.Streaming(
client_id=client_id,
interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
browser_session=browser_session,
workflow_run=workflow_run,
x_api_key=x_api_key,
websocket=websocket,
)
LOG.info("Got streaming context for workflow run.", streaming=streaming)
loops = [
asyncio.create_task(loop_verify_workflow_run(streaming)),
asyncio.create_task(loop_stream_vnc(streaming)),
]
return streaming, loops
def verify_message_channel(
message_channel: sc.MessageChannel | None, streaming: sc.Streaming
) -> sc.MessageChannel | t.Literal[False]:
if message_channel and message_channel.is_open:
return message_channel
LOG.warning(
"No message channel found for client, or it is not open",
message_channel=message_channel,
client_id=streaming.client_id,
organization_id=streaming.organization_id,
)
return False
async def copy_text(streaming: sc.Streaming) -> None:
try:
async with connected_agent(streaming) as agent:
copied_text = await agent.get_selected_text()
LOG.info(
"Retrieved selected text via CDP",
organization_id=streaming.organization_id,
)
message_channel = sc.get_message_client(streaming.client_id)
if cc := verify_message_channel(message_channel, streaming):
await cc.send_copied_text(copied_text, streaming)
else:
LOG.warning(
"No message channel found for client, or it is not open",
message_channel=message_channel,
client_id=streaming.client_id,
organization_id=streaming.organization_id,
)
except Exception:
LOG.exception(
"Failed to retrieve selected text via CDP",
organization_id=streaming.organization_id,
)
async def ask_for_clipboard(streaming: sc.Streaming) -> None:
try:
LOG.info(
"Asking for clipboard data via CDP",
organization_id=streaming.organization_id,
)
message_channel = sc.get_message_client(streaming.client_id)
if cc := verify_message_channel(message_channel, streaming):
await cc.ask_for_clipboard(streaming)
except Exception:
LOG.exception(
"Failed to ask for clipboard via CDP",
organization_id=streaming.organization_id,
)
async def loop_stream_vnc(streaming: sc.Streaming) -> None:
"""
Actually stream the VNC session data between a frontend and a browser
session.
Loops until the task is cleared or the websocket is closed.
"""
if not streaming.browser_session:
LOG.info("No browser session found for task.", task=streaming.task, organization_id=streaming.organization_id)
return
vnc_url: str = ""
if streaming.browser_session.ip_address:
if ":" in streaming.browser_session.ip_address:
ip, _ = streaming.browser_session.ip_address.split(":")
vnc_url = f"ws://{ip}:{streaming.vnc_port}"
else:
vnc_url = f"ws://{streaming.browser_session.ip_address}:{streaming.vnc_port}"
else:
browser_address = streaming.browser_session.browser_address
parsed_browser_address = urlparse(browser_address)
host = parsed_browser_address.hostname
vnc_url = f"ws://{host}:{streaming.vnc_port}"
# NOTE(jdo:streaming-local-dev)
# vnc_url = "ws://localhost:9001/ws/novnc"
LOG.info(
"Connecting to VNC URL.",
vnc_url=vnc_url,
task=streaming.task,
workflow_run=streaming.workflow_run,
organization_id=streaming.organization_id,
)
async with websockets.connect(vnc_url) as novnc_ws:
async def frontend_to_browser() -> None:
LOG.info("Starting frontend-to-browser data transfer.", streaming=streaming)
data: Data | None = None
while streaming.is_open:
try:
data = await streaming.websocket.receive_bytes()
if data:
message_type = data[0]
if message_type == sc.MessageType.Keyboard.value:
streaming.update_key_state(data)
if streaming.key_state.is_copy(data):
await copy_text(streaming)
if streaming.key_state.is_paste(data):
await ask_for_clipboard(streaming)
if streaming.key_state.is_forbidden(data):
continue
if message_type == sc.MessageType.Mouse.value:
if sc.Mouse.Up.Right(data):
continue
if not streaming.interactor == "user" and message_type in (
sc.MessageType.Keyboard.value,
sc.MessageType.Mouse.value,
):
LOG.info(
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
)
continue
except WebSocketDisconnect:
LOG.info("Frontend disconnected.", task=streaming.task, organization_id=streaming.organization_id)
raise
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
if not data:
continue
try:
await novnc_ws.send(data)
except WebSocketDisconnect:
LOG.info(
"Browser disconnected from the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
except ConnectionClosedError:
LOG.info(
"Browser closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred in frontend-to-browser loop.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
async def browser_to_frontend() -> None:
LOG.info("Starting browser-to-frontend data transfer.", streaming=streaming)
data: Data | None = None
while streaming.is_open:
try:
data = await novnc_ws.recv()
except WebSocketDisconnect:
LOG.info(
"Browser disconnected from the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="browser-disconnected")
except ConnectionClosedError:
LOG.info(
"Browser closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="browser-closed")
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred in browser-to-frontend loop.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
if not data:
continue
try:
await streaming.websocket.send_bytes(data)
except WebSocketDisconnect:
LOG.info(
"Frontend disconnected from the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="frontend-disconnected")
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="frontend-closed")
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
loops = [
asyncio.create_task(frontend_to_browser()),
asyncio.create_task(browser_to_frontend()),
]
try:
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in loop stream.", task=streaming.task, organization_id=streaming.organization_id
)
finally:
LOG.info("Closing the loop stream.", task=streaming.task, organization_id=streaming.organization_id)
await streaming.close(reason="loop-stream-vnc-closed")
@base_router.websocket("/stream/vnc/browser_session/{browser_session_id}")
async def browser_session_stream(
websocket: WebSocket,
@@ -507,7 +71,7 @@ async def stream(
) -> None:
if not client_id:
LOG.error(
"Client ID not provided for VNC stream.",
"Client ID not provided for vnc stream.",
browser_session_id=browser_session_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
@@ -515,7 +79,7 @@ async def stream(
return
LOG.info(
"Starting VNC stream.",
"Starting vnc stream.",
browser_session_id=browser_session_id,
client_id=client_id,
task_id=task_id,
@@ -528,11 +92,11 @@ async def stream(
LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
return
streaming: sc.Streaming
loops: list[asyncio.Task] = []
vnc_channel: VncChannel
loops: Loops
if browser_session_id:
result = await get_streaming_for_browser_session(
result = await get_vnc_channel_for_browser_session(
client_id=client_id,
browser_session_id=browser_session_id,
organization_id=organization_id,
@@ -541,22 +105,22 @@ async def stream(
if not result:
LOG.error(
"No streaming context found for the browser session.",
"No vnc context found for the browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
streaming, loops = result
vnc_channel, loops = result
LOG.info(
"Starting streaming for browser session.",
"Starting vnc for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
elif task_id:
result = await get_streaming_for_task(
result = await get_vnc_channel_for_task(
client_id=client_id,
task_id=task_id,
organization_id=organization_id,
@@ -564,19 +128,17 @@ async def stream(
)
if not result:
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
LOG.error("No vnc context found for the task.", task_id=task_id, organization_id=organization_id)
await websocket.close(code=1013)
return
streaming, loops = result
vnc_channel, loops = result
LOG.info("Starting streaming for task.", task_id=task_id, organization_id=organization_id)
LOG.info("Starting vnc for task.", task_id=task_id, organization_id=organization_id)
elif workflow_run_id:
LOG.info(
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
result = await get_streaming_for_workflow_run(
LOG.info("Starting vnc for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id)
result = await get_vnc_channel_for_workflow_run(
client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
@@ -585,25 +147,23 @@ async def stream(
if not result:
LOG.error(
"No streaming context found for the workflow run.",
"No vnc context found for the workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
streaming, loops = result
vnc_channel, loops = result
LOG.info(
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
LOG.info("Starting vnc for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id)
else:
LOG.error("Neither task ID nor workflow run ID was provided.")
return
try:
LOG.info(
"Starting streaming loops.",
"Starting vnc loops.",
task_id=task_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
@@ -611,16 +171,16 @@ async def stream(
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the stream function.",
"An exception occurred in the vnc loop.",
task_id=task_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the streaming session.",
"Closing the vnc session.",
task_id=task_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await streaming.close(reason="stream-closed")
await vnc_channel.close(reason="vnc-closed")

0
skyvern/py.typed Normal file
View File