Browser streaming refactor (#4064)
This commit is contained in:
@@ -1,191 +0,0 @@
|
|||||||
"""
|
|
||||||
A lightweight "agent" for interacting with the streaming browser over CDP.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import typing
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
|
||||||
|
|
||||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
|
||||||
from skyvern.config import settings
|
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingAgent:
|
|
||||||
"""
|
|
||||||
A minimal agent that can connect to a browser via CDP and execute JavaScript.
|
|
||||||
|
|
||||||
Specifically for operations during streaming sessions (like copy/pasting selected text, etc.).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, streaming: sc.Streaming) -> None:
|
|
||||||
self.streaming = streaming
|
|
||||||
self.browser: Browser | None = None
|
|
||||||
self.browser_context: BrowserContext | None = None
|
|
||||||
self.page: Page | None = None
|
|
||||||
self.pw: Playwright | None = None
|
|
||||||
|
|
||||||
async def connect(self, cdp_url: str | None = None) -> None:
|
|
||||||
url = cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL
|
|
||||||
|
|
||||||
LOG.info("StreamingAgent connecting to CDP", cdp_url=url)
|
|
||||||
|
|
||||||
pw = self.pw or await async_playwright().start()
|
|
||||||
|
|
||||||
self.pw = pw
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"x-api-key": self.streaming.x_api_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.browser = await pw.chromium.connect_over_cdp(url, headers=headers)
|
|
||||||
|
|
||||||
org_id = self.streaming.organization_id
|
|
||||||
browser_session_id = (
|
|
||||||
self.streaming.browser_session.persistent_browser_session_id if self.streaming.browser_session else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if browser_session_id:
|
|
||||||
cdp_session = await self.browser.new_browser_cdp_session()
|
|
||||||
await cdp_session.send(
|
|
||||||
"Browser.setDownloadBehavior",
|
|
||||||
{
|
|
||||||
"behavior": "allow",
|
|
||||||
"downloadPath": f"/app/downloads/{org_id}/{browser_session_id}",
|
|
||||||
"eventsEnabled": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
contexts = self.browser.contexts
|
|
||||||
if contexts:
|
|
||||||
LOG.info("StreamingAgent using existing browser context")
|
|
||||||
self.browser_context = contexts[0]
|
|
||||||
else:
|
|
||||||
LOG.warning("No existing browser context found, creating new one")
|
|
||||||
self.browser_context = await self.browser.new_context()
|
|
||||||
|
|
||||||
pages = self.browser_context.pages
|
|
||||||
if pages:
|
|
||||||
self.page = pages[0]
|
|
||||||
LOG.info("StreamingAgent connected to page", url=self.page.url)
|
|
||||||
else:
|
|
||||||
LOG.warning("No pages found in browser context")
|
|
||||||
|
|
||||||
LOG.info("StreamingAgent connected successfully")
|
|
||||||
|
|
||||||
async def evaluate_js(
|
|
||||||
self, expression: str, arg: str | int | float | bool | list | dict | None = None
|
|
||||||
) -> str | int | float | bool | list | dict | None:
|
|
||||||
if not self.page:
|
|
||||||
raise RuntimeError("StreamingAgent is not connected to a page. Call connect() first.")
|
|
||||||
|
|
||||||
LOG.info("StreamingAgent evaluating JS", expression=expression[:100])
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await self.page.evaluate(expression, arg)
|
|
||||||
LOG.info("StreamingAgent JS evaluation successful")
|
|
||||||
return result
|
|
||||||
except Exception as ex:
|
|
||||||
LOG.exception("StreamingAgent JS evaluation failed", expression=expression, ex=str(ex))
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_selected_text(self) -> str:
|
|
||||||
LOG.info("StreamingAgent getting selected text")
|
|
||||||
|
|
||||||
js_expression = """
|
|
||||||
() => {
|
|
||||||
const selection = window.getSelection();
|
|
||||||
return selection ? selection.toString() : '';
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
selected_text = await self.evaluate_js(js_expression)
|
|
||||||
|
|
||||||
if isinstance(selected_text, str) or selected_text is None:
|
|
||||||
LOG.info("StreamingAgent got selected text", length=len(selected_text) if selected_text else 0)
|
|
||||||
return selected_text or ""
|
|
||||||
|
|
||||||
raise RuntimeError(f"StreamingAgent selected text is not a string, but a(n) '{type(selected_text)}'")
|
|
||||||
|
|
||||||
async def paste_text(self, text: str) -> None:
|
|
||||||
LOG.info("StreamingAgent pasting text")
|
|
||||||
|
|
||||||
js_expression = """
|
|
||||||
(text) => {
|
|
||||||
const activeElement = document.activeElement;
|
|
||||||
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
|
|
||||||
const start = activeElement.selectionStart || 0;
|
|
||||||
const end = activeElement.selectionEnd || 0;
|
|
||||||
const value = activeElement.value || '';
|
|
||||||
activeElement.value = value.slice(0, start) + text + value.slice(end);
|
|
||||||
const newCursorPos = start + text.length;
|
|
||||||
activeElement.setSelectionRange(newCursorPos, newCursorPos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
await self.evaluate_js(js_expression, text)
|
|
||||||
|
|
||||||
LOG.info("StreamingAgent pasted text successfully")
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
LOG.info("StreamingAgent closing connection")
|
|
||||||
|
|
||||||
if self.browser:
|
|
||||||
await self.browser.close()
|
|
||||||
self.browser = None
|
|
||||||
self.browser_context = None
|
|
||||||
self.page = None
|
|
||||||
|
|
||||||
if self.pw:
|
|
||||||
await self.pw.stop()
|
|
||||||
self.pw = None
|
|
||||||
|
|
||||||
LOG.info("StreamingAgent closed")
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterator[StreamingAgent]:
|
|
||||||
"""
|
|
||||||
The first pass at this has us doing the following for every operation:
|
|
||||||
- creating a new agent
|
|
||||||
- connecting
|
|
||||||
- [doing smth]
|
|
||||||
- closing the agent
|
|
||||||
|
|
||||||
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
|
|
||||||
|
|
||||||
If it turns out it's too slow, we can refactor to keep a persistent agent per streaming client.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not streaming:
|
|
||||||
msg = "connected_agent: no streaming client provided."
|
|
||||||
LOG.error(msg)
|
|
||||||
|
|
||||||
raise Exception(msg)
|
|
||||||
|
|
||||||
if not streaming.browser_session or not streaming.browser_session.browser_address:
|
|
||||||
msg = "connected_agent: no browser session or browser address found for streaming client."
|
|
||||||
|
|
||||||
LOG.error(
|
|
||||||
msg,
|
|
||||||
client_id=streaming.client_id,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise Exception(msg)
|
|
||||||
|
|
||||||
agent = StreamingAgent(streaming=streaming)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await agent.connect(streaming.browser_session.browser_address)
|
|
||||||
|
|
||||||
# NOTE(jdo:streaming-local-dev): use BROWSER_REMOTE_DEBUGGING_URL from settings
|
|
||||||
# await agent.connect()
|
|
||||||
|
|
||||||
yield agent
|
|
||||||
finally:
|
|
||||||
await agent.close()
|
|
||||||
@@ -6,11 +6,35 @@ import structlog
|
|||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from websockets.exceptions import ConnectionClosedOK
|
from websockets.exceptions import ConnectionClosedOK
|
||||||
|
|
||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
from skyvern.forge.sdk.services.org_auth_service import get_current_org
|
from skyvern.forge.sdk.services.org_auth_service import get_current_org
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class Constants:
|
||||||
|
MISSING_API_KEY = "<missing-x-api-key>"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_x_api_key(organization_id: str) -> str:
|
||||||
|
token = await app.DATABASE.get_valid_org_auth_token(
|
||||||
|
organization_id,
|
||||||
|
OrganizationAuthTokenType.api.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
LOG.warning(
|
||||||
|
"No valid API key found for organization when streaming.",
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
x_api_key = Constants.MISSING_API_KEY
|
||||||
|
else:
|
||||||
|
x_api_key = token.token
|
||||||
|
|
||||||
|
return x_api_key
|
||||||
|
|
||||||
|
|
||||||
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
||||||
"""
|
"""
|
||||||
Accepts the websocket connection.
|
Accepts the websocket connection.
|
||||||
|
|||||||
336
skyvern/forge/sdk/routes/streaming/channels/README.md
Normal file
336
skyvern/forge/sdk/routes/streaming/channels/README.md
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
# Channels
|
||||||
|
|
||||||
|
A "channel", as used within the streaming mechanism of our remote browsers,
|
||||||
|
is a WebSocket fit to some particular purpose.
|
||||||
|
|
||||||
|
There is/are:
|
||||||
|
- a "VNC" channel that transmits NoVNC's RFB protocol data
|
||||||
|
- a "Message" channel that transmits JSON between the frontend app and
|
||||||
|
the api server
|
||||||
|
- "CDP" channels that send messages to a remote browser using CDP protocol
|
||||||
|
data
|
||||||
|
- an "Execution" channel (one-off executions)
|
||||||
|
- a soon-to-be "Exfiltration" channel (user event streaming)
|
||||||
|
|
||||||
|
In all cases, these are just WebSockets. They have been bucketed into "named channels"
|
||||||
|
to aid understanding.
|
||||||
|
|
||||||
|
These channels are described at the top of their respective files.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
WARN: below is an AI-generated architecture document for all of the code beneath
|
||||||
|
the `skyvern/forge/sdk/routes/streaming` directory. It looks correct.
|
||||||
|
|
||||||
|
### High-Level Component Diagram
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────┐
|
||||||
|
│ Frontend App │
|
||||||
|
│ (Skyvern) │
|
||||||
|
└────────┬────────┘
|
||||||
|
│
|
||||||
|
│ Two WebSocket Connections (paired via client_id)
|
||||||
|
│
|
||||||
|
┌────┴────┬──────────────────────────────────────────┐
|
||||||
|
│ │ │
|
||||||
|
│ VNC Channel Message Channel
|
||||||
|
http (RFB Protocol) (JSON Messages)
|
||||||
|
│ │ │
|
||||||
|
│ │ │
|
||||||
|
┌───▼─────────▼──────────────────────────────────────────▼────┐
|
||||||
|
│ API Server │
|
||||||
|
│ ┌────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Registries (In-Memory State) │ │
|
||||||
|
│ │ - vnc_channels: dict[client_id -> VncChannel] │ │
|
||||||
|
│ │ - message_channels: dict[client_id -> MessageChannel] │ │
|
||||||
|
│ └────────────────────────────────────────────────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ VNC Channel Logic: Message Channel Logic: │
|
||||||
|
│ - RFB pass-through - Copy/paste coordination │
|
||||||
|
│ - Keyboard/mouse filtering - Control handoff (agent/user)│
|
||||||
|
│ - Interactor control - Clipboard management │
|
||||||
|
│ - Copy/paste detection - Channel coordination │
|
||||||
|
│ │
|
||||||
|
│ CDP Channels (created on-demand): │
|
||||||
|
│ - ExecutionChannel: JS evaluation (paste, get selected) │
|
||||||
|
│ - ExfiltrationChannel: (future) user event streaming │
|
||||||
|
│ │
|
||||||
|
└────┬─────────────────────────────────────────────────────┬──┘
|
||||||
|
│ │
|
||||||
|
│ WebSocket (RFB) Playwright (CDP)│
|
||||||
|
│ │
|
||||||
|
┌────▼─────────────────────────────────────────────────────▼──┐
|
||||||
|
│ Persistent Browser Session │
|
||||||
|
│ ┌──────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ noVNC Server Chrome/Chromium │ │
|
||||||
|
│ │ (websockify) │ │
|
||||||
|
│ └──────────────────────────────────────────────────────┘ │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Channel Pairing & Sticky Sessions
|
||||||
|
|
||||||
|
**Critical Design Constraint**: The VNC and Message channels for a given frontend
|
||||||
|
instance MUST connect to the same API server instance because they coordinate
|
||||||
|
via in-memory registries keyed by `client_id`.
|
||||||
|
|
||||||
|
```
|
||||||
|
Frontend Instance (client_id="abc123")
|
||||||
|
│
|
||||||
|
├─→ VNC Channel ─────→ API Server Instance #2
|
||||||
|
│ ↓
|
||||||
|
│ vnc_channels["abc123"] = VncChannel
|
||||||
|
│ ↕ (coordinate via client_id)
|
||||||
|
└─→ Message Channel ──→ API Server Instance #2
|
||||||
|
↓
|
||||||
|
message_channels["abc123"] = MessageChannel
|
||||||
|
```
|
||||||
|
|
||||||
|
**Deployment Requirement**: Load balancer must use sticky sessions (e.g., cookie-based
|
||||||
|
or IP-based affinity) to ensure both WebSocket connections from the same client_id
|
||||||
|
reach the same backend instance.
|
||||||
|
|
||||||
|
### Channel Lifecycle & Verification
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ Channel Creation (per browser_session/task/workflow_run) │
|
||||||
|
└────────────────────┬─────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌────────────────────────────────────────┐
|
||||||
|
│ Initial Verification │
|
||||||
|
│ - verify_browser_session() │
|
||||||
|
│ - verify_task() │
|
||||||
|
│ - verify_workflow_run() │
|
||||||
|
│ Returns: entity + browser session │
|
||||||
|
└────────┬───────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌────────────────────────────────────────┐
|
||||||
|
│ Channel + Loops Created │
|
||||||
|
│ VncChannel/MessageChannel initialized │
|
||||||
|
│ + Added to registry │
|
||||||
|
└────────┬───────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌────────────────────────────────────────┐
|
||||||
|
│ Concurrent Loop Execution │
|
||||||
|
│ (via collect() - fail-fast) │
|
||||||
|
│ │
|
||||||
|
│ Loop 1: Verification Loop │
|
||||||
|
│ - Polls every 5s │
|
||||||
|
│ - Updates channel state │
|
||||||
|
│ - Exits if entity invalid │
|
||||||
|
│ │
|
||||||
|
│ Loop 2: Data Streaming Loop │
|
||||||
|
│ - VNC: bidirectional RFB │
|
||||||
|
│ - Message: JSON messages │
|
||||||
|
│ - Exits on disconnect │
|
||||||
|
└────────┬───────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌────────────────────────────────────────┐
|
||||||
|
│ Channel Cleanup │
|
||||||
|
│ - Close WebSocket │
|
||||||
|
│ - Remove from registry │
|
||||||
|
│ - Clear channel state │
|
||||||
|
└────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### VNC Channel Data Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
User Keyboard/Mouse Input
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ Frontend (noVNC client) │
|
||||||
|
│ Encodes input as RFB protocol bytes │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│ WebSocket (bytes)
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ API Server: VncChannel.loop_stream_vnc() │
|
||||||
|
│ frontend_to_browser() coroutine │
|
||||||
|
│ │
|
||||||
|
│ 1. Receive RFB bytes from frontend │
|
||||||
|
│ 2. Detect message type (keyboard=4, mouse=5) │
|
||||||
|
│ 3. Update key_state tracking │
|
||||||
|
│ 4. Check for special key combinations: │
|
||||||
|
│ - Ctrl+C / Cmd+C → copy_text() via CDP │
|
||||||
|
│ - Ctrl+V / Cmd+V → ask_for_clipboard() via Message │
|
||||||
|
│ - Ctrl+O → BLOCK (forbidden) │
|
||||||
|
│ 5. Check interactor mode: │
|
||||||
|
│ - If interactor=="agent" → BLOCK user input │
|
||||||
|
│ - If interactor=="user" → PASS THROUGH │
|
||||||
|
│ 6. Block right-mouse-button (security) │
|
||||||
|
│ 7. Forward to noVNC server │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│ WebSocket (bytes)
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ Persistent Browser: noVNC Server (websockify) │
|
||||||
|
│ Translates RFB → VNC protocol │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ Browser Display Update │
|
||||||
|
└───────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Screen Updates (reverse direction):
|
||||||
|
Browser → noVNC → VncChannel.browser_to_frontend() → Frontend
|
||||||
|
```
|
||||||
|
|
||||||
|
### Message Channel + CDP Execution Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
User Pastes (Ctrl+V detected in VNC channel)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ VncChannel: ask_for_clipboard() │
|
||||||
|
│ Finds MessageChannel via registry[client_id] │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ MessageChannel.ask_for_clipboard() │
|
||||||
|
│ Sends: {"kind": "ask-for-clipboard"} │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│ WebSocket (JSON)
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ Frontend: User's clipboard content │
|
||||||
|
│ Responds: {"kind": "ask-for-clipboard-response", │
|
||||||
|
│ "text": "clipboard content"} │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│ WebSocket (JSON)
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ MessageChannel: handle_data() │
|
||||||
|
│ Receives clipboard text │
|
||||||
|
│ Finds VncChannel via registry[client_id] │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ ExecutionChannel (CDP) │
|
||||||
|
│ 1. Connect to browser via Playwright │
|
||||||
|
│ 2. Get browser context + page │
|
||||||
|
│ 3. evaluate_js(paste_text_script, clipboard_text) │
|
||||||
|
│ 4. Close CDP connection │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│ CDP over WebSocket
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ Browser: Text pasted into active element │
|
||||||
|
└───────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Similar flow for Copy (Ctrl+C):
|
||||||
|
VNC detects → ExecutionChannel.get_selected_text() →
|
||||||
|
MessageChannel sends {"kind": "copied-text", "text": "..."}
|
||||||
|
→ Frontend updates clipboard
|
||||||
|
```
|
||||||
|
|
||||||
|
### Control Flow: Agent ↔ User Interaction
|
||||||
|
|
||||||
|
NOTE: we don't really have an "agent" at this time. But any control of the
|
||||||
|
browser that is not user-originated is kinda' agent-like, by some
|
||||||
|
definition of "agent". Here, we do not have an "AI agent". Future work may
|
||||||
|
alter this state of affairs - and some "agent" could operate the browser
|
||||||
|
automatically.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ Initial State: interactor = "agent" │
|
||||||
|
│ - User keyboard/mouse input is BLOCKED │
|
||||||
|
│ - Agent can control browser via CDP │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
│ User clicks "Take Control" in frontend
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ Frontend → MessageChannel │
|
||||||
|
│ {"kind": "take-control"} │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ MessageChannel.handle_data() │
|
||||||
|
│ vnc_channel.interactor = "user" │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ New State: interactor = "user" │
|
||||||
|
│ - User keyboard/mouse input is PASSED THROUGH │
|
||||||
|
│ - Agent should pause automation │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
│ User clicks "Cede Control" in frontend
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ Frontend → MessageChannel │
|
||||||
|
│ {"kind": "cede-control"} │
|
||||||
|
└────────┬──────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌───────────────────────────────────────────────────────────┐
|
||||||
|
│ MessageChannel.handle_data() │
|
||||||
|
│ vnc_channel.interactor = "agent" │
|
||||||
|
│ → Back to initial state │
|
||||||
|
└───────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Error Propagation & Cleanup
|
||||||
|
|
||||||
|
The system uses `collect()` (fail-fast gather) for loop management:
|
||||||
|
|
||||||
|
```
|
||||||
|
Channel has 2 concurrent loops:
|
||||||
|
- Verification loop (polls DB every 5s)
|
||||||
|
- Streaming loop (handles WebSocket I/O)
|
||||||
|
|
||||||
|
collect() behavior:
|
||||||
|
1. Waits for ANY loop to fail or complete
|
||||||
|
2. Cancels all other loops
|
||||||
|
3. Propagates the first exception
|
||||||
|
|
||||||
|
Cleanup (always executed via finally):
|
||||||
|
- channel.close()
|
||||||
|
- Sets browser_session/task/workflow_run = None
|
||||||
|
- Closes WebSocket
|
||||||
|
- Removes from registry
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Entity Relationships
|
||||||
|
|
||||||
|
```
|
||||||
|
Organization
|
||||||
|
│
|
||||||
|
├─→ BrowserSession ────────┐
|
||||||
|
│ │
|
||||||
|
├─→ Task ──────────────────┤
|
||||||
|
│ (has optional │
|
||||||
|
│ browser_session) │
|
||||||
|
│ │
|
||||||
|
└─→ WorkflowRun ───────────┤
|
||||||
|
(has optional │
|
||||||
|
browser_session) │
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
VncChannel + MessageChannel
|
||||||
|
(in-memory, paired by client_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Design Patterns
|
||||||
|
|
||||||
|
1. **Channel Pairing**: Two WebSocket connections coordinated via in-memory registry
|
||||||
|
2. **Fail-Fast Loops**: `collect()` ensures any loop failure closes the entire channel
|
||||||
|
3. **Interactor Mode**: Binary state controlling whether user input is allowed
|
||||||
|
4. **On-Demand CDP**: ExecutionChannel creates temporary connections for each operation
|
||||||
|
5. **Polling Verification**: Every 5s, channels verify their backing entity still exists
|
||||||
|
6. **Pass-Through Proxy**: API server intercepts but doesn't transform RFB data
|
||||||
185
skyvern/forge/sdk/routes/streaming/channels/cdp.py
Normal file
185
skyvern/forge/sdk/routes/streaming/channels/cdp.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""
|
||||||
|
A channel for connecting to a persistent browser instance.
|
||||||
|
|
||||||
|
What this channel looks like:
|
||||||
|
|
||||||
|
[API Server] <--> [Browser (CDP)]
|
||||||
|
|
||||||
|
Channel data:
|
||||||
|
|
||||||
|
CDP protocol data, with Playwright thrown in.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class CdpChannel:
|
||||||
|
"""
|
||||||
|
CdpChannel. Relies on a VncChannel - without one, a CdpChannel has no
|
||||||
|
r'aison d'etre.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, *_: t.Iterable[t.Any], **__: t.Mapping[str, t.Any]) -> t.Self: # noqa: N805
|
||||||
|
if cls is CdpChannel:
|
||||||
|
raise TypeError("CdpChannel class cannot be instantiated directly.")
|
||||||
|
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def __init__(self, *, vnc_channel: VncChannel) -> None:
|
||||||
|
self.vnc_channel = vnc_channel
|
||||||
|
# --
|
||||||
|
self.browser: Browser | None = None
|
||||||
|
self.browser_context: BrowserContext | None = None
|
||||||
|
self.page: Page | None = None
|
||||||
|
self.pw: Playwright | None = None
|
||||||
|
self.url: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_name(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identity(self) -> t.Dict[str, t.Any]:
|
||||||
|
base = self.vnc_channel.identity
|
||||||
|
|
||||||
|
return base | {"url": self.url}
|
||||||
|
|
||||||
|
async def connect(self, cdp_url: str | None = None) -> t.Self:
|
||||||
|
"""
|
||||||
|
Idempotent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.browser and self.browser.is_connected():
|
||||||
|
return self
|
||||||
|
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
if cdp_url:
|
||||||
|
url = cdp_url
|
||||||
|
elif self.vnc_channel.browser_session and self.vnc_channel.browser_session.browser_address:
|
||||||
|
url = self.vnc_channel.browser_session.browser_address
|
||||||
|
else:
|
||||||
|
url = settings.BROWSER_REMOTE_DEBUGGING_URL
|
||||||
|
|
||||||
|
self.url = url
|
||||||
|
|
||||||
|
LOG.info(f"{self.class_name} connecting to CDP", **self.identity)
|
||||||
|
|
||||||
|
pw = self.pw or await async_playwright().start()
|
||||||
|
|
||||||
|
self.pw = pw
|
||||||
|
|
||||||
|
headers = (
|
||||||
|
{
|
||||||
|
"x-api-key": self.vnc_channel.x_api_key,
|
||||||
|
}
|
||||||
|
if self.vnc_channel.x_api_key
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_close() -> None:
|
||||||
|
LOG.warning(
|
||||||
|
f"{self.class_name} closing because the persistent browser disconnected itself.", **self.identity
|
||||||
|
)
|
||||||
|
close_task = asyncio.create_task(self.close())
|
||||||
|
close_task.add_done_callback(lambda _: asyncio.create_task(self.connect())) # TODO: avoid blind reconnect
|
||||||
|
|
||||||
|
self.browser = await pw.chromium.connect_over_cdp(url, headers=headers)
|
||||||
|
self.browser.on("disconnected", on_close)
|
||||||
|
|
||||||
|
await self.apply_download_behavior(self.browser)
|
||||||
|
|
||||||
|
contexts = self.browser.contexts
|
||||||
|
if contexts:
|
||||||
|
LOG.info(f"{self.class_name} using existing browser context", **self.identity)
|
||||||
|
self.browser_context = contexts[0]
|
||||||
|
else:
|
||||||
|
LOG.warning(f"{self.class_name} No existing browser context found, creating new one", **self.identity)
|
||||||
|
self.browser_context = await self.browser.new_context()
|
||||||
|
|
||||||
|
pages = self.browser_context.pages
|
||||||
|
if pages:
|
||||||
|
self.page = pages[0]
|
||||||
|
LOG.info(f"{self.class_name} connected to page", **self.identity)
|
||||||
|
else:
|
||||||
|
LOG.warning(f"{self.class_name} No pages found in browser context", **self.identity)
|
||||||
|
|
||||||
|
LOG.info(f"{self.class_name} connected successfully", **self.identity)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def apply_download_behavior(self, browser: Browser) -> t.Self:
|
||||||
|
org_id = self.vnc_channel.organization_id
|
||||||
|
|
||||||
|
browser_session_id = (
|
||||||
|
self.vnc_channel.browser_session.persistent_browser_session_id if self.vnc_channel.browser_session else None
|
||||||
|
)
|
||||||
|
|
||||||
|
download_path = f"/app/downloads/{org_id}/{browser_session_id}" if browser_session_id else "/app/downloads/"
|
||||||
|
|
||||||
|
cdp_session = await browser.new_browser_cdp_session()
|
||||||
|
|
||||||
|
await cdp_session.send(
|
||||||
|
"Browser.setDownloadBehavior",
|
||||||
|
{
|
||||||
|
"behavior": "allow",
|
||||||
|
"downloadPath": download_path,
|
||||||
|
"eventsEnabled": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await cdp_session.detach()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
LOG.info(f"{self.class_name} closing connection", **self.identity)
|
||||||
|
|
||||||
|
if self.browser:
|
||||||
|
try:
|
||||||
|
await self.browser.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.browser = None
|
||||||
|
|
||||||
|
if self.pw:
|
||||||
|
await self.pw.stop()
|
||||||
|
self.pw = None
|
||||||
|
|
||||||
|
self.browser_context = None
|
||||||
|
self.page = None
|
||||||
|
|
||||||
|
LOG.info(f"{self.class_name} closed", **self.identity)
|
||||||
|
|
||||||
|
async def evaluate_js(
|
||||||
|
self,
|
||||||
|
expression: str,
|
||||||
|
arg: str | int | float | bool | list | dict | None = None,
|
||||||
|
) -> str | int | float | bool | list | dict | None:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
if not self.page:
|
||||||
|
raise RuntimeError(f"{self.class_name} evaluate_js: not connected to a page. Call connect() first.")
|
||||||
|
|
||||||
|
LOG.info(f"{self.class_name} evaluating js", expression=expression[:100], **self.identity)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.page.evaluate(expression, arg)
|
||||||
|
LOG.info(f"{self.class_name} evaluated js successfully", **self.identity)
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{self.class_name} failed to evaluate js", expression=expression, **self.identity)
|
||||||
|
raise
|
||||||
121
skyvern/forge/sdk/routes/streaming/channels/execution.py
Normal file
121
skyvern/forge/sdk/routes/streaming/channels/execution.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
A channel for executing JavaScript against a persistent browser instance.
|
||||||
|
|
||||||
|
What this channel looks like:
|
||||||
|
|
||||||
|
[API Server] <--> [Browser (CDP)]
|
||||||
|
|
||||||
|
Channel data:
|
||||||
|
|
||||||
|
Chrome DevTools Protocol (CDP) over WebSockets. We cheat and use Playwright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.routes.streaming.channels.cdp import CdpChannel
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionChannel(CdpChannel):
|
||||||
|
"""
|
||||||
|
ExecutionChannel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_name(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
async def get_selected_text(self) -> str:
|
||||||
|
LOG.info(f"{self.class_name} getting selected text", **self.identity)
|
||||||
|
|
||||||
|
js_expression = """
|
||||||
|
() => {
|
||||||
|
const selection = window.getSelection();
|
||||||
|
return selection ? selection.toString() : '';
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
selected_text = await self.evaluate_js(js_expression, self.page)
|
||||||
|
|
||||||
|
if isinstance(selected_text, str) or selected_text is None:
|
||||||
|
LOG.info(
|
||||||
|
f"{self.class_name} got selected text",
|
||||||
|
length=len(selected_text) if selected_text else 0,
|
||||||
|
**self.identity,
|
||||||
|
)
|
||||||
|
return selected_text or ""
|
||||||
|
|
||||||
|
raise RuntimeError(f"{self.class_name} selected text is not a string, but a(n) '{type(selected_text)}'")
|
||||||
|
|
||||||
|
async def paste_text(self, text: str) -> None:
|
||||||
|
LOG.info(f"{self.class_name} pasting text", **self.identity)
|
||||||
|
|
||||||
|
js_expression = """
|
||||||
|
(text) => {
|
||||||
|
const activeElement = document.activeElement;
|
||||||
|
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
|
||||||
|
const start = activeElement.selectionStart || 0;
|
||||||
|
const end = activeElement.selectionEnd || 0;
|
||||||
|
const value = activeElement.value || '';
|
||||||
|
activeElement.value = value.slice(0, start) + text + value.slice(end);
|
||||||
|
const newCursorPos = start + text.length;
|
||||||
|
activeElement.setSelectionRange(newCursorPos, newCursorPos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.evaluate_js(js_expression, text)
|
||||||
|
|
||||||
|
LOG.info(f"{self.class_name} pasted text successfully", **self.identity)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
LOG.info(f"{self.class_name} closing connection", **self.identity)
|
||||||
|
|
||||||
|
if self.browser:
|
||||||
|
await self.browser.close()
|
||||||
|
self.browser = None
|
||||||
|
self.browser_context = None
|
||||||
|
self.page = None
|
||||||
|
|
||||||
|
if self.pw:
|
||||||
|
await self.pw.stop()
|
||||||
|
self.pw = None
|
||||||
|
|
||||||
|
LOG.info(f"{self.class_name} closed", **self.identity)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def execution_channel(vnc_channel: VncChannel) -> t.AsyncIterator[ExecutionChannel]:
|
||||||
|
"""
|
||||||
|
The first pass at this has us doing the following for every operation:
|
||||||
|
- creating a new channel
|
||||||
|
- connecting
|
||||||
|
- [doing smth]
|
||||||
|
- closing the channel
|
||||||
|
|
||||||
|
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
|
||||||
|
|
||||||
|
If it turns out it's too slow, we can refactor to keep a persistent channel per vnc client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
channel = ExecutionChannel(vnc_channel=vnc_channel)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await channel.connect()
|
||||||
|
|
||||||
|
# NOTE(jdo:streaming-local-dev)
|
||||||
|
# from skyvern.config import settings
|
||||||
|
# await channel.connect(settings.BROWSER_REMOTE_DEBUGGING_URL)
|
||||||
|
|
||||||
|
yield channel
|
||||||
|
finally:
|
||||||
|
await channel.close()
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
# stub
|
||||||
456
skyvern/forge/sdk/routes/streaming/channels/message.py
Normal file
456
skyvern/forge/sdk/routes/streaming/channels/message.py
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
"""
|
||||||
|
A channel for streaming whole messages between our frontend and our API server.
|
||||||
|
This channel can access a persistent browser instance through the execution channel.
|
||||||
|
|
||||||
|
What this channel looks like:
|
||||||
|
|
||||||
|
[Skyvern App] <--> [API Server]
|
||||||
|
|
||||||
|
Channel data:
|
||||||
|
|
||||||
|
JSON over WebSockets. Semantics are fire and forget. Req-resp is built on
|
||||||
|
top of that using message types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import WebSocket, WebSocketDisconnect
|
||||||
|
from starlette.websockets import WebSocketState
|
||||||
|
from websockets.exceptions import ConnectionClosedError
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel
|
||||||
|
from skyvern.forge.sdk.routes.streaming.registries import (
|
||||||
|
add_message_channel,
|
||||||
|
del_message_channel,
|
||||||
|
get_vnc_channel,
|
||||||
|
)
|
||||||
|
from skyvern.forge.sdk.routes.streaming.verify import (
|
||||||
|
loop_verify_browser_session,
|
||||||
|
loop_verify_workflow_run,
|
||||||
|
verify_browser_session,
|
||||||
|
verify_workflow_run,
|
||||||
|
)
|
||||||
|
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||||
|
from skyvern.forge.sdk.utils.aio import collect
|
||||||
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||||
|
|
||||||
|
|
||||||
|
MessageKinds = t.Literal[
|
||||||
|
"ask-for-clipboard-response",
|
||||||
|
"cede-control",
|
||||||
|
"take-control",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Message:
|
||||||
|
kind: MessageKinds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MessageInTakeControl(Message):
|
||||||
|
kind: t.Literal["take-control"] = "take-control"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MessageInCedeControl(Message):
|
||||||
|
kind: t.Literal["cede-control"] = "cede-control"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MessageInAskForClipboardResponse(Message):
|
||||||
|
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
|
||||||
|
text: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
ChannelMessage = t.Union[
|
||||||
|
MessageInAskForClipboardResponse,
|
||||||
|
MessageInCedeControl,
|
||||||
|
MessageInTakeControl,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def reify_channel_message(data: dict) -> ChannelMessage:
|
||||||
|
kind = data.get("kind", None)
|
||||||
|
|
||||||
|
match kind:
|
||||||
|
case "ask-for-clipboard-response":
|
||||||
|
text = data.get("text") or ""
|
||||||
|
return MessageInAskForClipboardResponse(text=text)
|
||||||
|
case "cede-control":
|
||||||
|
return MessageInCedeControl()
|
||||||
|
case "take-control":
|
||||||
|
return MessageInTakeControl()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown message kind: '{kind}'")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MessageChannel:
|
||||||
|
"""
|
||||||
|
A message channel for streaming JSON messages between our frontend and our API server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
client_id: str
|
||||||
|
organization_id: str
|
||||||
|
websocket: WebSocket
|
||||||
|
# --
|
||||||
|
out_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) # warn: unbounded
|
||||||
|
browser_session: AddressablePersistentBrowserSession | None = None
|
||||||
|
workflow_run: WorkflowRun | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
add_message_channel(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_name(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identity(self) -> dict[str, str]:
|
||||||
|
base = {"organization_id": self.organization_id}
|
||||||
|
|
||||||
|
if self.browser_session:
|
||||||
|
return base | {"browser_session_id": self.browser_session.persistent_browser_session_id}
|
||||||
|
|
||||||
|
if self.workflow_run:
|
||||||
|
return base | {"workflow_run_id": self.workflow_run.id}
|
||||||
|
|
||||||
|
return base
|
||||||
|
|
||||||
|
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
|
||||||
|
LOG.info(f"{self.class_name} closing message stream.", reason=reason, code=code, **self.identity)
|
||||||
|
|
||||||
|
self.browser_session = None
|
||||||
|
self.workflow_run = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.websocket.close(code=code, reason=reason)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
del_message_channel(self.client_id)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_open(self) -> bool:
|
||||||
|
if self.websocket.client_state != WebSocketState.CONNECTED:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def drain(self) -> list[dict]:
|
||||||
|
datums: list[dict] = []
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(self.receive_from_out_queue()),
|
||||||
|
asyncio.create_task(self.receive_from_user()),
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
datums.extend(result)
|
||||||
|
|
||||||
|
if datums:
|
||||||
|
LOG.info(f"{self.class_name} Drained {len(datums)} messages from message channel.", **self.identity)
|
||||||
|
|
||||||
|
return datums
|
||||||
|
|
||||||
|
async def receive_from_user(self) -> list[dict]:
|
||||||
|
datums: list[dict] = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(self.websocket.receive_json(), timeout=0.001)
|
||||||
|
datums.append(data)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
except RuntimeError:
|
||||||
|
if "not connected" in str(RuntimeError).lower():
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{self.class_name} Failed to receive message from message channel", **self.identity)
|
||||||
|
break
|
||||||
|
|
||||||
|
return datums
|
||||||
|
|
||||||
|
async def receive_from_out_queue(self) -> list[dict]:
|
||||||
|
datums: list[dict] = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(self.out_queue.get(), timeout=0.001)
|
||||||
|
datums.append(data)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
|
||||||
|
return datums
|
||||||
|
|
||||||
|
def receive_from_out_queue_nowait(self) -> list[dict]:
|
||||||
|
datums: list[dict] = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.out_queue.get_nowait()
|
||||||
|
datums.append(data)
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
|
||||||
|
return datums
|
||||||
|
|
||||||
|
async def send(self, *, messages: list[dict]) -> t.Self:
|
||||||
|
for message in messages:
|
||||||
|
await self.out_queue.put(message)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def send_nowait(self, *, messages: list[dict]) -> t.Self:
|
||||||
|
for message in messages:
|
||||||
|
self.out_queue.put_nowait(message)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ask_for_clipboard(self) -> None:
|
||||||
|
LOG.info(f"{self.class_name} Sending ask-for-clipboard to message channel", **self.identity)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.websocket.send_json(
|
||||||
|
{
|
||||||
|
"kind": "ask-for-clipboard",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{self.class_name} Failed to send ask-for-clipboard to message channel", **self.identity)
|
||||||
|
|
||||||
|
async def send_copied_text(self, copied_text: str) -> None:
|
||||||
|
LOG.info(f"{self.class_name} Sending copied text to message channel", **self.identity)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.websocket.send_json(
|
||||||
|
{
|
||||||
|
"kind": "copied-text",
|
||||||
|
"text": copied_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{self.class_name} Failed to send copied text to message channel", **self.identity)
|
||||||
|
|
||||||
|
|
||||||
|
async def loop_stream_messages(message_channel: MessageChannel) -> None:
|
||||||
|
"""
|
||||||
|
Stream messages and their results back and forth.
|
||||||
|
|
||||||
|
Loops until the websocket is closed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class_name = message_channel.class_name
|
||||||
|
|
||||||
|
async def handle_data(data: dict) -> None:
|
||||||
|
nonlocal class_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = reify_channel_message(data)
|
||||||
|
except ValueError:
|
||||||
|
LOG.error(f"MessageChannel: cannot reify channel message from data: {data}", **message_channel.identity)
|
||||||
|
return
|
||||||
|
|
||||||
|
message_kind = message.kind
|
||||||
|
|
||||||
|
match message_kind:
|
||||||
|
case "ask-for-clipboard-response":
|
||||||
|
if not isinstance(message, MessageInAskForClipboardResponse):
|
||||||
|
LOG.error(
|
||||||
|
f"{class_name} invalid message type for ask-for-clipboard-response.",
|
||||||
|
message=message,
|
||||||
|
**message_channel.identity,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||||
|
|
||||||
|
if not vnc_channel:
|
||||||
|
LOG.error(
|
||||||
|
f"{class_name} no vnc channel found for message channel.",
|
||||||
|
message=message,
|
||||||
|
**message_channel.identity,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
text = message.text
|
||||||
|
|
||||||
|
async with execution_channel(vnc_channel) as execute:
|
||||||
|
await execute.paste_text(text)
|
||||||
|
|
||||||
|
case "cede-control":
|
||||||
|
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||||
|
|
||||||
|
if not vnc_channel:
|
||||||
|
LOG.error(
|
||||||
|
f"{class_name} no vnc channel client found for message channel.",
|
||||||
|
message=message,
|
||||||
|
**message_channel.identity,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
vnc_channel.interactor = "agent"
|
||||||
|
|
||||||
|
case "take-control":
|
||||||
|
LOG.info(f"{class_name} processing take-control message.", **message_channel.identity)
|
||||||
|
vnc_channel = get_vnc_channel(message_channel.client_id)
|
||||||
|
|
||||||
|
if not vnc_channel:
|
||||||
|
LOG.error(
|
||||||
|
f"{class_name} no vnc channel client found for message channel.",
|
||||||
|
message=message,
|
||||||
|
**message_channel.identity,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
vnc_channel.interactor = "user"
|
||||||
|
|
||||||
|
case _:
|
||||||
|
LOG.error(f"{class_name} unknown message kind: '{message_kind}'", **message_channel.identity)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def frontend_to_backend() -> None:
|
||||||
|
nonlocal class_name
|
||||||
|
|
||||||
|
LOG.info(f"{class_name} starting frontend-to-backend loop.", **message_channel.identity)
|
||||||
|
|
||||||
|
while message_channel.is_open:
|
||||||
|
try:
|
||||||
|
datums = await message_channel.drain()
|
||||||
|
|
||||||
|
for data in datums:
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
LOG.error(
|
||||||
|
f"{class_name} cannot create message: expected dict, got {type(data)}",
|
||||||
|
**message_channel.identity,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await handle_data(data)
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info(f"{class_name} frontend disconnected.", **message_channel.identity)
|
||||||
|
raise
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.info(f"{class_name} frontend closed channel.", **message_channel.identity)
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{class_name} An unexpected exception occurred.", **message_channel.identity)
|
||||||
|
raise
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(frontend_to_backend()),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await collect(loops)
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{class_name} An exception occurred in loop message channel stream.", **message_channel.identity)
|
||||||
|
finally:
|
||||||
|
LOG.info(f"{class_name} Closing the message channel stream.", **message_channel.identity)
|
||||||
|
await message_channel.close(reason="loop-channel-closed")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_message_channel_for_browser_session(
|
||||||
|
client_id: str,
|
||||||
|
browser_session_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[MessageChannel, Loops] | None:
|
||||||
|
"""
|
||||||
|
Return a message channel for a browser session, with a list of loops to run concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
|
||||||
|
|
||||||
|
browser_session = await verify_browser_session(
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info(
|
||||||
|
"Message channel: no initial browser session found.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_channel = MessageChannel(
|
||||||
|
client_id=client_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
browser_session=browser_session,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Got message channel for browser session.", message_channel=message_channel)
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(loop_verify_browser_session(message_channel)),
|
||||||
|
asyncio.create_task(loop_stream_messages(message_channel)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return message_channel, loops
|
||||||
|
|
||||||
|
|
||||||
|
async def get_message_channel_for_workflow_run(
|
||||||
|
client_id: str,
|
||||||
|
workflow_run_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[MessageChannel, Loops] | None:
|
||||||
|
"""
|
||||||
|
Return a message channel for a workflow run, with a list of loops to run concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||||
|
|
||||||
|
workflow_run, browser_session = await verify_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow_run:
|
||||||
|
LOG.info(
|
||||||
|
"Message channel: no initial workflow run found.",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info(
|
||||||
|
"Message channel: no initial browser session found for workflow run.",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_channel = MessageChannel(
|
||||||
|
client_id,
|
||||||
|
organization_id,
|
||||||
|
browser_session=browser_session,
|
||||||
|
websocket=websocket,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(loop_verify_workflow_run(message_channel)),
|
||||||
|
asyncio.create_task(loop_stream_messages(message_channel)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return message_channel, loops
|
||||||
592
skyvern/forge/sdk/routes/streaming/channels/vnc.py
Normal file
592
skyvern/forge/sdk/routes/streaming/channels/vnc.py
Normal file
@@ -0,0 +1,592 @@
|
|||||||
|
"""
|
||||||
|
A channel for streaming the VNC protocol data between our frontend and a
|
||||||
|
persistent browser instance.
|
||||||
|
|
||||||
|
This is a pass-thru channel, through our API server. As such, we can monitor and/or
|
||||||
|
intercept RFB protocol messages as needed.
|
||||||
|
|
||||||
|
What this channel looks like:
|
||||||
|
|
||||||
|
[Skyvern App] <--> [API Server] <--> [websockified noVNC] <--> [Browser]
|
||||||
|
|
||||||
|
|
||||||
|
Channel data:
|
||||||
|
|
||||||
|
One could call this RFB over WebSockets (rockets?), as the protocol data streaming
|
||||||
|
over the WebSocket is raw RFB protocol data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import typing as t
|
||||||
|
from enum import IntEnum
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
import websockets
|
||||||
|
from fastapi import WebSocket, WebSocketDisconnect
|
||||||
|
from starlette.websockets import WebSocketState
|
||||||
|
from websockets import ConnectionClosedError, Data
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
|
from skyvern.forge.sdk.routes.streaming.auth import get_x_api_key
|
||||||
|
from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel
|
||||||
|
from skyvern.forge.sdk.routes.streaming.registries import (
|
||||||
|
add_vnc_channel,
|
||||||
|
del_vnc_channel,
|
||||||
|
get_message_channel,
|
||||||
|
get_vnc_channel,
|
||||||
|
)
|
||||||
|
from skyvern.forge.sdk.routes.streaming.verify import (
|
||||||
|
loop_verify_browser_session,
|
||||||
|
loop_verify_task,
|
||||||
|
loop_verify_workflow_run,
|
||||||
|
verify_browser_session,
|
||||||
|
verify_task,
|
||||||
|
verify_workflow_run,
|
||||||
|
)
|
||||||
|
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||||
|
from skyvern.forge.sdk.schemas.tasks import Task
|
||||||
|
from skyvern.forge.sdk.utils.aio import collect
|
||||||
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
Interactor = t.Literal["agent", "user"]
|
||||||
|
"""
|
||||||
|
NOTE: we don't really have an "agent" at this time. But any control of the
|
||||||
|
browser that is not user-originated is kinda' agent-like, by some
|
||||||
|
definition of "agent". Here, we do not have an "AI agent". Future work may
|
||||||
|
alter this state of affairs - and some "agent" could operate the browser
|
||||||
|
automatically. In any case, if the interactor is not a "user", we assume
|
||||||
|
it is an "agent".
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(IntEnum):
|
||||||
|
Keyboard = 4
|
||||||
|
Mouse = 5
|
||||||
|
|
||||||
|
|
||||||
|
class Keys:
|
||||||
|
"""
|
||||||
|
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
|
||||||
|
|
||||||
|
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
|
||||||
|
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Down:
|
||||||
|
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
|
||||||
|
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
|
||||||
|
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
|
||||||
|
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
|
||||||
|
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
|
||||||
|
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
|
||||||
|
|
||||||
|
class Up:
|
||||||
|
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
|
||||||
|
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
|
||||||
|
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
|
||||||
|
|
||||||
|
|
||||||
|
def is_rmb(data: bytes) -> bool:
|
||||||
|
return data[0:2] == b"\x05\x04"
|
||||||
|
|
||||||
|
|
||||||
|
class Mouse:
|
||||||
|
class Up:
|
||||||
|
Right = is_rmb
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class KeyState:
|
||||||
|
ctrl_is_down: bool = False
|
||||||
|
alt_is_down: bool = False
|
||||||
|
cmd_is_down: bool = False
|
||||||
|
|
||||||
|
def is_forbidden(self, data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
:return: True if the key is forbidden, else False
|
||||||
|
"""
|
||||||
|
return self.is_ctrl_o(data)
|
||||||
|
|
||||||
|
def is_ctrl_o(self, data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Do not allow the opening of files.
|
||||||
|
"""
|
||||||
|
return self.ctrl_is_down and data == Keys.Down.OKey
|
||||||
|
|
||||||
|
def is_copy(self, data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Detect Ctrl+C or Cmd+C for copy.
|
||||||
|
"""
|
||||||
|
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
|
||||||
|
|
||||||
|
def is_paste(self, data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Detect Ctrl+V or Cmd+V for paste.
|
||||||
|
"""
|
||||||
|
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class VncChannel:
|
||||||
|
"""
|
||||||
|
A VNC channel for streaming RFB protocol data between our frontend app, and
|
||||||
|
a remote browser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
client_id: str
|
||||||
|
"""
|
||||||
|
Unique to a frontend app instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
organization_id: str
|
||||||
|
vnc_port: int
|
||||||
|
x_api_key: str
|
||||||
|
websocket: WebSocket
|
||||||
|
|
||||||
|
initial_interactor: dataclasses.InitVar[Interactor]
|
||||||
|
"""
|
||||||
|
The role of the entity interacting with the channel, either "agent" or "user".
|
||||||
|
"""
|
||||||
|
|
||||||
|
_interactor: Interactor = dataclasses.field(init=False, repr=False)
|
||||||
|
|
||||||
|
# --
|
||||||
|
|
||||||
|
browser_session: AddressablePersistentBrowserSession | None = None
|
||||||
|
key_state: KeyState = dataclasses.field(default_factory=KeyState)
|
||||||
|
task: Task | None = None
|
||||||
|
workflow_run: WorkflowRun | None = None
|
||||||
|
|
||||||
|
def __post_init__(self, initial_interactor: Interactor) -> None:
|
||||||
|
self.interactor = initial_interactor
|
||||||
|
add_vnc_channel(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_name(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identity(self) -> dict:
|
||||||
|
base = {"organization_id": self.organization_id}
|
||||||
|
|
||||||
|
if self.task:
|
||||||
|
return base | {"task_id": self.task.task_id}
|
||||||
|
elif self.workflow_run:
|
||||||
|
return base | {"workflow_run_id": self.workflow_run.workflow_run_id}
|
||||||
|
elif self.browser_session:
|
||||||
|
return base | {"browser_session_id": self.browser_session.persistent_browser_session_id}
|
||||||
|
else:
|
||||||
|
return base | {"client_id": self.client_id}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interactor(self) -> Interactor:
|
||||||
|
return self._interactor
|
||||||
|
|
||||||
|
@interactor.setter
|
||||||
|
def interactor(self, value: Interactor) -> None:
|
||||||
|
self._interactor = value
|
||||||
|
|
||||||
|
LOG.info(f"{self.class_name} Setting interactor to {value}", **self.identity)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_open(self) -> bool:
|
||||||
|
if self.websocket.client_state != WebSocketState.CONNECTED:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not self.task and not self.workflow_run and not self.browser_session:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not get_vnc_channel(self.client_id):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def close(self, code: int = 1000, reason: str | None = None) -> t.Self:
|
||||||
|
LOG.info(f"{self.class_name} closing.", reason=reason, code=code, **self.identity)
|
||||||
|
|
||||||
|
self.browser_session = None
|
||||||
|
self.task = None
|
||||||
|
self.workflow_run = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.websocket.close(code=code, reason=reason)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
del_vnc_channel(self.client_id)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def update_key_state(self, data: bytes) -> None:
|
||||||
|
if data == Keys.Down.Ctrl:
|
||||||
|
self.key_state.ctrl_is_down = True
|
||||||
|
elif data == Keys.Up.Ctrl:
|
||||||
|
self.key_state.ctrl_is_down = False
|
||||||
|
elif data == Keys.Down.Alt:
|
||||||
|
self.key_state.alt_is_down = True
|
||||||
|
elif data == Keys.Up.Alt:
|
||||||
|
self.key_state.alt_is_down = False
|
||||||
|
elif data == Keys.Down.Cmd:
|
||||||
|
self.key_state.cmd_is_down = True
|
||||||
|
elif data == Keys.Up.Cmd:
|
||||||
|
self.key_state.cmd_is_down = False
|
||||||
|
|
||||||
|
|
||||||
|
async def copy_text(vnc_channel: VncChannel) -> None:
|
||||||
|
class_name = vnc_channel.class_name
|
||||||
|
LOG.info(f"{class_name} Retrieving selected text via CDP", **vnc_channel.identity)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with execution_channel(vnc_channel) as execute:
|
||||||
|
copied_text = await execute.get_selected_text()
|
||||||
|
|
||||||
|
message_channel = get_message_channel(vnc_channel.client_id)
|
||||||
|
|
||||||
|
if message_channel:
|
||||||
|
await message_channel.send_copied_text(copied_text)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
f"{class_name} No message channel found for client, or it is not open",
|
||||||
|
message_channel=message_channel,
|
||||||
|
**vnc_channel.identity,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{class_name} Failed to retrieve selected text via CDP", **vnc_channel.identity)
|
||||||
|
|
||||||
|
|
||||||
|
async def ask_for_clipboard(vnc_channel: VncChannel) -> None:
|
||||||
|
class_name = vnc_channel.class_name
|
||||||
|
LOG.info(f"{class_name} Asking for clipboard data via CDP", **vnc_channel.identity)
|
||||||
|
|
||||||
|
try:
|
||||||
|
message_channel = get_message_channel(vnc_channel.client_id)
|
||||||
|
|
||||||
|
if message_channel:
|
||||||
|
await message_channel.ask_for_clipboard()
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
f"{class_name} No message channel found for client, or it is not open",
|
||||||
|
message_channel=message_channel,
|
||||||
|
**vnc_channel.identity,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{class_name} Failed to ask for clipboard via CDP", **vnc_channel.identity)
|
||||||
|
|
||||||
|
|
||||||
|
async def loop_stream_vnc(vnc_channel: VncChannel) -> None:
|
||||||
|
"""
|
||||||
|
Actually stream the VNC data between a frontend and a browser.
|
||||||
|
|
||||||
|
Loops until the task is cleared or the websocket is closed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vnc_url: str = ""
|
||||||
|
browser_session = vnc_channel.browser_session
|
||||||
|
class_name = vnc_channel.class_name
|
||||||
|
|
||||||
|
if browser_session:
|
||||||
|
if browser_session.ip_address:
|
||||||
|
if ":" in browser_session.ip_address:
|
||||||
|
ip, _ = browser_session.ip_address.split(":")
|
||||||
|
vnc_url = f"ws://{ip}:{vnc_channel.vnc_port}"
|
||||||
|
else:
|
||||||
|
vnc_url = f"ws://{browser_session.ip_address}:{vnc_channel.vnc_port}"
|
||||||
|
else:
|
||||||
|
browser_address = browser_session.browser_address
|
||||||
|
|
||||||
|
parsed_browser_address = urlparse(browser_address)
|
||||||
|
host = parsed_browser_address.hostname
|
||||||
|
vnc_url = f"ws://{host}:{vnc_channel.vnc_port}"
|
||||||
|
else:
|
||||||
|
raise Exception(f"{class_name} No browser session associated with vnc channel.")
|
||||||
|
|
||||||
|
# NOTE(jdo:streaming-local-dev)
|
||||||
|
# vnc_url = "ws://localhost:6080"
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"{class_name} Connecting to vnc url.",
|
||||||
|
vnc_url=vnc_url,
|
||||||
|
**vnc_channel.identity,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with websockets.connect(vnc_url) as novnc_ws:
|
||||||
|
|
||||||
|
async def frontend_to_browser() -> None:
|
||||||
|
nonlocal class_name
|
||||||
|
|
||||||
|
LOG.info(f"{class_name} Starting frontend-to-browser data transfer.", **vnc_channel.identity)
|
||||||
|
data: Data | None = None
|
||||||
|
|
||||||
|
while vnc_channel.is_open:
|
||||||
|
try:
|
||||||
|
data = await vnc_channel.websocket.receive_bytes()
|
||||||
|
|
||||||
|
if data:
|
||||||
|
message_type = data[0]
|
||||||
|
|
||||||
|
if message_type == MessageType.Keyboard.value:
|
||||||
|
vnc_channel.update_key_state(data)
|
||||||
|
|
||||||
|
if vnc_channel.key_state.is_copy(data):
|
||||||
|
await copy_text(vnc_channel)
|
||||||
|
|
||||||
|
if vnc_channel.key_state.is_paste(data):
|
||||||
|
await ask_for_clipboard(vnc_channel)
|
||||||
|
|
||||||
|
if vnc_channel.key_state.is_forbidden(data):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# prevent right-mouse-button clicks for "security" reasons
|
||||||
|
if message_type == MessageType.Mouse.value:
|
||||||
|
if Mouse.Up.Right(data):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not vnc_channel.interactor == "user" and message_type in (
|
||||||
|
MessageType.Keyboard.value,
|
||||||
|
MessageType.Mouse.value,
|
||||||
|
):
|
||||||
|
LOG.debug(f"{class_name} Blocking user message.", **vnc_channel.identity)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info(f"{class_name} Frontend disconnected.", **vnc_channel.identity)
|
||||||
|
raise
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.info(f"{class_name} Frontend closed the vnc channel.", **vnc_channel.identity)
|
||||||
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{class_name} An unexpected exception occurred.", **vnc_channel.identity)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await novnc_ws.send(data)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info(f"{class_name} Browser disconnected from vnc.", **vnc_channel.identity)
|
||||||
|
raise
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.info(f"{class_name} Browser closed vnc.", **vnc_channel.identity)
|
||||||
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
f"{class_name} An unexpected exception occurred in frontend-to-browser loop.",
|
||||||
|
**vnc_channel.identity,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def browser_to_frontend() -> None:
|
||||||
|
nonlocal class_name
|
||||||
|
|
||||||
|
LOG.info(f"{class_name} Starting browser-to-frontend data transfer.", **vnc_channel.identity)
|
||||||
|
data: Data | None = None
|
||||||
|
|
||||||
|
while vnc_channel.is_open:
|
||||||
|
try:
|
||||||
|
data = await novnc_ws.recv()
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info(f"{class_name} Browser disconnected from the vnc channel session.", **vnc_channel.identity)
|
||||||
|
await vnc_channel.close(reason="browser-disconnected")
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.info(f"{class_name} Browser closed the vnc channel session.", **vnc_channel.identity)
|
||||||
|
await vnc_channel.close(reason="browser-closed")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
f"{class_name} An unexpected exception occurred in browser-to-frontend loop.",
|
||||||
|
**vnc_channel.identity,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if vnc_channel.websocket.client_state != WebSocketState.CONNECTED:
|
||||||
|
continue
|
||||||
|
await vnc_channel.websocket.send_bytes(data)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info(
|
||||||
|
f"{class_name} Frontend disconnected from the vnc channel session.", **vnc_channel.identity
|
||||||
|
)
|
||||||
|
await vnc_channel.close(reason="frontend-disconnected")
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.info(f"{class_name} Frontend closed the vnc channel session.", **vnc_channel.identity)
|
||||||
|
await vnc_channel.close(reason="frontend-closed")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{class_name} An unexpected exception occurred.", **vnc_channel.identity)
|
||||||
|
raise
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(frontend_to_browser()),
|
||||||
|
asyncio.create_task(browser_to_frontend()),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await collect(loops)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"{class_name} An exception occurred in loop stream.", **vnc_channel.identity)
|
||||||
|
finally:
|
||||||
|
LOG.info(f"{class_name} Closing the loop stream.", **vnc_channel.identity)
|
||||||
|
await vnc_channel.close(reason="loop-stream-vnc-closed")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_vnc_channel_for_browser_session(
|
||||||
|
client_id: str,
|
||||||
|
browser_session_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[VncChannel, Loops] | None:
|
||||||
|
"""
|
||||||
|
Return a vnc channel for a browser session, with a list of loops to run concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOG.info("Getting vnc context for browser session.", browser_session_id=browser_session_id)
|
||||||
|
|
||||||
|
browser_session = await verify_browser_session(
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info(
|
||||||
|
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
x_api_key = await get_x_api_key(organization_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
vnc_channel = VncChannel(
|
||||||
|
client_id=client_id,
|
||||||
|
initial_interactor="agent",
|
||||||
|
organization_id=organization_id,
|
||||||
|
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||||
|
browser_session=browser_session,
|
||||||
|
x_api_key=x_api_key,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.exception("Failed to create VncChannel.", error=str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
LOG.info("Got vnc context for browser session.", vnc_channel=vnc_channel)
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(loop_verify_browser_session(vnc_channel)),
|
||||||
|
asyncio.create_task(loop_stream_vnc(vnc_channel)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return vnc_channel, loops
|
||||||
|
|
||||||
|
|
||||||
|
async def get_vnc_channel_for_task(
|
||||||
|
client_id: str,
|
||||||
|
task_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[VncChannel, Loops] | None:
|
||||||
|
"""
|
||||||
|
Return a vnc channel for a task, with a list of loops to run concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
x_api_key = await get_x_api_key(organization_id)
|
||||||
|
|
||||||
|
vnc_channel = VncChannel(
|
||||||
|
client_id=client_id,
|
||||||
|
initial_interactor="agent",
|
||||||
|
organization_id=organization_id,
|
||||||
|
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||||
|
x_api_key=x_api_key,
|
||||||
|
websocket=websocket,
|
||||||
|
browser_session=browser_session,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(loop_verify_task(vnc_channel)),
|
||||||
|
asyncio.create_task(loop_stream_vnc(vnc_channel)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return vnc_channel, loops
|
||||||
|
|
||||||
|
|
||||||
|
async def get_vnc_channel_for_workflow_run(
|
||||||
|
client_id: str,
|
||||||
|
workflow_run_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[VncChannel, Loops] | None:
|
||||||
|
"""
|
||||||
|
Return a vnc channel for a workflow run, with a list of loops to run concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOG.info("Getting vnc channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||||
|
|
||||||
|
workflow_run, browser_session = await verify_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow_run:
|
||||||
|
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info(
|
||||||
|
"No initial browser session found for workflow run.",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
x_api_key = await get_x_api_key(organization_id)
|
||||||
|
|
||||||
|
vnc_channel = VncChannel(
|
||||||
|
client_id=client_id,
|
||||||
|
initial_interactor="agent",
|
||||||
|
organization_id=organization_id,
|
||||||
|
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||||
|
browser_session=browser_session,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
x_api_key=x_api_key,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Got vnc channel context for workflow run.", vnc_channel=vnc_channel)
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(loop_verify_workflow_run(vnc_channel)),
|
||||||
|
asyncio.create_task(loop_stream_vnc(vnc_channel)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return vnc_channel, loops
|
||||||
@@ -1,328 +0,0 @@
|
|||||||
"""
|
|
||||||
Streaming types.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import dataclasses
|
|
||||||
import typing as t
|
|
||||||
from enum import IntEnum
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
from fastapi import WebSocket
|
|
||||||
from starlette.websockets import WebSocketState
|
|
||||||
|
|
||||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
|
||||||
from skyvern.forge.sdk.schemas.tasks import Task
|
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
|
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
Interactor = t.Literal["agent", "user"]
|
|
||||||
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
|
||||||
|
|
||||||
|
|
||||||
# Messages
|
|
||||||
|
|
||||||
|
|
||||||
# a global registry for WS message clients
|
|
||||||
message_channels: dict[str, "MessageChannel"] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def add_message_client(message_channel: "MessageChannel") -> None:
|
|
||||||
message_channels[message_channel.client_id] = message_channel
|
|
||||||
|
|
||||||
|
|
||||||
def get_message_client(client_id: str) -> t.Union["MessageChannel", None]:
|
|
||||||
return message_channels.get(client_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
def del_message_client(client_id: str) -> None:
|
|
||||||
try:
|
|
||||||
del message_channels[client_id]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class MessageChannel:
|
|
||||||
client_id: str
|
|
||||||
organization_id: str
|
|
||||||
websocket: WebSocket
|
|
||||||
|
|
||||||
# --
|
|
||||||
|
|
||||||
browser_session: AddressablePersistentBrowserSession | None = None
|
|
||||||
workflow_run: WorkflowRun | None = None
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
add_message_client(self)
|
|
||||||
|
|
||||||
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
|
|
||||||
LOG.info("Closing message stream.", reason=reason, code=code)
|
|
||||||
|
|
||||||
self.browser_session = None
|
|
||||||
self.workflow_run = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.websocket.close(code=code, reason=reason)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
del_message_client(self.client_id)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_open(self) -> bool:
|
|
||||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not self.workflow_run and not self.browser_session:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not get_message_client(self.client_id):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def ask_for_clipboard(self, streaming: "Streaming") -> None:
|
|
||||||
try:
|
|
||||||
await self.websocket.send_json(
|
|
||||||
{
|
|
||||||
"kind": "ask-for-clipboard",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
LOG.info(
|
|
||||||
"Sent ask-for-clipboard to message channel",
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"Failed to send ask-for-clipboard to message channel",
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_copied_text(self, copied_text: str, streaming: "Streaming") -> None:
|
|
||||||
try:
|
|
||||||
await self.websocket.send_json(
|
|
||||||
{
|
|
||||||
"kind": "copied-text",
|
|
||||||
"text": copied_text,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
LOG.info(
|
|
||||||
"Sent copied text to message channel",
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"Failed to send copied text to message channel",
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
MessageKinds = t.Literal["take-control", "cede-control", "ask-for-clipboard-response"]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Message:
|
|
||||||
kind: MessageKinds
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class MessageTakeControl(Message):
|
|
||||||
kind: t.Literal["take-control"] = "take-control"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class MessageCedeControl(Message):
|
|
||||||
kind: t.Literal["cede-control"] = "cede-control"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class MessageInAskForClipboardResponse(Message):
|
|
||||||
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
|
|
||||||
text: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
ChannelMessage = t.Union[MessageTakeControl, MessageCedeControl, MessageInAskForClipboardResponse]
|
|
||||||
|
|
||||||
|
|
||||||
def reify_channel_message(data: dict) -> ChannelMessage:
|
|
||||||
kind = data.get("kind", None)
|
|
||||||
|
|
||||||
match kind:
|
|
||||||
case "take-control":
|
|
||||||
return MessageTakeControl()
|
|
||||||
case "cede-control":
|
|
||||||
return MessageCedeControl()
|
|
||||||
case "ask-for-clipboard-response":
|
|
||||||
text = data.get("text") or ""
|
|
||||||
return MessageInAskForClipboardResponse(text=text)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unknown message kind: '{kind}'")
|
|
||||||
|
|
||||||
|
|
||||||
# Streaming
|
|
||||||
|
|
||||||
|
|
||||||
# a global registry for WS streaming VNC clients
|
|
||||||
streaming_clients: dict[str, "Streaming"] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def add_streaming_client(streaming: "Streaming") -> None:
|
|
||||||
streaming_clients[streaming.client_id] = streaming
|
|
||||||
|
|
||||||
|
|
||||||
def get_streaming_client(client_id: str) -> t.Union["Streaming", None]:
|
|
||||||
return streaming_clients.get(client_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
def del_streaming_client(client_id: str) -> None:
|
|
||||||
try:
|
|
||||||
del streaming_clients[client_id]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MessageType(IntEnum):
|
|
||||||
Keyboard = 4
|
|
||||||
Mouse = 5
|
|
||||||
|
|
||||||
|
|
||||||
class Keys:
|
|
||||||
"""
|
|
||||||
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
|
|
||||||
|
|
||||||
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
|
|
||||||
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Down:
|
|
||||||
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
|
|
||||||
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
|
|
||||||
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
|
|
||||||
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
|
|
||||||
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
|
|
||||||
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
|
|
||||||
|
|
||||||
class Up:
|
|
||||||
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
|
|
||||||
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
|
|
||||||
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
|
|
||||||
|
|
||||||
|
|
||||||
def is_rmb(data: bytes) -> bool:
|
|
||||||
return data[0:2] == b"\x05\x04"
|
|
||||||
|
|
||||||
|
|
||||||
class Mouse:
|
|
||||||
class Up:
|
|
||||||
Right = is_rmb
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class KeyState:
|
|
||||||
ctrl_is_down: bool = False
|
|
||||||
alt_is_down: bool = False
|
|
||||||
cmd_is_down: bool = False
|
|
||||||
|
|
||||||
def is_forbidden(self, data: bytes) -> bool:
|
|
||||||
"""
|
|
||||||
:return: True if the key is forbidden, else False
|
|
||||||
"""
|
|
||||||
return self.is_ctrl_o(data)
|
|
||||||
|
|
||||||
def is_ctrl_o(self, data: bytes) -> bool:
|
|
||||||
"""
|
|
||||||
Do not allow the opening of files.
|
|
||||||
"""
|
|
||||||
return self.ctrl_is_down and data == Keys.Down.OKey
|
|
||||||
|
|
||||||
def is_copy(self, data: bytes) -> bool:
|
|
||||||
"""
|
|
||||||
Detect Ctrl+C or Cmd+C for copy.
|
|
||||||
"""
|
|
||||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
|
|
||||||
|
|
||||||
def is_paste(self, data: bytes) -> bool:
|
|
||||||
"""
|
|
||||||
Detect Ctrl+V or Cmd+V for paste.
|
|
||||||
"""
|
|
||||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Streaming:
|
|
||||||
"""
|
|
||||||
Streaming state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
client_id: str
|
|
||||||
"""
|
|
||||||
Unique to frontend app instance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
interactor: Interactor
|
|
||||||
"""
|
|
||||||
Whether the user or the agent are the interactor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
organization_id: str
|
|
||||||
vnc_port: int
|
|
||||||
x_api_key: str
|
|
||||||
websocket: WebSocket
|
|
||||||
|
|
||||||
# --
|
|
||||||
|
|
||||||
browser_session: AddressablePersistentBrowserSession | None = None
|
|
||||||
key_state: KeyState = dataclasses.field(default_factory=KeyState)
|
|
||||||
task: Task | None = None
|
|
||||||
workflow_run: WorkflowRun | None = None
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
add_streaming_client(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_open(self) -> bool:
|
|
||||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not self.task and not self.workflow_run and not self.browser_session:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not get_streaming_client(self.client_id):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
|
|
||||||
LOG.info("Closing Streaming.", reason=reason, code=code)
|
|
||||||
|
|
||||||
self.browser_session = None
|
|
||||||
self.task = None
|
|
||||||
self.workflow_run = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.websocket.close(code=code, reason=reason)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
del_streaming_client(self.client_id)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def update_key_state(self, data: bytes) -> None:
|
|
||||||
if data == Keys.Down.Ctrl:
|
|
||||||
self.key_state.ctrl_is_down = True
|
|
||||||
elif data == Keys.Up.Ctrl:
|
|
||||||
self.key_state.ctrl_is_down = False
|
|
||||||
elif data == Keys.Down.Alt:
|
|
||||||
self.key_state.alt_is_down = True
|
|
||||||
elif data == Keys.Up.Alt:
|
|
||||||
self.key_state.alt_is_down = False
|
|
||||||
elif data == Keys.Down.Cmd:
|
|
||||||
self.key_state.cmd_is_down = True
|
|
||||||
elif data == Keys.Up.Cmd:
|
|
||||||
self.key_state.cmd_is_down = False
|
|
||||||
@@ -1,240 +1,24 @@
|
|||||||
"""
|
"""
|
||||||
Streaming messages for WebSocket connections.
|
Provides WS endpoints for streaming messages to/from our frontend application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from fastapi import WebSocket, WebSocketDisconnect
|
from fastapi import WebSocket
|
||||||
from websockets.exceptions import ConnectionClosedError
|
|
||||||
|
|
||||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
|
||||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||||
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
|
|
||||||
from skyvern.forge.sdk.routes.streaming.auth import auth
|
from skyvern.forge.sdk.routes.streaming.auth import auth
|
||||||
from skyvern.forge.sdk.routes.streaming.verify import (
|
from skyvern.forge.sdk.routes.streaming.channels.message import (
|
||||||
loop_verify_browser_session,
|
Loops,
|
||||||
loop_verify_workflow_run,
|
MessageChannel,
|
||||||
verify_browser_session,
|
get_message_channel_for_browser_session,
|
||||||
verify_workflow_run,
|
get_message_channel_for_workflow_run,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.utils.aio import collect
|
from skyvern.forge.sdk.utils.aio import collect
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
async def get_messages_for_browser_session(
|
|
||||||
client_id: str,
|
|
||||||
browser_session_id: str,
|
|
||||||
organization_id: str,
|
|
||||||
websocket: WebSocket,
|
|
||||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
|
||||||
"""
|
|
||||||
Return a message channel for a browser session, with a list of loops to run concurrently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
|
|
||||||
|
|
||||||
browser_session = await verify_browser_session(
|
|
||||||
browser_session_id=browser_session_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not browser_session:
|
|
||||||
LOG.info(
|
|
||||||
"Message channel: no initial browser session found.",
|
|
||||||
browser_session_id=browser_session_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
message_channel = sc.MessageChannel(
|
|
||||||
client_id=client_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
browser_session=browser_session,
|
|
||||||
websocket=websocket,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("Got message channel for browser session.", message_channel=message_channel)
|
|
||||||
|
|
||||||
loops = [
|
|
||||||
asyncio.create_task(loop_verify_browser_session(message_channel)),
|
|
||||||
asyncio.create_task(loop_channel(message_channel)),
|
|
||||||
]
|
|
||||||
|
|
||||||
return message_channel, loops
|
|
||||||
|
|
||||||
|
|
||||||
async def get_messages_for_workflow_run(
|
|
||||||
client_id: str,
|
|
||||||
workflow_run_id: str,
|
|
||||||
organization_id: str,
|
|
||||||
websocket: WebSocket,
|
|
||||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
|
||||||
"""
|
|
||||||
Return a message channel for a workflow run, with a list of loops to run concurrently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
|
|
||||||
|
|
||||||
workflow_run, browser_session = await verify_workflow_run(
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not workflow_run:
|
|
||||||
LOG.info(
|
|
||||||
"Message channel: no initial workflow run found.",
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not browser_session:
|
|
||||||
LOG.info(
|
|
||||||
"Message channel: no initial browser session found for workflow run.",
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
message_channel = sc.MessageChannel(
|
|
||||||
client_id=client_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
browser_session=browser_session,
|
|
||||||
workflow_run=workflow_run,
|
|
||||||
websocket=websocket,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
|
|
||||||
|
|
||||||
loops = [
|
|
||||||
asyncio.create_task(loop_verify_workflow_run(message_channel)),
|
|
||||||
asyncio.create_task(loop_channel(message_channel)),
|
|
||||||
]
|
|
||||||
|
|
||||||
return message_channel, loops
|
|
||||||
|
|
||||||
|
|
||||||
async def loop_channel(message_channel: sc.MessageChannel) -> None:
|
|
||||||
"""
|
|
||||||
Stream messages and their results back and forth.
|
|
||||||
|
|
||||||
Loops until the workflow run is cleared or the websocket is closed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not message_channel.browser_session:
|
|
||||||
LOG.info(
|
|
||||||
"No browser session found for workflow run.",
|
|
||||||
workflow_run=message_channel.workflow_run,
|
|
||||||
organization_id=message_channel.organization_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
async def frontend_to_backend() -> None:
|
|
||||||
LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel)
|
|
||||||
|
|
||||||
while message_channel.is_open:
|
|
||||||
try:
|
|
||||||
data = await message_channel.websocket.receive_json()
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
LOG.error(f"Cannot create channel message: expected dict, got {type(data)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
message = sc.reify_channel_message(data)
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
message_kind = message.kind
|
|
||||||
|
|
||||||
match message_kind:
|
|
||||||
case "take-control":
|
|
||||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
|
||||||
if not streaming:
|
|
||||||
LOG.error(
|
|
||||||
"No streaming client found for message.",
|
|
||||||
message_channel=message_channel,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
streaming.interactor = "user"
|
|
||||||
case "cede-control":
|
|
||||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
|
||||||
if not streaming:
|
|
||||||
LOG.error(
|
|
||||||
"No streaming client found for message.",
|
|
||||||
message_channel=message_channel,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
streaming.interactor = "agent"
|
|
||||||
case "ask-for-clipboard-response":
|
|
||||||
if not isinstance(message, sc.MessageInAskForClipboardResponse):
|
|
||||||
LOG.error(
|
|
||||||
"Invalid message type for ask-for-clipboard-response.",
|
|
||||||
message_channel=message_channel,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
|
||||||
text = message.text
|
|
||||||
|
|
||||||
async with connected_agent(streaming) as agent:
|
|
||||||
await agent.paste_text(text)
|
|
||||||
case _:
|
|
||||||
LOG.error(f"Unknown message kind: '{message_kind}'")
|
|
||||||
continue
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
LOG.info(
|
|
||||||
"Frontend disconnected.",
|
|
||||||
workflow_run=message_channel.workflow_run,
|
|
||||||
organization_id=message_channel.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except ConnectionClosedError:
|
|
||||||
LOG.info(
|
|
||||||
"Frontend closed the streaming session.",
|
|
||||||
workflow_run=message_channel.workflow_run,
|
|
||||||
organization_id=message_channel.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"An unexpected exception occurred.",
|
|
||||||
workflow_run=message_channel.workflow_run,
|
|
||||||
organization_id=message_channel.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
loops = [
|
|
||||||
asyncio.create_task(frontend_to_backend()),
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
await collect(loops)
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"An exception occurred in loop channel stream.",
|
|
||||||
workflow_run=message_channel.workflow_run,
|
|
||||||
organization_id=message_channel.organization_id,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
LOG.info(
|
|
||||||
"Closing the loop channel stream.",
|
|
||||||
workflow_run=message_channel.workflow_run,
|
|
||||||
organization_id=message_channel.organization_id,
|
|
||||||
)
|
|
||||||
await message_channel.close(reason="loop-channel-closed")
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
|
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
|
||||||
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
|
|
||||||
async def browser_session_messages(
|
async def browser_session_messages(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
browser_session_id: str,
|
browser_session_id: str,
|
||||||
@@ -242,64 +26,16 @@ async def browser_session_messages(
|
|||||||
client_id: str | None = None,
|
client_id: str | None = None,
|
||||||
token: str | None = None,
|
token: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id)
|
return await messages(
|
||||||
|
|
||||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
|
||||||
|
|
||||||
if not organization_id:
|
|
||||||
LOG.error("Authentication failed.", browser_session_id=browser_session_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not client_id:
|
|
||||||
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
message_channel: sc.MessageChannel
|
|
||||||
loops: list[asyncio.Task] = []
|
|
||||||
|
|
||||||
result = await get_messages_for_browser_session(
|
|
||||||
client_id=client_id,
|
|
||||||
browser_session_id=browser_session_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
websocket=websocket,
|
websocket=websocket,
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
apikey=apikey,
|
||||||
|
client_id=client_id,
|
||||||
|
token=token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result:
|
|
||||||
LOG.error(
|
|
||||||
"No streaming context found for the browser session.",
|
|
||||||
browser_session_id=browser_session_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
await websocket.close(code=1013)
|
|
||||||
return
|
|
||||||
|
|
||||||
message_channel, loops = result
|
|
||||||
|
|
||||||
try:
|
|
||||||
LOG.info(
|
|
||||||
"Starting message stream loops for browser session.",
|
|
||||||
browser_session_id=browser_session_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
await collect(loops)
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"An exception occurred in the message stream function for browser session.",
|
|
||||||
browser_session_id=browser_session_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
LOG.info(
|
|
||||||
"Closing the message stream session for browser session.",
|
|
||||||
browser_session_id=browser_session_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
await message_channel.close(reason="stream-closed")
|
|
||||||
|
|
||||||
|
|
||||||
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
|
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
|
||||||
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
|
|
||||||
async def workflow_run_messages(
|
async def workflow_run_messages(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@@ -307,31 +43,77 @@ async def workflow_run_messages(
|
|||||||
client_id: str | None = None,
|
client_id: str | None = None,
|
||||||
token: str | None = None,
|
token: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
LOG.info("Starting message stream.", workflow_run_id=workflow_run_id)
|
return await messages(
|
||||||
|
websocket=websocket,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
apikey=apikey,
|
||||||
|
client_id=client_id,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def messages(
|
||||||
|
websocket: WebSocket,
|
||||||
|
browser_session_id: str | None = None,
|
||||||
|
workflow_run_id: str | None = None,
|
||||||
|
apikey: str | None = None,
|
||||||
|
client_id: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
LOG.info(
|
||||||
|
"Starting message stream.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||||
|
|
||||||
if not organization_id:
|
if not organization_id:
|
||||||
LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
|
LOG.error(
|
||||||
|
"Authentication failed.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not client_id:
|
if not client_id:
|
||||||
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
|
LOG.error(
|
||||||
|
"No client ID provided.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
message_channel: sc.MessageChannel
|
message_channel: MessageChannel
|
||||||
loops: list[asyncio.Task] = []
|
loops: Loops = []
|
||||||
|
|
||||||
result = await get_messages_for_workflow_run(
|
if browser_session_id:
|
||||||
client_id=client_id,
|
result = await get_message_channel_for_browser_session(
|
||||||
workflow_run_id=workflow_run_id,
|
client_id=client_id,
|
||||||
organization_id=organization_id,
|
browser_session_id=browser_session_id,
|
||||||
websocket=websocket,
|
organization_id=organization_id,
|
||||||
)
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
elif workflow_run_id:
|
||||||
|
result = await get_message_channel_for_workflow_run(
|
||||||
|
client_id=client_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.error(
|
||||||
|
"Message channel: no browser_session_id or workflow_run_id provided.",
|
||||||
|
client_id=client_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
await websocket.close(code=1002)
|
||||||
|
return
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"No streaming context found for the workflow run.",
|
"No message channel found.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
@@ -342,22 +124,25 @@ async def workflow_run_messages(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Starting message stream loops.",
|
"Starting message channel loops.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
await collect(loops)
|
await collect(loops)
|
||||||
except Exception:
|
except Exception:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"An exception occurred in the message stream function.",
|
"An exception occurred in the message loop function(s).",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Closing the message stream session.",
|
"Closing the message channel.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await message_channel.close(reason="stream-closed")
|
await message_channel.close(reason="message-stream-closed")
|
||||||
|
|||||||
78
skyvern/forge/sdk/routes/streaming/registries.py
Normal file
78
skyvern/forge/sdk/routes/streaming/registries.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
Contains registries for coordinating active WS connections (aka "channels", see
|
||||||
|
`./channels/README.md`).
|
||||||
|
|
||||||
|
NOTE: in AWS we had to turn on what amounts to sticky sessions for frontend apps,
|
||||||
|
so that an individual frontend app instance is guaranteed to always connect to
|
||||||
|
the same backend api instance. This is beccause the two registries here are
|
||||||
|
tied together via a `client_id` string.
|
||||||
|
|
||||||
|
The tale-of-the-tape is this:
|
||||||
|
- frontend app requires two different channels (WS connections) to the backend api
|
||||||
|
- one dedicated to streaming VNC's RFB protocol
|
||||||
|
- the other dedicated to messaging (JSON)
|
||||||
|
- both of these channels are stateful and need to coordinate with one another
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel
|
||||||
|
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# a registry for VNC channels, keyed by `client_id`
|
||||||
|
vnc_channels: dict[str, VncChannel] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def add_vnc_channel(vnc_channel: VncChannel) -> None:
|
||||||
|
vnc_channels[vnc_channel.client_id] = vnc_channel
|
||||||
|
|
||||||
|
|
||||||
|
def get_vnc_channel(client_id: str) -> t.Union[VncChannel, None]:
|
||||||
|
return vnc_channels.get(client_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
def del_vnc_channel(client_id: str) -> None:
|
||||||
|
try:
|
||||||
|
del vnc_channels[client_id]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# a registry for message channels, keyed by `client_id`
|
||||||
|
message_channels: dict[str, MessageChannel] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def add_message_channel(message_channel: MessageChannel) -> None:
|
||||||
|
message_channels[message_channel.client_id] = message_channel
|
||||||
|
|
||||||
|
|
||||||
|
def get_message_channel(client_id: str) -> t.Union[MessageChannel, None]:
|
||||||
|
candidate = message_channels.get(client_id, None)
|
||||||
|
|
||||||
|
if candidate and candidate.is_open:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
if candidate:
|
||||||
|
LOG.info(
|
||||||
|
"MessageChannel: message channel is not open; deleting it",
|
||||||
|
client_id=candidate.client_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
del_message_channel(candidate.client_id)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def del_message_channel(client_id: str) -> None:
|
||||||
|
try:
|
||||||
|
del message_channels[client_id]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
@@ -1,3 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Provides WS endpoints for streaming screenshots.
|
||||||
|
|
||||||
|
Screenshot streaming is created on the basis of one of these database entities:
|
||||||
|
- task (run)
|
||||||
|
- workflow run
|
||||||
|
|
||||||
|
Screenshot streaming is used for a run that is invoked without a browser session.
|
||||||
|
Otherwise, VNC streaming is used.
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|||||||
@@ -1,18 +1,43 @@
|
|||||||
|
"""
|
||||||
|
Channels (see ./channels/README.md) variously rely on the state of one of
|
||||||
|
these database entities:
|
||||||
|
- browser session
|
||||||
|
- task
|
||||||
|
- workflow run
|
||||||
|
|
||||||
|
That is, channels are created on the basis of one of those entities, and that
|
||||||
|
entity must be in a valid state for the channel to continue.
|
||||||
|
|
||||||
|
So, in order to continue operating a channel, we need to periodically verify
|
||||||
|
that the underlying entity is still valid. This module provides logic to
|
||||||
|
perform those verifications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import typing as t
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel
|
||||||
|
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class Constants:
|
||||||
|
POLL_INTERVAL_FOR_VERIFICATION_SECONDS = 5
|
||||||
|
|
||||||
|
|
||||||
async def verify_browser_session(
|
async def verify_browser_session(
|
||||||
browser_session_id: str,
|
browser_session_id: str,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
@@ -122,9 +147,7 @@ async def verify_task(
|
|||||||
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
|
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.error(
|
LOG.error("vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e)
|
||||||
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
|
|
||||||
)
|
|
||||||
return task, None
|
return task, None
|
||||||
|
|
||||||
return task, addressable_browser_session
|
return task, addressable_browser_session
|
||||||
@@ -238,7 +261,7 @@ async def verify_workflow_run(
|
|||||||
return workflow_run, addressable_browser_session
|
return workflow_run, addressable_browser_session
|
||||||
|
|
||||||
|
|
||||||
async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streaming) -> None:
|
async def loop_verify_browser_session(verifiable: MessageChannel | VncChannel) -> None:
|
||||||
"""
|
"""
|
||||||
Loop until the browser session is cleared or the websocket is closed.
|
Loop until the browser session is cleared or the websocket is closed.
|
||||||
"""
|
"""
|
||||||
@@ -251,27 +274,27 @@ async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streami
|
|||||||
|
|
||||||
verifiable.browser_session = browser_session
|
verifiable.browser_session = browser_session
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
async def loop_verify_task(streaming: sc.Streaming) -> None:
|
async def loop_verify_task(vnc_channel: VncChannel) -> None:
|
||||||
"""
|
"""
|
||||||
Loop until the task is cleared or the websocket is closed.
|
Loop until the task is cleared or the websocket is closed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
while streaming.task and streaming.is_open:
|
while vnc_channel.task and vnc_channel.is_open:
|
||||||
task, browser_session = await verify_task(
|
task, browser_session = await verify_task(
|
||||||
task_id=streaming.task.task_id,
|
task_id=vnc_channel.task.task_id,
|
||||||
organization_id=streaming.organization_id,
|
organization_id=vnc_channel.organization_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
streaming.task = task
|
vnc_channel.task = task
|
||||||
streaming.browser_session = browser_session
|
vnc_channel.browser_session = browser_session
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) -> None:
|
async def loop_verify_workflow_run(verifiable: MessageChannel | VncChannel) -> None:
|
||||||
"""
|
"""
|
||||||
Loop until the workflow run is cleared or the websocket is closed.
|
Loop until the workflow run is cleared or the websocket is closed.
|
||||||
"""
|
"""
|
||||||
@@ -285,4 +308,4 @@ async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming)
|
|||||||
verifiable.workflow_run = workflow_run
|
verifiable.workflow_run = workflow_run
|
||||||
verifiable.browser_session = browser_session
|
verifiable.browser_session = browser_session
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(Constants.POLL_INTERVAL_FOR_VERIFICATION_SECONDS)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Streaming VNC WebSocket connections.
|
Provides WS endpoints for streaming a remote browser via VNC.
|
||||||
|
|
||||||
NOTE(jdo:streaming-local-dev)
|
NOTE(jdo:streaming-local-dev)
|
||||||
-----------------------------
|
-----------------------------
|
||||||
@@ -9,459 +9,23 @@ NOTE(jdo:streaming-local-dev)
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import typing as t
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
import websockets
|
from fastapi import WebSocket
|
||||||
from fastapi import WebSocket, WebSocketDisconnect
|
|
||||||
from websockets import Data
|
|
||||||
from websockets.exceptions import ConnectionClosedError
|
|
||||||
|
|
||||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
|
||||||
from skyvern.config import settings
|
|
||||||
from skyvern.forge import app
|
|
||||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
|
||||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||||
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
|
|
||||||
from skyvern.forge.sdk.routes.streaming.auth import auth
|
from skyvern.forge.sdk.routes.streaming.auth import auth
|
||||||
from skyvern.forge.sdk.routes.streaming.verify import (
|
from skyvern.forge.sdk.routes.streaming.channels.vnc import (
|
||||||
loop_verify_browser_session,
|
Loops,
|
||||||
loop_verify_task,
|
VncChannel,
|
||||||
loop_verify_workflow_run,
|
get_vnc_channel_for_browser_session,
|
||||||
verify_browser_session,
|
get_vnc_channel_for_task,
|
||||||
verify_task,
|
get_vnc_channel_for_workflow_run,
|
||||||
verify_workflow_run,
|
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.utils.aio import collect
|
from skyvern.forge.sdk.utils.aio import collect
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
class Constants:
|
|
||||||
MissingXApiKey = "<missing-x-api-key>"
|
|
||||||
|
|
||||||
|
|
||||||
async def get_x_api_key(organization_id: str) -> str:
|
|
||||||
token = await app.DATABASE.get_valid_org_auth_token(
|
|
||||||
organization_id,
|
|
||||||
OrganizationAuthTokenType.api.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not token:
|
|
||||||
LOG.warning(
|
|
||||||
"No valid API key found for organization when streaming.",
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
x_api_key = Constants.MissingXApiKey
|
|
||||||
else:
|
|
||||||
x_api_key = token.token
|
|
||||||
|
|
||||||
return x_api_key
|
|
||||||
|
|
||||||
|
|
||||||
async def get_streaming_for_browser_session(
|
|
||||||
client_id: str,
|
|
||||||
browser_session_id: str,
|
|
||||||
organization_id: str,
|
|
||||||
websocket: WebSocket,
|
|
||||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
|
||||||
"""
|
|
||||||
Return a streaming context for a browser session, with a list of loops to run concurrently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
LOG.info("Getting streaming context for browser session.", browser_session_id=browser_session_id)
|
|
||||||
|
|
||||||
browser_session = await verify_browser_session(
|
|
||||||
browser_session_id=browser_session_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not browser_session:
|
|
||||||
LOG.info(
|
|
||||||
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
x_api_key = await get_x_api_key(organization_id)
|
|
||||||
|
|
||||||
streaming = sc.Streaming(
|
|
||||||
client_id=client_id,
|
|
||||||
interactor="agent",
|
|
||||||
organization_id=organization_id,
|
|
||||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
|
||||||
browser_session=browser_session,
|
|
||||||
x_api_key=x_api_key,
|
|
||||||
websocket=websocket,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("Got streaming context for browser session.", streaming=streaming)
|
|
||||||
|
|
||||||
loops = [
|
|
||||||
asyncio.create_task(loop_verify_browser_session(streaming)),
|
|
||||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
|
||||||
]
|
|
||||||
|
|
||||||
return streaming, loops
|
|
||||||
|
|
||||||
|
|
||||||
async def get_streaming_for_task(
|
|
||||||
client_id: str,
|
|
||||||
task_id: str,
|
|
||||||
organization_id: str,
|
|
||||||
websocket: WebSocket,
|
|
||||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
|
||||||
"""
|
|
||||||
Return a streaming context for a task, with a list of loops to run concurrently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
|
|
||||||
|
|
||||||
if not task:
|
|
||||||
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not browser_session:
|
|
||||||
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
x_api_key = await get_x_api_key(organization_id)
|
|
||||||
|
|
||||||
streaming = sc.Streaming(
|
|
||||||
client_id=client_id,
|
|
||||||
interactor="agent",
|
|
||||||
organization_id=organization_id,
|
|
||||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
|
||||||
x_api_key=x_api_key,
|
|
||||||
websocket=websocket,
|
|
||||||
browser_session=browser_session,
|
|
||||||
task=task,
|
|
||||||
)
|
|
||||||
|
|
||||||
loops = [
|
|
||||||
asyncio.create_task(loop_verify_task(streaming)),
|
|
||||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
|
||||||
]
|
|
||||||
|
|
||||||
return streaming, loops
|
|
||||||
|
|
||||||
|
|
||||||
async def get_streaming_for_workflow_run(
|
|
||||||
client_id: str,
|
|
||||||
workflow_run_id: str,
|
|
||||||
organization_id: str,
|
|
||||||
websocket: WebSocket,
|
|
||||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
|
||||||
"""
|
|
||||||
Return a streaming context for a workflow run, with a list of loops to run concurrently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
LOG.info("Getting streaming context for workflow run.", workflow_run_id=workflow_run_id)
|
|
||||||
|
|
||||||
workflow_run, browser_session = await verify_workflow_run(
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not workflow_run:
|
|
||||||
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not browser_session:
|
|
||||||
LOG.info(
|
|
||||||
"No initial browser session found for workflow run.",
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
x_api_key = await get_x_api_key(organization_id)
|
|
||||||
|
|
||||||
streaming = sc.Streaming(
|
|
||||||
client_id=client_id,
|
|
||||||
interactor="agent",
|
|
||||||
organization_id=organization_id,
|
|
||||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
|
||||||
browser_session=browser_session,
|
|
||||||
workflow_run=workflow_run,
|
|
||||||
x_api_key=x_api_key,
|
|
||||||
websocket=websocket,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("Got streaming context for workflow run.", streaming=streaming)
|
|
||||||
|
|
||||||
loops = [
|
|
||||||
asyncio.create_task(loop_verify_workflow_run(streaming)),
|
|
||||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
|
||||||
]
|
|
||||||
|
|
||||||
return streaming, loops
|
|
||||||
|
|
||||||
|
|
||||||
def verify_message_channel(
|
|
||||||
message_channel: sc.MessageChannel | None, streaming: sc.Streaming
|
|
||||||
) -> sc.MessageChannel | t.Literal[False]:
|
|
||||||
if message_channel and message_channel.is_open:
|
|
||||||
return message_channel
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"No message channel found for client, or it is not open",
|
|
||||||
message_channel=message_channel,
|
|
||||||
client_id=streaming.client_id,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def copy_text(streaming: sc.Streaming) -> None:
|
|
||||||
try:
|
|
||||||
async with connected_agent(streaming) as agent:
|
|
||||||
copied_text = await agent.get_selected_text()
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
"Retrieved selected text via CDP",
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
message_channel = sc.get_message_client(streaming.client_id)
|
|
||||||
|
|
||||||
if cc := verify_message_channel(message_channel, streaming):
|
|
||||||
await cc.send_copied_text(copied_text, streaming)
|
|
||||||
else:
|
|
||||||
LOG.warning(
|
|
||||||
"No message channel found for client, or it is not open",
|
|
||||||
message_channel=message_channel,
|
|
||||||
client_id=streaming.client_id,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"Failed to retrieve selected text via CDP",
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def ask_for_clipboard(streaming: sc.Streaming) -> None:
|
|
||||||
try:
|
|
||||||
LOG.info(
|
|
||||||
"Asking for clipboard data via CDP",
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
message_channel = sc.get_message_client(streaming.client_id)
|
|
||||||
|
|
||||||
if cc := verify_message_channel(message_channel, streaming):
|
|
||||||
await cc.ask_for_clipboard(streaming)
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"Failed to ask for clipboard via CDP",
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def loop_stream_vnc(streaming: sc.Streaming) -> None:
|
|
||||||
"""
|
|
||||||
Actually stream the VNC session data between a frontend and a browser
|
|
||||||
session.
|
|
||||||
|
|
||||||
Loops until the task is cleared or the websocket is closed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not streaming.browser_session:
|
|
||||||
LOG.info("No browser session found for task.", task=streaming.task, organization_id=streaming.organization_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
vnc_url: str = ""
|
|
||||||
if streaming.browser_session.ip_address:
|
|
||||||
if ":" in streaming.browser_session.ip_address:
|
|
||||||
ip, _ = streaming.browser_session.ip_address.split(":")
|
|
||||||
vnc_url = f"ws://{ip}:{streaming.vnc_port}"
|
|
||||||
else:
|
|
||||||
vnc_url = f"ws://{streaming.browser_session.ip_address}:{streaming.vnc_port}"
|
|
||||||
else:
|
|
||||||
browser_address = streaming.browser_session.browser_address
|
|
||||||
|
|
||||||
parsed_browser_address = urlparse(browser_address)
|
|
||||||
host = parsed_browser_address.hostname
|
|
||||||
vnc_url = f"ws://{host}:{streaming.vnc_port}"
|
|
||||||
|
|
||||||
# NOTE(jdo:streaming-local-dev)
|
|
||||||
# vnc_url = "ws://localhost:9001/ws/novnc"
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
"Connecting to VNC URL.",
|
|
||||||
vnc_url=vnc_url,
|
|
||||||
task=streaming.task,
|
|
||||||
workflow_run=streaming.workflow_run,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with websockets.connect(vnc_url) as novnc_ws:
|
|
||||||
|
|
||||||
async def frontend_to_browser() -> None:
|
|
||||||
LOG.info("Starting frontend-to-browser data transfer.", streaming=streaming)
|
|
||||||
data: Data | None = None
|
|
||||||
|
|
||||||
while streaming.is_open:
|
|
||||||
try:
|
|
||||||
data = await streaming.websocket.receive_bytes()
|
|
||||||
|
|
||||||
if data:
|
|
||||||
message_type = data[0]
|
|
||||||
|
|
||||||
if message_type == sc.MessageType.Keyboard.value:
|
|
||||||
streaming.update_key_state(data)
|
|
||||||
|
|
||||||
if streaming.key_state.is_copy(data):
|
|
||||||
await copy_text(streaming)
|
|
||||||
|
|
||||||
if streaming.key_state.is_paste(data):
|
|
||||||
await ask_for_clipboard(streaming)
|
|
||||||
|
|
||||||
if streaming.key_state.is_forbidden(data):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if message_type == sc.MessageType.Mouse.value:
|
|
||||||
if sc.Mouse.Up.Right(data):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not streaming.interactor == "user" and message_type in (
|
|
||||||
sc.MessageType.Keyboard.value,
|
|
||||||
sc.MessageType.Mouse.value,
|
|
||||||
):
|
|
||||||
LOG.info(
|
|
||||||
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
LOG.info("Frontend disconnected.", task=streaming.task, organization_id=streaming.organization_id)
|
|
||||||
raise
|
|
||||||
except ConnectionClosedError:
|
|
||||||
LOG.info(
|
|
||||||
"Frontend closed the streaming session.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"An unexpected exception occurred.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
await novnc_ws.send(data)
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
LOG.info(
|
|
||||||
"Browser disconnected from the streaming session.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except ConnectionClosedError:
|
|
||||||
LOG.info(
|
|
||||||
"Browser closed the streaming session.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"An unexpected exception occurred in frontend-to-browser loop.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def browser_to_frontend() -> None:
|
|
||||||
LOG.info("Starting browser-to-frontend data transfer.", streaming=streaming)
|
|
||||||
data: Data | None = None
|
|
||||||
|
|
||||||
while streaming.is_open:
|
|
||||||
try:
|
|
||||||
data = await novnc_ws.recv()
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
LOG.info(
|
|
||||||
"Browser disconnected from the streaming session.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
await streaming.close(reason="browser-disconnected")
|
|
||||||
except ConnectionClosedError:
|
|
||||||
LOG.info(
|
|
||||||
"Browser closed the streaming session.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
await streaming.close(reason="browser-closed")
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"An unexpected exception occurred in browser-to-frontend loop.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
await streaming.websocket.send_bytes(data)
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
LOG.info(
|
|
||||||
"Frontend disconnected from the streaming session.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
await streaming.close(reason="frontend-disconnected")
|
|
||||||
except ConnectionClosedError:
|
|
||||||
LOG.info(
|
|
||||||
"Frontend closed the streaming session.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
await streaming.close(reason="frontend-closed")
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"An unexpected exception occurred.",
|
|
||||||
task=streaming.task,
|
|
||||||
organization_id=streaming.organization_id,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
loops = [
|
|
||||||
asyncio.create_task(frontend_to_browser()),
|
|
||||||
asyncio.create_task(browser_to_frontend()),
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
await collect(loops)
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(
|
|
||||||
"An exception occurred in loop stream.", task=streaming.task, organization_id=streaming.organization_id
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
LOG.info("Closing the loop stream.", task=streaming.task, organization_id=streaming.organization_id)
|
|
||||||
await streaming.close(reason="loop-stream-vnc-closed")
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.websocket("/stream/vnc/browser_session/{browser_session_id}")
|
@base_router.websocket("/stream/vnc/browser_session/{browser_session_id}")
|
||||||
async def browser_session_stream(
|
async def browser_session_stream(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
@@ -507,7 +71,7 @@ async def stream(
|
|||||||
) -> None:
|
) -> None:
|
||||||
if not client_id:
|
if not client_id:
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"Client ID not provided for VNC stream.",
|
"Client ID not provided for vnc stream.",
|
||||||
browser_session_id=browser_session_id,
|
browser_session_id=browser_session_id,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
@@ -515,7 +79,7 @@ async def stream(
|
|||||||
return
|
return
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Starting VNC stream.",
|
"Starting vnc stream.",
|
||||||
browser_session_id=browser_session_id,
|
browser_session_id=browser_session_id,
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@@ -528,11 +92,11 @@ async def stream(
|
|||||||
LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
|
LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
streaming: sc.Streaming
|
vnc_channel: VncChannel
|
||||||
loops: list[asyncio.Task] = []
|
loops: Loops
|
||||||
|
|
||||||
if browser_session_id:
|
if browser_session_id:
|
||||||
result = await get_streaming_for_browser_session(
|
result = await get_vnc_channel_for_browser_session(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
browser_session_id=browser_session_id,
|
browser_session_id=browser_session_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
@@ -541,22 +105,22 @@ async def stream(
|
|||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"No streaming context found for the browser session.",
|
"No vnc context found for the browser session.",
|
||||||
browser_session_id=browser_session_id,
|
browser_session_id=browser_session_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
await websocket.close(code=1013)
|
await websocket.close(code=1013)
|
||||||
return
|
return
|
||||||
|
|
||||||
streaming, loops = result
|
vnc_channel, loops = result
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Starting streaming for browser session.",
|
"Starting vnc for browser session.",
|
||||||
browser_session_id=browser_session_id,
|
browser_session_id=browser_session_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
elif task_id:
|
elif task_id:
|
||||||
result = await get_streaming_for_task(
|
result = await get_vnc_channel_for_task(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
@@ -564,19 +128,17 @@ async def stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
|
LOG.error("No vnc context found for the task.", task_id=task_id, organization_id=organization_id)
|
||||||
await websocket.close(code=1013)
|
await websocket.close(code=1013)
|
||||||
return
|
return
|
||||||
|
|
||||||
streaming, loops = result
|
vnc_channel, loops = result
|
||||||
|
|
||||||
LOG.info("Starting streaming for task.", task_id=task_id, organization_id=organization_id)
|
LOG.info("Starting vnc for task.", task_id=task_id, organization_id=organization_id)
|
||||||
|
|
||||||
elif workflow_run_id:
|
elif workflow_run_id:
|
||||||
LOG.info(
|
LOG.info("Starting vnc for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||||
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
result = await get_vnc_channel_for_workflow_run(
|
||||||
)
|
|
||||||
result = await get_streaming_for_workflow_run(
|
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
@@ -585,25 +147,23 @@ async def stream(
|
|||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"No streaming context found for the workflow run.",
|
"No vnc context found for the workflow run.",
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
await websocket.close(code=1013)
|
await websocket.close(code=1013)
|
||||||
return
|
return
|
||||||
|
|
||||||
streaming, loops = result
|
vnc_channel, loops = result
|
||||||
|
|
||||||
LOG.info(
|
LOG.info("Starting vnc for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||||
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
LOG.error("Neither task ID nor workflow run ID was provided.")
|
LOG.error("Neither task ID nor workflow run ID was provided.")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Starting streaming loops.",
|
"Starting vnc loops.",
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
@@ -611,16 +171,16 @@ async def stream(
|
|||||||
await collect(loops)
|
await collect(loops)
|
||||||
except Exception:
|
except Exception:
|
||||||
LOG.exception(
|
LOG.exception(
|
||||||
"An exception occurred in the stream function.",
|
"An exception occurred in the vnc loop.",
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Closing the streaming session.",
|
"Closing the vnc session.",
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
await streaming.close(reason="stream-closed")
|
await vnc_channel.close(reason="vnc-closed")
|
||||||
|
|||||||
0
skyvern/py.typed
Normal file
0
skyvern/py.typed
Normal file
Reference in New Issue
Block a user