Browser streaming refactor (#4064)
This commit is contained in:
@@ -1,191 +0,0 @@
|
||||
"""
|
||||
A lightweight "agent" for interacting with the streaming browser over CDP.
|
||||
"""
|
||||
|
||||
import typing
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
||||
from skyvern.config import settings
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class StreamingAgent:
|
||||
"""
|
||||
A minimal agent that can connect to a browser via CDP and execute JavaScript.
|
||||
|
||||
Specifically for operations during streaming sessions (like copy/pasting selected text, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, streaming: sc.Streaming) -> None:
|
||||
self.streaming = streaming
|
||||
self.browser: Browser | None = None
|
||||
self.browser_context: BrowserContext | None = None
|
||||
self.page: Page | None = None
|
||||
self.pw: Playwright | None = None
|
||||
|
||||
async def connect(self, cdp_url: str | None = None) -> None:
|
||||
url = cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL
|
||||
|
||||
LOG.info("StreamingAgent connecting to CDP", cdp_url=url)
|
||||
|
||||
pw = self.pw or await async_playwright().start()
|
||||
|
||||
self.pw = pw
|
||||
|
||||
headers = {
|
||||
"x-api-key": self.streaming.x_api_key,
|
||||
}
|
||||
|
||||
self.browser = await pw.chromium.connect_over_cdp(url, headers=headers)
|
||||
|
||||
org_id = self.streaming.organization_id
|
||||
browser_session_id = (
|
||||
self.streaming.browser_session.persistent_browser_session_id if self.streaming.browser_session else None
|
||||
)
|
||||
|
||||
if browser_session_id:
|
||||
cdp_session = await self.browser.new_browser_cdp_session()
|
||||
await cdp_session.send(
|
||||
"Browser.setDownloadBehavior",
|
||||
{
|
||||
"behavior": "allow",
|
||||
"downloadPath": f"/app/downloads/{org_id}/{browser_session_id}",
|
||||
"eventsEnabled": True,
|
||||
},
|
||||
)
|
||||
|
||||
contexts = self.browser.contexts
|
||||
if contexts:
|
||||
LOG.info("StreamingAgent using existing browser context")
|
||||
self.browser_context = contexts[0]
|
||||
else:
|
||||
LOG.warning("No existing browser context found, creating new one")
|
||||
self.browser_context = await self.browser.new_context()
|
||||
|
||||
pages = self.browser_context.pages
|
||||
if pages:
|
||||
self.page = pages[0]
|
||||
LOG.info("StreamingAgent connected to page", url=self.page.url)
|
||||
else:
|
||||
LOG.warning("No pages found in browser context")
|
||||
|
||||
LOG.info("StreamingAgent connected successfully")
|
||||
|
||||
async def evaluate_js(
|
||||
self, expression: str, arg: str | int | float | bool | list | dict | None = None
|
||||
) -> str | int | float | bool | list | dict | None:
|
||||
if not self.page:
|
||||
raise RuntimeError("StreamingAgent is not connected to a page. Call connect() first.")
|
||||
|
||||
LOG.info("StreamingAgent evaluating JS", expression=expression[:100])
|
||||
|
||||
try:
|
||||
result = await self.page.evaluate(expression, arg)
|
||||
LOG.info("StreamingAgent JS evaluation successful")
|
||||
return result
|
||||
except Exception as ex:
|
||||
LOG.exception("StreamingAgent JS evaluation failed", expression=expression, ex=str(ex))
|
||||
raise
|
||||
|
||||
async def get_selected_text(self) -> str:
|
||||
LOG.info("StreamingAgent getting selected text")
|
||||
|
||||
js_expression = """
|
||||
() => {
|
||||
const selection = window.getSelection();
|
||||
return selection ? selection.toString() : '';
|
||||
}
|
||||
"""
|
||||
|
||||
selected_text = await self.evaluate_js(js_expression)
|
||||
|
||||
if isinstance(selected_text, str) or selected_text is None:
|
||||
LOG.info("StreamingAgent got selected text", length=len(selected_text) if selected_text else 0)
|
||||
return selected_text or ""
|
||||
|
||||
raise RuntimeError(f"StreamingAgent selected text is not a string, but a(n) '{type(selected_text)}'")
|
||||
|
||||
async def paste_text(self, text: str) -> None:
|
||||
LOG.info("StreamingAgent pasting text")
|
||||
|
||||
js_expression = """
|
||||
(text) => {
|
||||
const activeElement = document.activeElement;
|
||||
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
|
||||
const start = activeElement.selectionStart || 0;
|
||||
const end = activeElement.selectionEnd || 0;
|
||||
const value = activeElement.value || '';
|
||||
activeElement.value = value.slice(0, start) + text + value.slice(end);
|
||||
const newCursorPos = start + text.length;
|
||||
activeElement.setSelectionRange(newCursorPos, newCursorPos);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
await self.evaluate_js(js_expression, text)
|
||||
|
||||
LOG.info("StreamingAgent pasted text successfully")
|
||||
|
||||
async def close(self) -> None:
|
||||
LOG.info("StreamingAgent closing connection")
|
||||
|
||||
if self.browser:
|
||||
await self.browser.close()
|
||||
self.browser = None
|
||||
self.browser_context = None
|
||||
self.page = None
|
||||
|
||||
if self.pw:
|
||||
await self.pw.stop()
|
||||
self.pw = None
|
||||
|
||||
LOG.info("StreamingAgent closed")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterator[StreamingAgent]:
|
||||
"""
|
||||
The first pass at this has us doing the following for every operation:
|
||||
- creating a new agent
|
||||
- connecting
|
||||
- [doing smth]
|
||||
- closing the agent
|
||||
|
||||
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
|
||||
|
||||
If it turns out it's too slow, we can refactor to keep a persistent agent per streaming client.
|
||||
"""
|
||||
|
||||
if not streaming:
|
||||
msg = "connected_agent: no streaming client provided."
|
||||
LOG.error(msg)
|
||||
|
||||
raise Exception(msg)
|
||||
|
||||
if not streaming.browser_session or not streaming.browser_session.browser_address:
|
||||
msg = "connected_agent: no browser session or browser address found for streaming client."
|
||||
|
||||
LOG.error(
|
||||
msg,
|
||||
client_id=streaming.client_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
raise Exception(msg)
|
||||
|
||||
agent = StreamingAgent(streaming=streaming)
|
||||
|
||||
try:
|
||||
await agent.connect(streaming.browser_session.browser_address)
|
||||
|
||||
# NOTE(jdo:streaming-local-dev): use BROWSER_REMOTE_DEBUGGING_URL from settings
|
||||
# await agent.connect()
|
||||
|
||||
yield agent
|
||||
finally:
|
||||
await agent.close()
|
||||
@@ -6,11 +6,35 @@ import structlog
|
||||
from fastapi import WebSocket
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.services.org_auth_service import get_current_org
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class Constants:
|
||||
MISSING_API_KEY = "<missing-x-api-key>"
|
||||
|
||||
|
||||
async def get_x_api_key(organization_id: str) -> str:
|
||||
token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id,
|
||||
OrganizationAuthTokenType.api.value,
|
||||
)
|
||||
|
||||
if not token:
|
||||
LOG.warning(
|
||||
"No valid API key found for organization when streaming.",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
x_api_key = Constants.MISSING_API_KEY
|
||||
else:
|
||||
x_api_key = token.token
|
||||
|
||||
return x_api_key
|
||||
|
||||
|
||||
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
||||
"""
|
||||
Accepts the websocket connection.
|
||||
|
||||
336
skyvern/forge/sdk/routes/streaming/channels/README.md
Normal file
336
skyvern/forge/sdk/routes/streaming/channels/README.md
Normal file
@@ -0,0 +1,336 @@
|
||||
# Channels
|
||||
|
||||
A "channel", as used within the streaming mechanism of our remote browsers,
|
||||
is a WebSocket fit to some particular purpose.
|
||||
|
||||
There is/are:
|
||||
- a "VNC" channel that transmits NoVNC's RFB protocol data
|
||||
- a "Message" channel that transmits JSON between the frontend app and
|
||||
the api server
|
||||
- "CDP" channels that send messages to a remote browser using CDP protocol
|
||||
data
|
||||
- an "Execution" channel (one-off executions)
|
||||
- a soon-to-be "Exfiltration" channel (user event streaming)
|
||||
|
||||
In all cases, these are just WebSockets. They have been bucketed into "named channels"
|
||||
to aid understanding.
|
||||
|
||||
These channels are described at the top of their respective files.
|
||||
|
||||
## Architecture
|
||||
|
||||
WARN: below is an AI-generated architecture document for all of the code beneath
|
||||
the `skyvern/forge/sdk/routes/streaming` directory. It looks correct.
|
||||
|
||||
### High-Level Component Diagram
|
||||
|
||||
```
|
||||
┌─────────────────┐
|
||||
│ Frontend App │
|
||||
│ (Skyvern) │
|
||||
└────────┬────────┘
|
||||
│
|
||||
│ Two WebSocket Connections (paired via client_id)
|
||||
│
|
||||
┌────┴────┬──────────────────────────────────────────┐
|
||||
│ │ │
|
||||
│ VNC Channel Message Channel
|
||||
http (RFB Protocol) (JSON Messages)
|
||||
│ │ │
|
||||
│ │ │
|
||||
┌───▼─────────▼──────────────────────────────────────────▼────┐
|
||||
│ API Server │
|
||||
│ ┌────────────────────────────────────────────────────────┐ │
|
||||
│ │ Registries (In-Memory State) │ │
|
||||
│ │ - vnc_channels: dict[client_id -> VncChannel] │ │
|
||||
│ │ - message_channels: dict[client_id -> MessageChannel] │ │
|
||||
│ └────────────────────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ VNC Channel Logic: Message Channel Logic: │
|
||||
│ - RFB pass-through - Copy/paste coordination │
|
||||
│ - Keyboard/mouse filtering - Control handoff (agent/user)│
|
||||
│ - Interactor control - Clipboard management │
|
||||
│ - Copy/paste detection - Channel coordination │
|
||||
│ │
|
||||
│ CDP Channels (created on-demand): │
|
||||
│ - ExecutionChannel: JS evaluation (paste, get selected) │
|
||||
│ - ExfiltrationChannel: (future) user event streaming │
|
||||
│ │
|
||||
└────┬─────────────────────────────────────────────────────┬──┘
|
||||
│ │
|
||||
│ WebSocket (RFB) Playwright (CDP)│
|
||||
│ │
|
||||
┌────▼─────────────────────────────────────────────────────▼──┐
|
||||
│ Persistent Browser Session │
|
||||
│ ┌──────────────────────────────────────────────────────┐ │
|
||||
│ │ noVNC Server Chrome/Chromium │ │
|
||||
│ │ (websockify) │ │
|
||||
│ └──────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Channel Pairing & Sticky Sessions
|
||||
|
||||
**Critical Design Constraint**: The VNC and Message channels for a given frontend
|
||||
instance MUST connect to the same API server instance because they coordinate
|
||||
via in-memory registries keyed by `client_id`.
|
||||
|
||||
```
|
||||
Frontend Instance (client_id="abc123")
|
||||
│
|
||||
├─→ VNC Channel ─────→ API Server Instance #2
|
||||
│ ↓
|
||||
│ vnc_channels["abc123"] = VncChannel
|
||||
│ ↕ (coordinate via client_id)
|
||||
└─→ Message Channel ──→ API Server Instance #2
|
||||
↓
|
||||
message_channels["abc123"] = MessageChannel
|
||||
```
|
||||
|
||||
**Deployment Requirement**: Load balancer must use sticky sessions (e.g., cookie-based
|
||||
or IP-based affinity) to ensure both WebSocket connections from the same client_id
|
||||
reach the same backend instance.
|
||||
|
||||
### Channel Lifecycle & Verification
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Channel Creation (per browser_session/task/workflow_run) │
|
||||
└────────────────────┬─────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌────────────────────────────────────────┐
|
||||
│ Initial Verification │
|
||||
│ - verify_browser_session() │
|
||||
│ - verify_task() │
|
||||
│ - verify_workflow_run() │
|
||||
│ Returns: entity + browser session │
|
||||
└────────┬───────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌────────────────────────────────────────┐
|
||||
│ Channel + Loops Created │
|
||||
│ VncChannel/MessageChannel initialized │
|
||||
│ + Added to registry │
|
||||
└────────┬───────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌────────────────────────────────────────┐
|
||||
│ Concurrent Loop Execution │
|
||||
│ (via collect() - fail-fast) │
|
||||
│ │
|
||||
│ Loop 1: Verification Loop │
|
||||
│ - Polls every 5s │
|
||||
│ - Updates channel state │
|
||||
│ - Exits if entity invalid │
|
||||
│ │
|
||||
│ Loop 2: Data Streaming Loop │
|
||||
│ - VNC: bidirectional RFB │
|
||||
│ - Message: JSON messages │
|
||||
│ - Exits on disconnect │
|
||||
└────────┬───────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌────────────────────────────────────────┐
|
||||
│ Channel Cleanup │
|
||||
│ - Close WebSocket │
|
||||
│ - Remove from registry │
|
||||
│ - Clear channel state │
|
||||
└────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### VNC Channel Data Flow
|
||||
|
||||
```
|
||||
User Keyboard/Mouse Input
|
||||
│
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ Frontend (noVNC client) │
|
||||
│ Encodes input as RFB protocol bytes │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│ WebSocket (bytes)
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ API Server: VncChannel.loop_stream_vnc() │
|
||||
│ frontend_to_browser() coroutine │
|
||||
│ │
|
||||
│ 1. Receive RFB bytes from frontend │
|
||||
│ 2. Detect message type (keyboard=4, mouse=5) │
|
||||
│ 3. Update key_state tracking │
|
||||
│ 4. Check for special key combinations: │
|
||||
│ - Ctrl+C / Cmd+C → copy_text() via CDP │
|
||||
│ - Ctrl+V / Cmd+V → ask_for_clipboard() via Message │
|
||||
│ - Ctrl+O → BLOCK (forbidden) │
|
||||
│ 5. Check interactor mode: │
|
||||
│ - If interactor=="agent" → BLOCK user input │
|
||||
│ - If interactor=="user" → PASS THROUGH │
|
||||
│ 6. Block right-mouse-button (security) │
|
||||
│ 7. Forward to noVNC server │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│ WebSocket (bytes)
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ Persistent Browser: noVNC Server (websockify) │
|
||||
│ Translates RFB → VNC protocol │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ Browser Display Update │
|
||||
└───────────────────────────────────────────────────────────┘
|
||||
|
||||
Screen Updates (reverse direction):
|
||||
Browser → noVNC → VncChannel.browser_to_frontend() → Frontend
|
||||
```
|
||||
|
||||
### Message Channel + CDP Execution Flow
|
||||
|
||||
```
|
||||
User Pastes (Ctrl+V detected in VNC channel)
|
||||
│
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ VncChannel: ask_for_clipboard() │
|
||||
│ Finds MessageChannel via registry[client_id] │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ MessageChannel.ask_for_clipboard() │
|
||||
│ Sends: {"kind": "ask-for-clipboard"} │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│ WebSocket (JSON)
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ Frontend: User's clipboard content │
|
||||
│ Responds: {"kind": "ask-for-clipboard-response", │
|
||||
│ "text": "clipboard content"} │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│ WebSocket (JSON)
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ MessageChannel: handle_data() │
|
||||
│ Receives clipboard text │
|
||||
│ Finds VncChannel via registry[client_id] │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ ExecutionChannel (CDP) │
|
||||
│ 1. Connect to browser via Playwright │
|
||||
│ 2. Get browser context + page │
|
||||
│ 3. evaluate_js(paste_text_script, clipboard_text) │
|
||||
│ 4. Close CDP connection │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│ CDP over WebSocket
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ Browser: Text pasted into active element │
|
||||
└───────────────────────────────────────────────────────────┘
|
||||
|
||||
Similar flow for Copy (Ctrl+C):
|
||||
VNC detects → ExecutionChannel.get_selected_text() →
|
||||
MessageChannel sends {"kind": "copied-text", "text": "..."}
|
||||
→ Frontend updates clipboard
|
||||
```
|
||||
|
||||
### Control Flow: Agent ↔ User Interaction
|
||||
|
||||
NOTE: we don't really have an "agent" at this time. But any control of the
|
||||
browser that is not user-originated is kinda' agent-like, by some
|
||||
definition of "agent". Here, we do not have an "AI agent". Future work may
|
||||
alter this state of affairs - and some "agent" could operate the browser
|
||||
automatically.
|
||||
|
||||
```
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ Initial State: interactor = "agent" │
|
||||
│ - User keyboard/mouse input is BLOCKED │
|
||||
│ - Agent can control browser via CDP │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│
|
||||
│ User clicks "Take Control" in frontend
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ Frontend → MessageChannel │
|
||||
│ {"kind": "take-control"} │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ MessageChannel.handle_data() │
|
||||
│ vnc_channel.interactor = "user" │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ New State: interactor = "user" │
|
||||
│ - User keyboard/mouse input is PASSED THROUGH │
|
||||
│ - Agent should pause automation │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│
|
||||
│ User clicks "Cede Control" in frontend
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ Frontend → MessageChannel │
|
||||
│ {"kind": "cede-control"} │
|
||||
└────────┬──────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌───────────────────────────────────────────────────────────┐
|
||||
│ MessageChannel.handle_data() │
|
||||
│ vnc_channel.interactor = "agent" │
|
||||
│ → Back to initial state │
|
||||
└───────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Error Propagation & Cleanup
|
||||
|
||||
The system uses `collect()` (fail-fast gather) for loop management:
|
||||
|
||||
```
|
||||
Channel has 2 concurrent loops:
|
||||
- Verification loop (polls DB every 5s)
|
||||
- Streaming loop (handles WebSocket I/O)
|
||||
|
||||
collect() behavior:
|
||||
1. Waits for ANY loop to fail or complete
|
||||
2. Cancels all other loops
|
||||
3. Propagates the first exception
|
||||
|
||||
Cleanup (always executed via finally):
|
||||
- channel.close()
|
||||
- Sets browser_session/task/workflow_run = None
|
||||
- Closes WebSocket
|
||||
- Removes from registry
|
||||
```
|
||||
|
||||
### Database Entity Relationships
|
||||
|
||||
```
|
||||
Organization
|
||||
│
|
||||
├─→ BrowserSession ────────┐
|
||||
│ │
|
||||
├─→ Task ──────────────────┤
|
||||
│ (has optional │
|
||||
│ browser_session) │
|
||||
│ │
|
||||
└─→ WorkflowRun ───────────┤
|
||||
(has optional │
|
||||
browser_session) │
|
||||
│
|
||||
▼
|
||||
VncChannel + MessageChannel
|
||||
(in-memory, paired by client_id)
|
||||
```
|
||||
|
||||
### Key Design Patterns
|
||||
|
||||
1. **Channel Pairing**: Two WebSocket connections coordinated via in-memory registry
|
||||
2. **Fail-Fast Loops**: `collect()` ensures any loop failure closes the entire channel
|
||||
3. **Interactor Mode**: Binary state controlling whether user input is allowed
|
||||
4. **On-Demand CDP**: ExecutionChannel creates temporary connections for each operation
|
||||
5. **Polling Verification**: Every 5s, channels verify their backing entity still exists
|
||||
6. **Pass-Through Proxy**: API server intercepts but doesn't transform RFB data
|
||||
185
skyvern/forge/sdk/routes/streaming/channels/cdp.py
Normal file
185
skyvern/forge/sdk/routes/streaming/channels/cdp.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
A channel for connecting to a persistent browser instance.
|
||||
|
||||
What this channel looks like:
|
||||
|
||||
[API Server] <--> [Browser (CDP)]
|
||||
|
||||
Channel data:
|
||||
|
||||
CDP protocol data, with Playwright thrown in.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing as t
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
||||
|
||||
from skyvern.config import settings
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class CdpChannel:
|
||||
"""
|
||||
CdpChannel. Relies on a VncChannel - without one, a CdpChannel has no
|
||||
r'aison d'etre.
|
||||
"""
|
||||
|
||||
def __new__(cls, *_: t.Iterable[t.Any], **__: t.Mapping[str, t.Any]) -> t.Self: # noqa: N805
|
||||
if cls is CdpChannel:
|
||||
raise TypeError("CdpChannel class cannot be instantiated directly.")
|
||||
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, *, vnc_channel: VncChannel) -> None:
|
||||
self.vnc_channel = vnc_channel
|
||||
# --
|
||||
self.browser: Browser | None = None
|
||||
self.browser_context: BrowserContext | None = None
|
||||
self.page: Page | None = None
|
||||
self.pw: Playwright | None = None
|
||||
self.url: str | None = None
|
||||
|
||||
@property
|
||||
def class_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def identity(self) -> t.Dict[str, t.Any]:
|
||||
base = self.vnc_channel.identity
|
||||
|
||||
return base | {"url": self.url}
|
||||
|
||||
async def connect(self, cdp_url: str | None = None) -> t.Self:
|
||||
"""
|
||||
Idempotent.
|
||||
"""
|
||||
|
||||
if self.browser and self.browser.is_connected():
|
||||
return self
|
||||
|
||||
await self.close()
|
||||
|
||||
if cdp_url:
|
||||
url = cdp_url
|
||||
elif self.vnc_channel.browser_session and self.vnc_channel.browser_session.browser_address:
|
||||
url = self.vnc_channel.browser_session.browser_address
|
||||
else:
|
||||
url = settings.BROWSER_REMOTE_DEBUGGING_URL
|
||||
|
||||
self.url = url
|
||||
|
||||
LOG.info(f"{self.class_name} connecting to CDP", **self.identity)
|
||||
|
||||
pw = self.pw or await async_playwright().start()
|
||||
|
||||
self.pw = pw
|
||||
|
||||
headers = (
|
||||
{
|
||||
"x-api-key": self.vnc_channel.x_api_key,
|
||||
}
|
||||
if self.vnc_channel.x_api_key
|
||||
else None
|
||||
)
|
||||
|
||||
def on_close() -> None:
|
||||
LOG.warning(
|
||||
f"{self.class_name} closing because the persistent browser disconnected itself.", **self.identity
|
||||
)
|
||||
close_task = asyncio.create_task(self.close())
|
||||
close_task.add_done_callback(lambda _: asyncio.create_task(self.connect())) # TODO: avoid blind reconnect
|
||||
|
||||
self.browser = await pw.chromium.connect_over_cdp(url, headers=headers)
|
||||
self.browser.on("disconnected", on_close)
|
||||
|
||||
await self.apply_download_behavior(self.browser)
|
||||
|
||||
contexts = self.browser.contexts
|
||||
if contexts:
|
||||
LOG.info(f"{self.class_name} using existing browser context", **self.identity)
|
||||
self.browser_context = contexts[0]
|
||||
else:
|
||||
LOG.warning(f"{self.class_name} No existing browser context found, creating new one", **self.identity)
|
||||
self.browser_context = await self.browser.new_context()
|
||||
|
||||
pages = self.browser_context.pages
|
||||
if pages:
|
||||
self.page = pages[0]
|
||||
LOG.info(f"{self.class_name} connected to page", **self.identity)
|
||||
else:
|
||||
LOG.warning(f"{self.class_name} No pages found in browser context", **self.identity)
|
||||
|
||||
LOG.info(f"{self.class_name} connected successfully", **self.identity)
|
||||
|
||||
return self
|
||||
|
||||
async def apply_download_behavior(self, browser: Browser) -> t.Self:
|
||||
org_id = self.vnc_channel.organization_id
|
||||
|
||||
browser_session_id = (
|
||||
self.vnc_channel.browser_session.persistent_browser_session_id if self.vnc_channel.browser_session else None
|
||||
)
|
||||
|
||||
download_path = f"/app/downloads/{org_id}/{browser_session_id}" if browser_session_id else "/app/downloads/"
|
||||
|
||||
cdp_session = await browser.new_browser_cdp_session()
|
||||
|
||||
await cdp_session.send(
|
||||
"Browser.setDownloadBehavior",
|
||||
{
|
||||
"behavior": "allow",
|
||||
"downloadPath": download_path,
|
||||
"eventsEnabled": True,
|
||||
},
|
||||
)
|
||||
|
||||
await cdp_session.detach()
|
||||
|
||||
return self
|
||||
|
||||
async def close(self) -> None:
|
||||
LOG.info(f"{self.class_name} closing connection", **self.identity)
|
||||
|
||||
if self.browser:
|
||||
try:
|
||||
await self.browser.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.browser = None
|
||||
|
||||
if self.pw:
|
||||
await self.pw.stop()
|
||||
self.pw = None
|
||||
|
||||
self.browser_context = None
|
||||
self.page = None
|
||||
|
||||
LOG.info(f"{self.class_name} closed", **self.identity)
|
||||
|
||||
async def evaluate_js(
|
||||
self,
|
||||
expression: str,
|
||||
arg: str | int | float | bool | list | dict | None = None,
|
||||
) -> str | int | float | bool | list | dict | None:
|
||||
await self.connect()
|
||||
|
||||
if not self.page:
|
||||
raise RuntimeError(f"{self.class_name} evaluate_js: not connected to a page. Call connect() first.")
|
||||
|
||||
LOG.info(f"{self.class_name} evaluating js", expression=expression[:100], **self.identity)
|
||||
|
||||
try:
|
||||
result = await self.page.evaluate(expression, arg)
|
||||
LOG.info(f"{self.class_name} evaluated js successfully", **self.identity)
|
||||
return result
|
||||
except Exception:
|
||||
LOG.exception(f"{self.class_name} failed to evaluate js", expression=expression, **self.identity)
|
||||
raise
|
||||
121
skyvern/forge/sdk/routes/streaming/channels/execution.py
Normal file
121
skyvern/forge/sdk/routes/streaming/channels/execution.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
A channel for executing JavaScript against a persistent browser instance.
|
||||
|
||||
What this channel looks like:
|
||||
|
||||
[API Server] <--> [Browser (CDP)]
|
||||
|
||||
Channel data:
|
||||
|
||||
Chrome DevTools Protocol (CDP) over WebSockets. We cheat and use Playwright.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.routes.streaming.channels.cdp import CdpChannel
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class ExecutionChannel(CdpChannel):
|
||||
"""
|
||||
ExecutionChannel.
|
||||
"""
|
||||
|
||||
@property
|
||||
def class_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
async def get_selected_text(self) -> str:
|
||||
LOG.info(f"{self.class_name} getting selected text", **self.identity)
|
||||
|
||||
js_expression = """
|
||||
() => {
|
||||
const selection = window.getSelection();
|
||||
return selection ? selection.toString() : '';
|
||||
}
|
||||
"""
|
||||
|
||||
selected_text = await self.evaluate_js(js_expression, self.page)
|
||||
|
||||
if isinstance(selected_text, str) or selected_text is None:
|
||||
LOG.info(
|
||||
f"{self.class_name} got selected text",
|
||||
length=len(selected_text) if selected_text else 0,
|
||||
**self.identity,
|
||||
)
|
||||
return selected_text or ""
|
||||
|
||||
raise RuntimeError(f"{self.class_name} selected text is not a string, but a(n) '{type(selected_text)}'")
|
||||
|
||||
async def paste_text(self, text: str) -> None:
|
||||
LOG.info(f"{self.class_name} pasting text", **self.identity)
|
||||
|
||||
js_expression = """
|
||||
(text) => {
|
||||
const activeElement = document.activeElement;
|
||||
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
|
||||
const start = activeElement.selectionStart || 0;
|
||||
const end = activeElement.selectionEnd || 0;
|
||||
const value = activeElement.value || '';
|
||||
activeElement.value = value.slice(0, start) + text + value.slice(end);
|
||||
const newCursorPos = start + text.length;
|
||||
activeElement.setSelectionRange(newCursorPos, newCursorPos);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
await self.evaluate_js(js_expression, text)
|
||||
|
||||
LOG.info(f"{self.class_name} pasted text successfully", **self.identity)
|
||||
|
||||
async def close(self) -> None:
|
||||
LOG.info(f"{self.class_name} closing connection", **self.identity)
|
||||
|
||||
if self.browser:
|
||||
await self.browser.close()
|
||||
self.browser = None
|
||||
self.browser_context = None
|
||||
self.page = None
|
||||
|
||||
if self.pw:
|
||||
await self.pw.stop()
|
||||
self.pw = None
|
||||
|
||||
LOG.info(f"{self.class_name} closed", **self.identity)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def execution_channel(vnc_channel: VncChannel) -> t.AsyncIterator[ExecutionChannel]:
|
||||
"""
|
||||
The first pass at this has us doing the following for every operation:
|
||||
- creating a new channel
|
||||
- connecting
|
||||
- [doing smth]
|
||||
- closing the channel
|
||||
|
||||
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
|
||||
|
||||
If it turns out it's too slow, we can refactor to keep a persistent channel per vnc client.
|
||||
"""
|
||||
|
||||
channel = ExecutionChannel(vnc_channel=vnc_channel)
|
||||
|
||||
try:
|
||||
await channel.connect()
|
||||
|
||||
# NOTE(jdo:streaming-local-dev)
|
||||
# from skyvern.config import settings
|
||||
# await channel.connect(settings.BROWSER_REMOTE_DEBUGGING_URL)
|
||||
|
||||
yield channel
|
||||
finally:
|
||||
await channel.close()
|
||||
@@ -0,0 +1 @@
|
||||
# stub
|
||||
456
skyvern/forge/sdk/routes/streaming/channels/message.py
Normal file
456
skyvern/forge/sdk/routes/streaming/channels/message.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
A channel for streaming whole messages between our frontend and our API server.
|
||||
This channel can access a persistent browser instance through the execution channel.
|
||||
|
||||
What this channel looks like:
|
||||
|
||||
[Skyvern App] <--> [API Server]
|
||||
|
||||
Channel data:
|
||||
|
||||
JSON over WebSockets. Semantics are fire and forget. Req-resp is built on
|
||||
top of that using message types.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import typing as t
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from starlette.websockets import WebSocketState
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel
|
||||
from skyvern.forge.sdk.routes.streaming.registries import (
|
||||
add_message_channel,
|
||||
del_message_channel,
|
||||
get_vnc_channel,
|
||||
)
|
||||
from skyvern.forge.sdk.routes.streaming.verify import (
|
||||
loop_verify_browser_session,
|
||||
loop_verify_workflow_run,
|
||||
verify_browser_session,
|
||||
verify_workflow_run,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||
from skyvern.forge.sdk.utils.aio import collect
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||
|
||||
|
||||
MessageKinds = t.Literal[
|
||||
"ask-for-clipboard-response",
|
||||
"cede-control",
|
||||
"take-control",
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Message:
|
||||
kind: MessageKinds
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInTakeControl(Message):
|
||||
kind: t.Literal["take-control"] = "take-control"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInCedeControl(Message):
|
||||
kind: t.Literal["cede-control"] = "cede-control"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInAskForClipboardResponse(Message):
|
||||
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
|
||||
text: str = ""
|
||||
|
||||
|
||||
ChannelMessage = t.Union[
|
||||
MessageInAskForClipboardResponse,
|
||||
MessageInCedeControl,
|
||||
MessageInTakeControl,
|
||||
]
|
||||
|
||||
|
||||
def reify_channel_message(data: dict) -> ChannelMessage:
|
||||
kind = data.get("kind", None)
|
||||
|
||||
match kind:
|
||||
case "ask-for-clipboard-response":
|
||||
text = data.get("text") or ""
|
||||
return MessageInAskForClipboardResponse(text=text)
|
||||
case "cede-control":
|
||||
return MessageInCedeControl()
|
||||
case "take-control":
|
||||
return MessageInTakeControl()
|
||||
case _:
|
||||
raise ValueError(f"Unknown message kind: '{kind}'")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageChannel:
|
||||
"""
|
||||
A message channel for streaming JSON messages between our frontend and our API server.
|
||||
"""
|
||||
|
||||
client_id: str
|
||||
organization_id: str
|
||||
websocket: WebSocket
|
||||
# --
|
||||
out_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) # warn: unbounded
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
add_message_channel(self)
|
||||
|
||||
@property
|
||||
def class_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def identity(self) -> dict[str, str]:
|
||||
base = {"organization_id": self.organization_id}
|
||||
|
||||
if self.browser_session:
|
||||
return base | {"browser_session_id": self.browser_session.persistent_browser_session_id}
|
||||
|
||||
if self.workflow_run:
|
||||
return base | {"workflow_run_id": self.workflow_run.id}
|
||||
|
||||
return base
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
|
||||
LOG.info(f"{self.class_name} closing message stream.", reason=reason, code=code, **self.identity)
|
||||
|
||||
self.browser_session = None
|
||||
self.workflow_run = None
|
||||
|
||||
try:
|
||||
await self.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
del_message_channel(self.client_id)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
if self.websocket.client_state != WebSocketState.CONNECTED:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def drain(self) -> list[dict]:
|
||||
datums: list[dict] = []
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self.receive_from_out_queue()),
|
||||
asyncio.create_task(self.receive_from_user()),
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
for result in results:
|
||||
datums.extend(result)
|
||||
|
||||
if datums:
|
||||
LOG.info(f"{self.class_name} Drained {len(datums)} messages from message channel.", **self.identity)
|
||||
|
||||
return datums
|
||||
|
||||
async def receive_from_user(self) -> list[dict]:
|
||||
datums: list[dict] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = await asyncio.wait_for(self.websocket.receive_json(), timeout=0.001)
|
||||
datums.append(data)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
except RuntimeError:
|
||||
if "not connected" in str(RuntimeError).lower():
|
||||
break
|
||||
except Exception:
|
||||
LOG.exception(f"{self.class_name} Failed to receive message from message channel", **self.identity)
|
||||
break
|
||||
|
||||
return datums
|
||||
|
||||
async def receive_from_out_queue(self) -> list[dict]:
|
||||
datums: list[dict] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = await asyncio.wait_for(self.out_queue.get(), timeout=0.001)
|
||||
datums.append(data)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
return datums
|
||||
|
||||
def receive_from_out_queue_nowait(self) -> list[dict]:
|
||||
datums: list[dict] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.out_queue.get_nowait()
|
||||
datums.append(data)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
return datums
|
||||
|
||||
async def send(self, *, messages: list[dict]) -> t.Self:
|
||||
for message in messages:
|
||||
await self.out_queue.put(message)
|
||||
|
||||
return self
|
||||
|
||||
def send_nowait(self, *, messages: list[dict]) -> t.Self:
|
||||
for message in messages:
|
||||
self.out_queue.put_nowait(message)
|
||||
|
||||
return self
|
||||
|
||||
async def ask_for_clipboard(self) -> None:
|
||||
LOG.info(f"{self.class_name} Sending ask-for-clipboard to message channel", **self.identity)
|
||||
|
||||
try:
|
||||
await self.websocket.send_json(
|
||||
{
|
||||
"kind": "ask-for-clipboard",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(f"{self.class_name} Failed to send ask-for-clipboard to message channel", **self.identity)
|
||||
|
||||
async def send_copied_text(self, copied_text: str) -> None:
|
||||
LOG.info(f"{self.class_name} Sending copied text to message channel", **self.identity)
|
||||
|
||||
try:
|
||||
await self.websocket.send_json(
|
||||
{
|
||||
"kind": "copied-text",
|
||||
"text": copied_text,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(f"{self.class_name} Failed to send copied text to message channel", **self.identity)
|
||||
|
||||
|
||||
async def loop_stream_messages(message_channel: MessageChannel) -> None:
|
||||
"""
|
||||
Stream messages and their results back and forth.
|
||||
|
||||
Loops until the websocket is closed.
|
||||
"""
|
||||
|
||||
class_name = message_channel.class_name
|
||||
|
||||
async def handle_data(data: dict) -> None:
|
||||
nonlocal class_name
|
||||
|
||||
try:
|
||||
message = reify_channel_message(data)
|
||||
except ValueError:
|
||||
LOG.error(f"MessageChannel: cannot reify channel message from data: {data}", **message_channel.identity)
|
||||
return
|
||||
|
||||
message_kind = message.kind
|
||||
|
||||
match message_kind:
|
||||
case "ask-for-clipboard-response":
|
||||
if not isinstance(message, MessageInAskForClipboardResponse):
|
||||
LOG.error(
|
||||
f"{class_name} invalid message type for ask-for-clipboard-response.",
|
||||
message=message,
|
||||
**message_channel.identity,
|
||||
)
|
||||
return
|
||||
|
||||
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||
|
||||
if not vnc_channel:
|
||||
LOG.error(
|
||||
f"{class_name} no vnc channel found for message channel.",
|
||||
message=message,
|
||||
**message_channel.identity,
|
||||
)
|
||||
return
|
||||
|
||||
text = message.text
|
||||
|
||||
async with execution_channel(vnc_channel) as execute:
|
||||
await execute.paste_text(text)
|
||||
|
||||
case "cede-control":
|
||||
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||
|
||||
if not vnc_channel:
|
||||
LOG.error(
|
||||
f"{class_name} no vnc channel client found for message channel.",
|
||||
message=message,
|
||||
**message_channel.identity,
|
||||
)
|
||||
return
|
||||
vnc_channel.interactor = "agent"
|
||||
|
||||
case "take-control":
|
||||
LOG.info(f"{class_name} processing take-control message.", **message_channel.identity)
|
||||
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||
|
||||
if not vnc_channel:
|
||||
LOG.error(
|
||||
f"{class_name} no vnc channel client found for message channel.",
|
||||
message=message,
|
||||
**message_channel.identity,
|
||||
)
|
||||
return
|
||||
vnc_channel.interactor = "user"
|
||||
|
||||
case _:
|
||||
LOG.error(f"{class_name} unknown message kind: '{message_kind}'", **message_channel.identity)
|
||||
return
|
||||
|
||||
async def frontend_to_backend() -> None:
|
||||
nonlocal class_name
|
||||
|
||||
LOG.info(f"{class_name} starting frontend-to-backend loop.", **message_channel.identity)
|
||||
|
||||
while message_channel.is_open:
|
||||
try:
|
||||
datums = await message_channel.drain()
|
||||
|
||||
for data in datums:
|
||||
if not isinstance(data, dict):
|
||||
LOG.error(
|
||||
f"{class_name} cannot create message: expected dict, got {type(data)}",
|
||||
**message_channel.identity,
|
||||
)
|
||||
continue
|
||||
|
||||
await handle_data(data)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(f"{class_name} frontend disconnected.", **message_channel.identity)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(f"{class_name} frontend closed channel.", **message_channel.identity)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.exception(f"{class_name} An unexpected exception occurred.", **message_channel.identity)
|
||||
raise
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(frontend_to_backend()),
|
||||
]
|
||||
|
||||
try:
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(f"{class_name} An exception occurred in loop message channel stream.", **message_channel.identity)
|
||||
finally:
|
||||
LOG.info(f"{class_name} Closing the message channel stream.", **message_channel.identity)
|
||||
await message_channel.close(reason="loop-channel-closed")
|
||||
|
||||
|
||||
async def get_message_channel_for_browser_session(
|
||||
client_id: str,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[MessageChannel, Loops] | None:
|
||||
"""
|
||||
Return a message channel for a browser session, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Message channel: no initial browser session found.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
message_channel = MessageChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got message channel for browser session.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_browser_session(message_channel)),
|
||||
asyncio.create_task(loop_stream_messages(message_channel)),
|
||||
]
|
||||
|
||||
return message_channel, loops
|
||||
|
||||
|
||||
async def get_message_channel_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[MessageChannel, Loops] | None:
|
||||
"""
|
||||
Return a message channel for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info(
|
||||
"Message channel: no initial workflow run found.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Message channel: no initial browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
message_channel = MessageChannel(
|
||||
client_id,
|
||||
organization_id,
|
||||
browser_session=browser_session,
|
||||
websocket=websocket,
|
||||
workflow_run=workflow_run,
|
||||
)
|
||||
|
||||
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_workflow_run(message_channel)),
|
||||
asyncio.create_task(loop_stream_messages(message_channel)),
|
||||
]
|
||||
|
||||
return message_channel, loops
|
||||
592
skyvern/forge/sdk/routes/streaming/channels/vnc.py
Normal file
592
skyvern/forge/sdk/routes/streaming/channels/vnc.py
Normal file
@@ -0,0 +1,592 @@
|
||||
"""
|
||||
A channel for streaming the VNC protocol data between our frontend and a
|
||||
persistent browser instance.
|
||||
|
||||
This is a pass-thru channel, through our API server. As such, we can monitor and/or
|
||||
intercept RFB protocol messages as needed.
|
||||
|
||||
What this channel looks like:
|
||||
|
||||
[Skyvern App] <--> [API Server] <--> [websockified noVNC] <--> [Browser]
|
||||
|
||||
|
||||
Channel data:
|
||||
|
||||
One could call this RFB over WebSockets (rockets?), as the protocol data streaming
|
||||
over the WebSocket is raw RFB protocol data.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import structlog
|
||||
import websockets
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from starlette.websockets import WebSocketState
|
||||
from websockets import ConnectionClosedError, Data
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.routes.streaming.auth import get_x_api_key
|
||||
from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel
|
||||
from skyvern.forge.sdk.routes.streaming.registries import (
|
||||
add_vnc_channel,
|
||||
del_vnc_channel,
|
||||
get_message_channel,
|
||||
get_vnc_channel,
|
||||
)
|
||||
from skyvern.forge.sdk.routes.streaming.verify import (
|
||||
loop_verify_browser_session,
|
||||
loop_verify_task,
|
||||
loop_verify_workflow_run,
|
||||
verify_browser_session,
|
||||
verify_task,
|
||||
verify_workflow_run,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.utils.aio import collect
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
Interactor = t.Literal["agent", "user"]
|
||||
"""
|
||||
NOTE: we don't really have an "agent" at this time. But any control of the
|
||||
browser that is not user-originated is kinda' agent-like, by some
|
||||
definition of "agent". Here, we do not have an "AI agent". Future work may
|
||||
alter this state of affairs - and some "agent" could operate the browser
|
||||
automatically. In any case, if the interactor is not a "user", we assume
|
||||
it is an "agent".
|
||||
"""
|
||||
|
||||
|
||||
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||
|
||||
|
||||
class MessageType(IntEnum):
|
||||
Keyboard = 4
|
||||
Mouse = 5
|
||||
|
||||
|
||||
class Keys:
|
||||
"""
|
||||
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
|
||||
|
||||
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
|
||||
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
|
||||
"""
|
||||
|
||||
class Down:
|
||||
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
|
||||
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
|
||||
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
|
||||
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
|
||||
|
||||
class Up:
|
||||
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
|
||||
|
||||
|
||||
def is_rmb(data: bytes) -> bool:
|
||||
return data[0:2] == b"\x05\x04"
|
||||
|
||||
|
||||
class Mouse:
|
||||
class Up:
|
||||
Right = is_rmb
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KeyState:
|
||||
ctrl_is_down: bool = False
|
||||
alt_is_down: bool = False
|
||||
cmd_is_down: bool = False
|
||||
|
||||
def is_forbidden(self, data: bytes) -> bool:
|
||||
"""
|
||||
:return: True if the key is forbidden, else False
|
||||
"""
|
||||
return self.is_ctrl_o(data)
|
||||
|
||||
def is_ctrl_o(self, data: bytes) -> bool:
|
||||
"""
|
||||
Do not allow the opening of files.
|
||||
"""
|
||||
return self.ctrl_is_down and data == Keys.Down.OKey
|
||||
|
||||
def is_copy(self, data: bytes) -> bool:
|
||||
"""
|
||||
Detect Ctrl+C or Cmd+C for copy.
|
||||
"""
|
||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
|
||||
|
||||
def is_paste(self, data: bytes) -> bool:
|
||||
"""
|
||||
Detect Ctrl+V or Cmd+V for paste.
|
||||
"""
|
||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VncChannel:
|
||||
"""
|
||||
A VNC channel for streaming RFB protocol data between our frontend app, and
|
||||
a remote browser.
|
||||
"""
|
||||
|
||||
client_id: str
|
||||
"""
|
||||
Unique to a frontend app instance.
|
||||
"""
|
||||
|
||||
organization_id: str
|
||||
vnc_port: int
|
||||
x_api_key: str
|
||||
websocket: WebSocket
|
||||
|
||||
initial_interactor: dataclasses.InitVar[Interactor]
|
||||
"""
|
||||
The role of the entity interacting with the channel, either "agent" or "user".
|
||||
"""
|
||||
|
||||
_interactor: Interactor = dataclasses.field(init=False, repr=False)
|
||||
|
||||
# --
|
||||
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
key_state: KeyState = dataclasses.field(default_factory=KeyState)
|
||||
task: Task | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
def __post_init__(self, initial_interactor: Interactor) -> None:
|
||||
self.interactor = initial_interactor
|
||||
add_vnc_channel(self)
|
||||
|
||||
@property
|
||||
def class_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def identity(self) -> dict:
|
||||
base = {"organization_id": self.organization_id}
|
||||
|
||||
if self.task:
|
||||
return base | {"task_id": self.task.task_id}
|
||||
elif self.workflow_run:
|
||||
return base | {"workflow_run_id": self.workflow_run.workflow_run_id}
|
||||
elif self.browser_session:
|
||||
return base | {"browser_session_id": self.browser_session.persistent_browser_session_id}
|
||||
else:
|
||||
return base | {"client_id": self.client_id}
|
||||
|
||||
@property
|
||||
def interactor(self) -> Interactor:
|
||||
return self._interactor
|
||||
|
||||
@interactor.setter
|
||||
def interactor(self, value: Interactor) -> None:
|
||||
self._interactor = value
|
||||
|
||||
LOG.info(f"{self.class_name} Setting interactor to {value}", **self.identity)
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
if self.websocket.client_state != WebSocketState.CONNECTED:
|
||||
return False
|
||||
|
||||
if not self.task and not self.workflow_run and not self.browser_session:
|
||||
return False
|
||||
|
||||
if not get_vnc_channel(self.client_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> t.Self:
|
||||
LOG.info(f"{self.class_name} closing.", reason=reason, code=code, **self.identity)
|
||||
|
||||
self.browser_session = None
|
||||
self.task = None
|
||||
self.workflow_run = None
|
||||
|
||||
try:
|
||||
await self.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
del_vnc_channel(self.client_id)
|
||||
|
||||
return self
|
||||
|
||||
def update_key_state(self, data: bytes) -> None:
|
||||
if data == Keys.Down.Ctrl:
|
||||
self.key_state.ctrl_is_down = True
|
||||
elif data == Keys.Up.Ctrl:
|
||||
self.key_state.ctrl_is_down = False
|
||||
elif data == Keys.Down.Alt:
|
||||
self.key_state.alt_is_down = True
|
||||
elif data == Keys.Up.Alt:
|
||||
self.key_state.alt_is_down = False
|
||||
elif data == Keys.Down.Cmd:
|
||||
self.key_state.cmd_is_down = True
|
||||
elif data == Keys.Up.Cmd:
|
||||
self.key_state.cmd_is_down = False
|
||||
|
||||
|
||||
async def copy_text(vnc_channel: VncChannel) -> None:
|
||||
class_name = vnc_channel.class_name
|
||||
LOG.info(f"{class_name} Retrieving selected text via CDP", **vnc_channel.identity)
|
||||
|
||||
try:
|
||||
async with execution_channel(vnc_channel) as execute:
|
||||
copied_text = await execute.get_selected_text()
|
||||
|
||||
message_channel = get_message_channel(vnc_channel.client_id)
|
||||
|
||||
if message_channel:
|
||||
await message_channel.send_copied_text(copied_text)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"{class_name} No message channel found for client, or it is not open",
|
||||
message_channel=message_channel,
|
||||
**vnc_channel.identity,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(f"{class_name} Failed to retrieve selected text via CDP", **vnc_channel.identity)
|
||||
|
||||
|
||||
async def ask_for_clipboard(vnc_channel: VncChannel) -> None:
|
||||
class_name = vnc_channel.class_name
|
||||
LOG.info(f"{class_name} Asking for clipboard data via CDP", **vnc_channel.identity)
|
||||
|
||||
try:
|
||||
message_channel = get_message_channel(vnc_channel.client_id)
|
||||
|
||||
if message_channel:
|
||||
await message_channel.ask_for_clipboard()
|
||||
else:
|
||||
LOG.warning(
|
||||
f"{class_name} No message channel found for client, or it is not open",
|
||||
message_channel=message_channel,
|
||||
**vnc_channel.identity,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(f"{class_name} Failed to ask for clipboard via CDP", **vnc_channel.identity)
|
||||
|
||||
|
||||
async def loop_stream_vnc(vnc_channel: VncChannel) -> None:
|
||||
"""
|
||||
Actually stream the VNC data between a frontend and a browser.
|
||||
|
||||
Loops until the task is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
vnc_url: str = ""
|
||||
browser_session = vnc_channel.browser_session
|
||||
class_name = vnc_channel.class_name
|
||||
|
||||
if browser_session:
|
||||
if browser_session.ip_address:
|
||||
if ":" in browser_session.ip_address:
|
||||
ip, _ = browser_session.ip_address.split(":")
|
||||
vnc_url = f"ws://{ip}:{vnc_channel.vnc_port}"
|
||||
else:
|
||||
vnc_url = f"ws://{browser_session.ip_address}:{vnc_channel.vnc_port}"
|
||||
else:
|
||||
browser_address = browser_session.browser_address
|
||||
|
||||
parsed_browser_address = urlparse(browser_address)
|
||||
host = parsed_browser_address.hostname
|
||||
vnc_url = f"ws://{host}:{vnc_channel.vnc_port}"
|
||||
else:
|
||||
raise Exception(f"{class_name} No browser session associated with vnc channel.")
|
||||
|
||||
# NOTE(jdo:streaming-local-dev)
|
||||
# vnc_url = "ws://localhost:6080"
|
||||
|
||||
LOG.info(
|
||||
f"{class_name} Connecting to vnc url.",
|
||||
vnc_url=vnc_url,
|
||||
**vnc_channel.identity,
|
||||
)
|
||||
|
||||
async with websockets.connect(vnc_url) as novnc_ws:
|
||||
|
||||
async def frontend_to_browser() -> None:
|
||||
nonlocal class_name
|
||||
|
||||
LOG.info(f"{class_name} Starting frontend-to-browser data transfer.", **vnc_channel.identity)
|
||||
data: Data | None = None
|
||||
|
||||
while vnc_channel.is_open:
|
||||
try:
|
||||
data = await vnc_channel.websocket.receive_bytes()
|
||||
|
||||
if data:
|
||||
message_type = data[0]
|
||||
|
||||
if message_type == MessageType.Keyboard.value:
|
||||
vnc_channel.update_key_state(data)
|
||||
|
||||
if vnc_channel.key_state.is_copy(data):
|
||||
await copy_text(vnc_channel)
|
||||
|
||||
if vnc_channel.key_state.is_paste(data):
|
||||
await ask_for_clipboard(vnc_channel)
|
||||
|
||||
if vnc_channel.key_state.is_forbidden(data):
|
||||
continue
|
||||
|
||||
# prevent right-mouse-button clicks for "security" reasons
|
||||
if message_type == MessageType.Mouse.value:
|
||||
if Mouse.Up.Right(data):
|
||||
continue
|
||||
|
||||
if not vnc_channel.interactor == "user" and message_type in (
|
||||
MessageType.Keyboard.value,
|
||||
MessageType.Mouse.value,
|
||||
):
|
||||
LOG.debug(f"{class_name} Blocking user message.", **vnc_channel.identity)
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(f"{class_name} Frontend disconnected.", **vnc_channel.identity)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(f"{class_name} Frontend closed the vnc channel.", **vnc_channel.identity)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(f"{class_name} An unexpected exception occurred.", **vnc_channel.identity)
|
||||
raise
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
try:
|
||||
await novnc_ws.send(data)
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(f"{class_name} Browser disconnected from vnc.", **vnc_channel.identity)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(f"{class_name} Browser closed vnc.", **vnc_channel.identity)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
f"{class_name} An unexpected exception occurred in frontend-to-browser loop.",
|
||||
**vnc_channel.identity,
|
||||
)
|
||||
raise
|
||||
|
||||
async def browser_to_frontend() -> None:
|
||||
nonlocal class_name
|
||||
|
||||
LOG.info(f"{class_name} Starting browser-to-frontend data transfer.", **vnc_channel.identity)
|
||||
data: Data | None = None
|
||||
|
||||
while vnc_channel.is_open:
|
||||
try:
|
||||
data = await novnc_ws.recv()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(f"{class_name} Browser disconnected from the vnc channel session.", **vnc_channel.identity)
|
||||
await vnc_channel.close(reason="browser-disconnected")
|
||||
except ConnectionClosedError:
|
||||
LOG.info(f"{class_name} Browser closed the vnc channel session.", **vnc_channel.identity)
|
||||
await vnc_channel.close(reason="browser-closed")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
f"{class_name} An unexpected exception occurred in browser-to-frontend loop.",
|
||||
**vnc_channel.identity,
|
||||
)
|
||||
raise
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
try:
|
||||
if vnc_channel.websocket.client_state != WebSocketState.CONNECTED:
|
||||
continue
|
||||
await vnc_channel.websocket.send_bytes(data)
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
f"{class_name} Frontend disconnected from the vnc channel session.", **vnc_channel.identity
|
||||
)
|
||||
await vnc_channel.close(reason="frontend-disconnected")
|
||||
except ConnectionClosedError:
|
||||
LOG.info(f"{class_name} Frontend closed the vnc channel session.", **vnc_channel.identity)
|
||||
await vnc_channel.close(reason="frontend-closed")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(f"{class_name} An unexpected exception occurred.", **vnc_channel.identity)
|
||||
raise
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(frontend_to_browser()),
|
||||
asyncio.create_task(browser_to_frontend()),
|
||||
]
|
||||
|
||||
try:
|
||||
await collect(loops)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(f"{class_name} An exception occurred in loop stream.", **vnc_channel.identity)
|
||||
finally:
|
||||
LOG.info(f"{class_name} Closing the loop stream.", **vnc_channel.identity)
|
||||
await vnc_channel.close(reason="loop-stream-vnc-closed")
|
||||
|
||||
|
||||
async def get_vnc_channel_for_browser_session(
|
||||
client_id: str,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[VncChannel, Loops] | None:
|
||||
"""
|
||||
Return a vnc channel for a browser session, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting vnc context for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
|
||||
)
|
||||
return None
|
||||
|
||||
x_api_key = await get_x_api_key(organization_id)
|
||||
|
||||
try:
|
||||
vnc_channel = VncChannel(
|
||||
client_id=client_id,
|
||||
initial_interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
browser_session=browser_session,
|
||||
x_api_key=x_api_key,
|
||||
websocket=websocket,
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to create VncChannel.", error=str(e))
|
||||
return None
|
||||
|
||||
LOG.info("Got vnc context for browser session.", vnc_channel=vnc_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_browser_session(vnc_channel)),
|
||||
asyncio.create_task(loop_stream_vnc(vnc_channel)),
|
||||
]
|
||||
|
||||
return vnc_channel, loops
|
||||
|
||||
|
||||
async def get_vnc_channel_for_task(
|
||||
client_id: str,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[VncChannel, Loops] | None:
|
||||
"""
|
||||
Return a vnc channel for a task, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
|
||||
|
||||
if not task:
|
||||
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
x_api_key = await get_x_api_key(organization_id)
|
||||
|
||||
vnc_channel = VncChannel(
|
||||
client_id=client_id,
|
||||
initial_interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
x_api_key=x_api_key,
|
||||
websocket=websocket,
|
||||
browser_session=browser_session,
|
||||
task=task,
|
||||
)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_task(vnc_channel)),
|
||||
asyncio.create_task(loop_stream_vnc(vnc_channel)),
|
||||
]
|
||||
|
||||
return vnc_channel, loops
|
||||
|
||||
|
||||
async def get_vnc_channel_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[VncChannel, Loops] | None:
|
||||
"""
|
||||
Return a vnc channel for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting vnc channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No initial browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
x_api_key = await get_x_api_key(organization_id)
|
||||
|
||||
vnc_channel = VncChannel(
|
||||
client_id=client_id,
|
||||
initial_interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
browser_session=browser_session,
|
||||
workflow_run=workflow_run,
|
||||
x_api_key=x_api_key,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got vnc channel context for workflow run.", vnc_channel=vnc_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_workflow_run(vnc_channel)),
|
||||
asyncio.create_task(loop_stream_vnc(vnc_channel)),
|
||||
]
|
||||
|
||||
return vnc_channel, loops
|
||||
@@ -1,328 +0,0 @@
|
||||
"""
|
||||
Streaming types.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
Interactor = t.Literal["agent", "user"]
|
||||
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||
|
||||
|
||||
# Messages
|
||||
|
||||
|
||||
# a global registry for WS message clients
|
||||
message_channels: dict[str, "MessageChannel"] = {}
|
||||
|
||||
|
||||
def add_message_client(message_channel: "MessageChannel") -> None:
|
||||
message_channels[message_channel.client_id] = message_channel
|
||||
|
||||
|
||||
def get_message_client(client_id: str) -> t.Union["MessageChannel", None]:
|
||||
return message_channels.get(client_id, None)
|
||||
|
||||
|
||||
def del_message_client(client_id: str) -> None:
|
||||
try:
|
||||
del message_channels[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageChannel:
|
||||
client_id: str
|
||||
organization_id: str
|
||||
websocket: WebSocket
|
||||
|
||||
# --
|
||||
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
add_message_client(self)
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
|
||||
LOG.info("Closing message stream.", reason=reason, code=code)
|
||||
|
||||
self.browser_session = None
|
||||
self.workflow_run = None
|
||||
|
||||
try:
|
||||
await self.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
del_message_client(self.client_id)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||
return False
|
||||
|
||||
if not self.workflow_run and not self.browser_session:
|
||||
return False
|
||||
|
||||
if not get_message_client(self.client_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def ask_for_clipboard(self, streaming: "Streaming") -> None:
|
||||
try:
|
||||
await self.websocket.send_json(
|
||||
{
|
||||
"kind": "ask-for-clipboard",
|
||||
}
|
||||
)
|
||||
LOG.info(
|
||||
"Sent ask-for-clipboard to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to send ask-for-clipboard to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
async def send_copied_text(self, copied_text: str, streaming: "Streaming") -> None:
|
||||
try:
|
||||
await self.websocket.send_json(
|
||||
{
|
||||
"kind": "copied-text",
|
||||
"text": copied_text,
|
||||
}
|
||||
)
|
||||
LOG.info(
|
||||
"Sent copied text to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to send copied text to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
|
||||
MessageKinds = t.Literal["take-control", "cede-control", "ask-for-clipboard-response"]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Message:
|
||||
kind: MessageKinds
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageTakeControl(Message):
|
||||
kind: t.Literal["take-control"] = "take-control"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageCedeControl(Message):
|
||||
kind: t.Literal["cede-control"] = "cede-control"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInAskForClipboardResponse(Message):
|
||||
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
|
||||
text: str = ""
|
||||
|
||||
|
||||
ChannelMessage = t.Union[MessageTakeControl, MessageCedeControl, MessageInAskForClipboardResponse]
|
||||
|
||||
|
||||
def reify_channel_message(data: dict) -> ChannelMessage:
|
||||
kind = data.get("kind", None)
|
||||
|
||||
match kind:
|
||||
case "take-control":
|
||||
return MessageTakeControl()
|
||||
case "cede-control":
|
||||
return MessageCedeControl()
|
||||
case "ask-for-clipboard-response":
|
||||
text = data.get("text") or ""
|
||||
return MessageInAskForClipboardResponse(text=text)
|
||||
case _:
|
||||
raise ValueError(f"Unknown message kind: '{kind}'")
|
||||
|
||||
|
||||
# Streaming
|
||||
|
||||
|
||||
# a global registry for WS streaming VNC clients
|
||||
streaming_clients: dict[str, "Streaming"] = {}
|
||||
|
||||
|
||||
def add_streaming_client(streaming: "Streaming") -> None:
|
||||
streaming_clients[streaming.client_id] = streaming
|
||||
|
||||
|
||||
def get_streaming_client(client_id: str) -> t.Union["Streaming", None]:
|
||||
return streaming_clients.get(client_id, None)
|
||||
|
||||
|
||||
def del_streaming_client(client_id: str) -> None:
|
||||
try:
|
||||
del streaming_clients[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
class MessageType(IntEnum):
|
||||
Keyboard = 4
|
||||
Mouse = 5
|
||||
|
||||
|
||||
class Keys:
|
||||
"""
|
||||
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
|
||||
|
||||
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
|
||||
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
|
||||
"""
|
||||
|
||||
class Down:
|
||||
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
|
||||
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
|
||||
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
|
||||
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
|
||||
|
||||
class Up:
|
||||
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
|
||||
|
||||
|
||||
def is_rmb(data: bytes) -> bool:
|
||||
return data[0:2] == b"\x05\x04"
|
||||
|
||||
|
||||
class Mouse:
|
||||
class Up:
|
||||
Right = is_rmb
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KeyState:
|
||||
ctrl_is_down: bool = False
|
||||
alt_is_down: bool = False
|
||||
cmd_is_down: bool = False
|
||||
|
||||
def is_forbidden(self, data: bytes) -> bool:
|
||||
"""
|
||||
:return: True if the key is forbidden, else False
|
||||
"""
|
||||
return self.is_ctrl_o(data)
|
||||
|
||||
def is_ctrl_o(self, data: bytes) -> bool:
|
||||
"""
|
||||
Do not allow the opening of files.
|
||||
"""
|
||||
return self.ctrl_is_down and data == Keys.Down.OKey
|
||||
|
||||
def is_copy(self, data: bytes) -> bool:
|
||||
"""
|
||||
Detect Ctrl+C or Cmd+C for copy.
|
||||
"""
|
||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
|
||||
|
||||
def is_paste(self, data: bytes) -> bool:
|
||||
"""
|
||||
Detect Ctrl+V or Cmd+V for paste.
|
||||
"""
|
||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Streaming:
|
||||
"""
|
||||
Streaming state.
|
||||
"""
|
||||
|
||||
client_id: str
|
||||
"""
|
||||
Unique to frontend app instance.
|
||||
"""
|
||||
|
||||
interactor: Interactor
|
||||
"""
|
||||
Whether the user or the agent are the interactor.
|
||||
"""
|
||||
|
||||
organization_id: str
|
||||
vnc_port: int
|
||||
x_api_key: str
|
||||
websocket: WebSocket
|
||||
|
||||
# --
|
||||
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
key_state: KeyState = dataclasses.field(default_factory=KeyState)
|
||||
task: Task | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
add_streaming_client(self)
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||
return False
|
||||
|
||||
if not self.task and not self.workflow_run and not self.browser_session:
|
||||
return False
|
||||
|
||||
if not get_streaming_client(self.client_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
|
||||
LOG.info("Closing Streaming.", reason=reason, code=code)
|
||||
|
||||
self.browser_session = None
|
||||
self.task = None
|
||||
self.workflow_run = None
|
||||
|
||||
try:
|
||||
await self.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
del_streaming_client(self.client_id)
|
||||
|
||||
return self
|
||||
|
||||
def update_key_state(self, data: bytes) -> None:
|
||||
if data == Keys.Down.Ctrl:
|
||||
self.key_state.ctrl_is_down = True
|
||||
elif data == Keys.Up.Ctrl:
|
||||
self.key_state.ctrl_is_down = False
|
||||
elif data == Keys.Down.Alt:
|
||||
self.key_state.alt_is_down = True
|
||||
elif data == Keys.Up.Alt:
|
||||
self.key_state.alt_is_down = False
|
||||
elif data == Keys.Down.Cmd:
|
||||
self.key_state.cmd_is_down = True
|
||||
elif data == Keys.Up.Cmd:
|
||||
self.key_state.cmd_is_down = False
|
||||
@@ -1,240 +1,24 @@
|
||||
"""
|
||||
Streaming messages for WebSocket connections.
|
||||
Provides WS endpoints for streaming messages to/from our frontend application.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
from fastapi import WebSocket
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming.verify import (
|
||||
loop_verify_browser_session,
|
||||
loop_verify_workflow_run,
|
||||
verify_browser_session,
|
||||
verify_workflow_run,
|
||||
from skyvern.forge.sdk.routes.streaming.channels.message import (
|
||||
Loops,
|
||||
MessageChannel,
|
||||
get_message_channel_for_browser_session,
|
||||
get_message_channel_for_workflow_run,
|
||||
)
|
||||
from skyvern.forge.sdk.utils.aio import collect
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_messages_for_browser_session(
|
||||
client_id: str,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
||||
"""
|
||||
Return a message channel for a browser session, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Message channel: no initial browser session found.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
message_channel = sc.MessageChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got message channel for browser session.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_browser_session(message_channel)),
|
||||
asyncio.create_task(loop_channel(message_channel)),
|
||||
]
|
||||
|
||||
return message_channel, loops
|
||||
|
||||
|
||||
async def get_messages_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
||||
"""
|
||||
Return a message channel for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info(
|
||||
"Message channel: no initial workflow run found.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Message channel: no initial browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
message_channel = sc.MessageChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
workflow_run=workflow_run,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_workflow_run(message_channel)),
|
||||
asyncio.create_task(loop_channel(message_channel)),
|
||||
]
|
||||
|
||||
return message_channel, loops
|
||||
|
||||
|
||||
async def loop_channel(message_channel: sc.MessageChannel) -> None:
|
||||
"""
|
||||
Stream messages and their results back and forth.
|
||||
|
||||
Loops until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
if not message_channel.browser_session:
|
||||
LOG.info(
|
||||
"No browser session found for workflow run.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
async def frontend_to_backend() -> None:
|
||||
LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel)
|
||||
|
||||
while message_channel.is_open:
|
||||
try:
|
||||
data = await message_channel.websocket.receive_json()
|
||||
|
||||
if not isinstance(data, dict):
|
||||
LOG.error(f"Cannot create channel message: expected dict, got {type(data)}")
|
||||
continue
|
||||
|
||||
try:
|
||||
message = sc.reify_channel_message(data)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
message_kind = message.kind
|
||||
|
||||
match message_kind:
|
||||
case "take-control":
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
if not streaming:
|
||||
LOG.error(
|
||||
"No streaming client found for message.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
streaming.interactor = "user"
|
||||
case "cede-control":
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
if not streaming:
|
||||
LOG.error(
|
||||
"No streaming client found for message.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
streaming.interactor = "agent"
|
||||
case "ask-for-clipboard-response":
|
||||
if not isinstance(message, sc.MessageInAskForClipboardResponse):
|
||||
LOG.error(
|
||||
"Invalid message type for ask-for-clipboard-response.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
text = message.text
|
||||
|
||||
async with connected_agent(streaming) as agent:
|
||||
await agent.paste_text(text)
|
||||
case _:
|
||||
LOG.error(f"Unknown message kind: '{message_kind}'")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Frontend disconnected.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Frontend closed the streaming session.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(frontend_to_backend()),
|
||||
]
|
||||
|
||||
try:
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in loop channel stream.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the loop channel stream.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
await message_channel.close(reason="loop-channel-closed")
|
||||
|
||||
|
||||
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
|
||||
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
|
||||
async def browser_session_messages(
|
||||
websocket: WebSocket,
|
||||
browser_session_id: str,
|
||||
@@ -242,64 +26,16 @@ async def browser_session_messages(
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
LOG.error("Authentication failed.", browser_session_id=browser_session_id)
|
||||
return
|
||||
|
||||
if not client_id:
|
||||
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
|
||||
return
|
||||
|
||||
message_channel: sc.MessageChannel
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
result = await get_messages_for_browser_session(
|
||||
client_id=client_id,
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
return await messages(
|
||||
websocket=websocket,
|
||||
browser_session_id=browser_session_id,
|
||||
apikey=apikey,
|
||||
client_id=client_id,
|
||||
token=token,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
message_channel, loops = result
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting message stream loops for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the message stream function for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the message stream session for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await message_channel.close(reason="stream-closed")
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
|
||||
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
|
||||
async def workflow_run_messages(
|
||||
websocket: WebSocket,
|
||||
workflow_run_id: str,
|
||||
@@ -307,31 +43,77 @@ async def workflow_run_messages(
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting message stream.", workflow_run_id=workflow_run_id)
|
||||
return await messages(
|
||||
websocket=websocket,
|
||||
workflow_run_id=workflow_run_id,
|
||||
apikey=apikey,
|
||||
client_id=client_id,
|
||||
token=token,
|
||||
)
|
||||
|
||||
|
||||
async def messages(
|
||||
websocket: WebSocket,
|
||||
browser_session_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info(
|
||||
"Starting message stream.",
|
||||
browser_session_id=browser_session_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
|
||||
LOG.error(
|
||||
"Authentication failed.",
|
||||
browser_session_id=browser_session_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
if not client_id:
|
||||
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
|
||||
LOG.error(
|
||||
"No client ID provided.",
|
||||
browser_session_id=browser_session_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
message_channel: sc.MessageChannel
|
||||
loops: list[asyncio.Task] = []
|
||||
message_channel: MessageChannel
|
||||
loops: Loops = []
|
||||
|
||||
result = await get_messages_for_workflow_run(
|
||||
client_id=client_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
if browser_session_id:
|
||||
result = await get_message_channel_for_browser_session(
|
||||
client_id=client_id,
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
elif workflow_run_id:
|
||||
result = await get_message_channel_for_workflow_run(
|
||||
client_id=client_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
else:
|
||||
LOG.error(
|
||||
"Message channel: no browser_session_id or workflow_run_id provided.",
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1002)
|
||||
return
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the workflow run.",
|
||||
"No message channel found.",
|
||||
browser_session_id=browser_session_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
@@ -342,22 +124,25 @@ async def workflow_run_messages(
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting message stream loops.",
|
||||
"Starting message channel loops.",
|
||||
browser_session_id=browser_session_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the message stream function.",
|
||||
"An exception occurred in the message loop function(s).",
|
||||
browser_session_id=browser_session_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the message stream session.",
|
||||
"Closing the message channel.",
|
||||
browser_session_id=browser_session_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await message_channel.close(reason="stream-closed")
|
||||
await message_channel.close(reason="message-stream-closed")
|
||||
|
||||
78
skyvern/forge/sdk/routes/streaming/registries.py
Normal file
78
skyvern/forge/sdk/routes/streaming/registries.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Contains registries for coordinating active WS connections (aka "channels", see
|
||||
`./channels/README.md`).
|
||||
|
||||
NOTE: in AWS we had to turn on what amounts to sticky sessions for frontend apps,
|
||||
so that an individual frontend app instance is guaranteed to always connect to
|
||||
the same backend api instance. This is beccause the two registries here are
|
||||
tied together via a `client_id` string.
|
||||
|
||||
The tale-of-the-tape is this:
|
||||
- frontend app requires two different channels (WS connections) to the backend api
|
||||
- one dedicated to streaming VNC's RFB protocol
|
||||
- the other dedicated to messaging (JSON)
|
||||
- both of these channels are stateful and need to coordinate with one another
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import structlog
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel
|
||||
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
# a registry for VNC channels, keyed by `client_id`
|
||||
vnc_channels: dict[str, VncChannel] = {}
|
||||
|
||||
|
||||
def add_vnc_channel(vnc_channel: VncChannel) -> None:
|
||||
vnc_channels[vnc_channel.client_id] = vnc_channel
|
||||
|
||||
|
||||
def get_vnc_channel(client_id: str) -> t.Union[VncChannel, None]:
|
||||
return vnc_channels.get(client_id, None)
|
||||
|
||||
|
||||
def del_vnc_channel(client_id: str) -> None:
|
||||
try:
|
||||
del vnc_channels[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
# a registry for message channels, keyed by `client_id`
|
||||
message_channels: dict[str, MessageChannel] = {}
|
||||
|
||||
|
||||
def add_message_channel(message_channel: MessageChannel) -> None:
|
||||
message_channels[message_channel.client_id] = message_channel
|
||||
|
||||
|
||||
def get_message_channel(client_id: str) -> t.Union[MessageChannel, None]:
|
||||
candidate = message_channels.get(client_id, None)
|
||||
|
||||
if candidate and candidate.is_open:
|
||||
return candidate
|
||||
|
||||
if candidate:
|
||||
LOG.info(
|
||||
"MessageChannel: message channel is not open; deleting it",
|
||||
client_id=candidate.client_id,
|
||||
)
|
||||
|
||||
del_message_channel(candidate.client_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def del_message_channel(client_id: str) -> None:
|
||||
try:
|
||||
del message_channels[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
@@ -1,3 +1,14 @@
|
||||
"""
|
||||
Provides WS endpoints for streaming screenshots.
|
||||
|
||||
Screenshot streaming is created on the basis of one of these database entities:
|
||||
- task (run)
|
||||
- workflow run
|
||||
|
||||
Screenshot streaming is used for a run that is invoked without a browser session.
|
||||
Otherwise, VNC streaming is used.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from datetime import datetime
|
||||
|
||||
@@ -1,18 +1,43 @@
|
||||
"""
|
||||
Channels (see ./channels/README.md) variously rely on the state of one of
|
||||
these database entities:
|
||||
- browser session
|
||||
- task
|
||||
- workflow run
|
||||
|
||||
That is, channels are created on the basis of one of those entities, and that
|
||||
entity must be in a valid state for the channel to continue.
|
||||
|
||||
So, in order to continue operating a channel, we need to periodically verify
|
||||
that the underlying entity is still valid. This module provides logic to
|
||||
perform those verifications.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
import structlog
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel
|
||||
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class Constants:
|
||||
POLL_INTERVAL_FOR_VERIFICATION_SECONDS = 5
|
||||
|
||||
|
||||
async def verify_browser_session(
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
@@ -122,9 +147,7 @@ async def verify_task(
|
||||
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
|
||||
)
|
||||
LOG.error("vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e)
|
||||
return task, None
|
||||
|
||||
return task, addressable_browser_session
|
||||
@@ -238,7 +261,7 @@ async def verify_workflow_run(
|
||||
return workflow_run, addressable_browser_session
|
||||
|
||||
|
||||
async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streaming) -> None:
|
||||
async def loop_verify_browser_session(verifiable: MessageChannel | VncChannel) -> None:
|
||||
"""
|
||||
Loop until the browser session is cleared or the websocket is closed.
|
||||
"""
|
||||
@@ -251,27 +274,27 @@ async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streami
|
||||
|
||||
verifiable.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS)
|
||||
|
||||
|
||||
async def loop_verify_task(streaming: sc.Streaming) -> None:
|
||||
async def loop_verify_task(vnc_channel: VncChannel) -> None:
|
||||
"""
|
||||
Loop until the task is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
while streaming.task and streaming.is_open:
|
||||
while vnc_channel.task and vnc_channel.is_open:
|
||||
task, browser_session = await verify_task(
|
||||
task_id=streaming.task.task_id,
|
||||
organization_id=streaming.organization_id,
|
||||
task_id=vnc_channel.task.task_id,
|
||||
organization_id=vnc_channel.organization_id,
|
||||
)
|
||||
|
||||
streaming.task = task
|
||||
streaming.browser_session = browser_session
|
||||
vnc_channel.task = task
|
||||
vnc_channel.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS)
|
||||
|
||||
|
||||
async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) -> None:
|
||||
async def loop_verify_workflow_run(verifiable: MessageChannel | VncChannel) -> None:
|
||||
"""
|
||||
Loop until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
@@ -285,4 +308,4 @@ async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming)
|
||||
verifiable.workflow_run = workflow_run
|
||||
verifiable.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Streaming VNC WebSocket connections.
|
||||
Provides WS endpoints for streaming a remote browser via VNC.
|
||||
|
||||
NOTE(jdo:streaming-local-dev)
|
||||
-----------------------------
|
||||
@@ -9,459 +9,23 @@ NOTE(jdo:streaming-local-dev)
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import typing as t
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import structlog
|
||||
import websockets
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets import Data
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
from fastapi import WebSocket
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming.verify import (
|
||||
loop_verify_browser_session,
|
||||
loop_verify_task,
|
||||
loop_verify_workflow_run,
|
||||
verify_browser_session,
|
||||
verify_task,
|
||||
verify_workflow_run,
|
||||
from skyvern.forge.sdk.routes.streaming.channels.vnc import (
|
||||
Loops,
|
||||
VncChannel,
|
||||
get_vnc_channel_for_browser_session,
|
||||
get_vnc_channel_for_task,
|
||||
get_vnc_channel_for_workflow_run,
|
||||
)
|
||||
from skyvern.forge.sdk.utils.aio import collect
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class Constants:
|
||||
MissingXApiKey = "<missing-x-api-key>"
|
||||
|
||||
|
||||
async def get_x_api_key(organization_id: str) -> str:
|
||||
token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id,
|
||||
OrganizationAuthTokenType.api.value,
|
||||
)
|
||||
|
||||
if not token:
|
||||
LOG.warning(
|
||||
"No valid API key found for organization when streaming.",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
x_api_key = Constants.MissingXApiKey
|
||||
else:
|
||||
x_api_key = token.token
|
||||
|
||||
return x_api_key
|
||||
|
||||
|
||||
async def get_streaming_for_browser_session(
|
||||
client_id: str,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||
"""
|
||||
Return a streaming context for a browser session, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting streaming context for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
|
||||
)
|
||||
return None
|
||||
|
||||
x_api_key = await get_x_api_key(organization_id)
|
||||
|
||||
streaming = sc.Streaming(
|
||||
client_id=client_id,
|
||||
interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
browser_session=browser_session,
|
||||
x_api_key=x_api_key,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got streaming context for browser session.", streaming=streaming)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_browser_session(streaming)),
|
||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||
]
|
||||
|
||||
return streaming, loops
|
||||
|
||||
|
||||
async def get_streaming_for_task(
|
||||
client_id: str,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||
"""
|
||||
Return a streaming context for a task, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
|
||||
|
||||
if not task:
|
||||
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
x_api_key = await get_x_api_key(organization_id)
|
||||
|
||||
streaming = sc.Streaming(
|
||||
client_id=client_id,
|
||||
interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
x_api_key=x_api_key,
|
||||
websocket=websocket,
|
||||
browser_session=browser_session,
|
||||
task=task,
|
||||
)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_task(streaming)),
|
||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||
]
|
||||
|
||||
return streaming, loops
|
||||
|
||||
|
||||
async def get_streaming_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||
"""
|
||||
Return a streaming context for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting streaming context for workflow run.", workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No initial browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
x_api_key = await get_x_api_key(organization_id)
|
||||
|
||||
streaming = sc.Streaming(
|
||||
client_id=client_id,
|
||||
interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
browser_session=browser_session,
|
||||
workflow_run=workflow_run,
|
||||
x_api_key=x_api_key,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got streaming context for workflow run.", streaming=streaming)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_workflow_run(streaming)),
|
||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||
]
|
||||
|
||||
return streaming, loops
|
||||
|
||||
|
||||
def verify_message_channel(
|
||||
message_channel: sc.MessageChannel | None, streaming: sc.Streaming
|
||||
) -> sc.MessageChannel | t.Literal[False]:
|
||||
if message_channel and message_channel.is_open:
|
||||
return message_channel
|
||||
|
||||
LOG.warning(
|
||||
"No message channel found for client, or it is not open",
|
||||
message_channel=message_channel,
|
||||
client_id=streaming.client_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def copy_text(streaming: sc.Streaming) -> None:
|
||||
try:
|
||||
async with connected_agent(streaming) as agent:
|
||||
copied_text = await agent.get_selected_text()
|
||||
|
||||
LOG.info(
|
||||
"Retrieved selected text via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
message_channel = sc.get_message_client(streaming.client_id)
|
||||
|
||||
if cc := verify_message_channel(message_channel, streaming):
|
||||
await cc.send_copied_text(copied_text, streaming)
|
||||
else:
|
||||
LOG.warning(
|
||||
"No message channel found for client, or it is not open",
|
||||
message_channel=message_channel,
|
||||
client_id=streaming.client_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to retrieve selected text via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
|
||||
async def ask_for_clipboard(streaming: sc.Streaming) -> None:
|
||||
try:
|
||||
LOG.info(
|
||||
"Asking for clipboard data via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
message_channel = sc.get_message_client(streaming.client_id)
|
||||
|
||||
if cc := verify_message_channel(message_channel, streaming):
|
||||
await cc.ask_for_clipboard(streaming)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to ask for clipboard via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
|
||||
async def loop_stream_vnc(streaming: sc.Streaming) -> None:
|
||||
"""
|
||||
Actually stream the VNC session data between a frontend and a browser
|
||||
session.
|
||||
|
||||
Loops until the task is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
if not streaming.browser_session:
|
||||
LOG.info("No browser session found for task.", task=streaming.task, organization_id=streaming.organization_id)
|
||||
return
|
||||
|
||||
vnc_url: str = ""
|
||||
if streaming.browser_session.ip_address:
|
||||
if ":" in streaming.browser_session.ip_address:
|
||||
ip, _ = streaming.browser_session.ip_address.split(":")
|
||||
vnc_url = f"ws://{ip}:{streaming.vnc_port}"
|
||||
else:
|
||||
vnc_url = f"ws://{streaming.browser_session.ip_address}:{streaming.vnc_port}"
|
||||
else:
|
||||
browser_address = streaming.browser_session.browser_address
|
||||
|
||||
parsed_browser_address = urlparse(browser_address)
|
||||
host = parsed_browser_address.hostname
|
||||
vnc_url = f"ws://{host}:{streaming.vnc_port}"
|
||||
|
||||
# NOTE(jdo:streaming-local-dev)
|
||||
# vnc_url = "ws://localhost:9001/ws/novnc"
|
||||
|
||||
LOG.info(
|
||||
"Connecting to VNC URL.",
|
||||
vnc_url=vnc_url,
|
||||
task=streaming.task,
|
||||
workflow_run=streaming.workflow_run,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
async with websockets.connect(vnc_url) as novnc_ws:
|
||||
|
||||
async def frontend_to_browser() -> None:
|
||||
LOG.info("Starting frontend-to-browser data transfer.", streaming=streaming)
|
||||
data: Data | None = None
|
||||
|
||||
while streaming.is_open:
|
||||
try:
|
||||
data = await streaming.websocket.receive_bytes()
|
||||
|
||||
if data:
|
||||
message_type = data[0]
|
||||
|
||||
if message_type == sc.MessageType.Keyboard.value:
|
||||
streaming.update_key_state(data)
|
||||
|
||||
if streaming.key_state.is_copy(data):
|
||||
await copy_text(streaming)
|
||||
|
||||
if streaming.key_state.is_paste(data):
|
||||
await ask_for_clipboard(streaming)
|
||||
|
||||
if streaming.key_state.is_forbidden(data):
|
||||
continue
|
||||
|
||||
if message_type == sc.MessageType.Mouse.value:
|
||||
if sc.Mouse.Up.Right(data):
|
||||
continue
|
||||
|
||||
if not streaming.interactor == "user" and message_type in (
|
||||
sc.MessageType.Keyboard.value,
|
||||
sc.MessageType.Mouse.value,
|
||||
):
|
||||
LOG.info(
|
||||
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
|
||||
)
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info("Frontend disconnected.", task=streaming.task, organization_id=streaming.organization_id)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Frontend closed the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
try:
|
||||
await novnc_ws.send(data)
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Browser disconnected from the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Browser closed the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred in frontend-to-browser loop.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def browser_to_frontend() -> None:
|
||||
LOG.info("Starting browser-to-frontend data transfer.", streaming=streaming)
|
||||
data: Data | None = None
|
||||
|
||||
while streaming.is_open:
|
||||
try:
|
||||
data = await novnc_ws.recv()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Browser disconnected from the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
await streaming.close(reason="browser-disconnected")
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Browser closed the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
await streaming.close(reason="browser-closed")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred in browser-to-frontend loop.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
try:
|
||||
await streaming.websocket.send_bytes(data)
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Frontend disconnected from the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
await streaming.close(reason="frontend-disconnected")
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Frontend closed the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
await streaming.close(reason="frontend-closed")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(frontend_to_browser()),
|
||||
asyncio.create_task(browser_to_frontend()),
|
||||
]
|
||||
|
||||
try:
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in loop stream.", task=streaming.task, organization_id=streaming.organization_id
|
||||
)
|
||||
finally:
|
||||
LOG.info("Closing the loop stream.", task=streaming.task, organization_id=streaming.organization_id)
|
||||
await streaming.close(reason="loop-stream-vnc-closed")
|
||||
|
||||
|
||||
@base_router.websocket("/stream/vnc/browser_session/{browser_session_id}")
|
||||
async def browser_session_stream(
|
||||
websocket: WebSocket,
|
||||
@@ -507,7 +71,7 @@ async def stream(
|
||||
) -> None:
|
||||
if not client_id:
|
||||
LOG.error(
|
||||
"Client ID not provided for VNC stream.",
|
||||
"Client ID not provided for vnc stream.",
|
||||
browser_session_id=browser_session_id,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
@@ -515,7 +79,7 @@ async def stream(
|
||||
return
|
||||
|
||||
LOG.info(
|
||||
"Starting VNC stream.",
|
||||
"Starting vnc stream.",
|
||||
browser_session_id=browser_session_id,
|
||||
client_id=client_id,
|
||||
task_id=task_id,
|
||||
@@ -528,11 +92,11 @@ async def stream(
|
||||
LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
streaming: sc.Streaming
|
||||
loops: list[asyncio.Task] = []
|
||||
vnc_channel: VncChannel
|
||||
loops: Loops
|
||||
|
||||
if browser_session_id:
|
||||
result = await get_streaming_for_browser_session(
|
||||
result = await get_vnc_channel_for_browser_session(
|
||||
client_id=client_id,
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
@@ -541,22 +105,22 @@ async def stream(
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the browser session.",
|
||||
"No vnc context found for the browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
streaming, loops = result
|
||||
vnc_channel, loops = result
|
||||
|
||||
LOG.info(
|
||||
"Starting streaming for browser session.",
|
||||
"Starting vnc for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
elif task_id:
|
||||
result = await get_streaming_for_task(
|
||||
result = await get_vnc_channel_for_task(
|
||||
client_id=client_id,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
@@ -564,19 +128,17 @@ async def stream(
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
|
||||
LOG.error("No vnc context found for the task.", task_id=task_id, organization_id=organization_id)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
streaming, loops = result
|
||||
vnc_channel, loops = result
|
||||
|
||||
LOG.info("Starting streaming for task.", task_id=task_id, organization_id=organization_id)
|
||||
LOG.info("Starting vnc for task.", task_id=task_id, organization_id=organization_id)
|
||||
|
||||
elif workflow_run_id:
|
||||
LOG.info(
|
||||
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
result = await get_streaming_for_workflow_run(
|
||||
LOG.info("Starting vnc for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
result = await get_vnc_channel_for_workflow_run(
|
||||
client_id=client_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
@@ -585,25 +147,23 @@ async def stream(
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the workflow run.",
|
||||
"No vnc context found for the workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
streaming, loops = result
|
||||
vnc_channel, loops = result
|
||||
|
||||
LOG.info(
|
||||
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
LOG.info("Starting vnc for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
else:
|
||||
LOG.error("Neither task ID nor workflow run ID was provided.")
|
||||
return
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting streaming loops.",
|
||||
"Starting vnc loops.",
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
@@ -611,16 +171,16 @@ async def stream(
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the stream function.",
|
||||
"An exception occurred in the vnc loop.",
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the streaming session.",
|
||||
"Closing the vnc session.",
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await streaming.close(reason="stream-closed")
|
||||
await vnc_channel.close(reason="vnc-closed")
|
||||
|
||||
0
skyvern/py.typed
Normal file
0
skyvern/py.typed
Normal file
Reference in New Issue
Block a user