Browser streaming: reorganize & rename (#4033)

This commit is contained in:
Jonathan Dobson
2025-11-19 09:35:05 -05:00
committed by GitHub
parent 1559160aef
commit 2253ca2004
8 changed files with 18 additions and 14 deletions

View File

@@ -0,0 +1,191 @@
"""
A lightweight "agent" for interacting with the streaming browser over CDP.
"""
import typing
from contextlib import asynccontextmanager
import structlog
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
import skyvern.forge.sdk.routes.streaming.clients as sc
from skyvern.config import settings
LOG = structlog.get_logger()
class StreamingAgent:
"""
A minimal agent that can connect to a browser via CDP and execute JavaScript.
Specifically for operations during streaming sessions (like copy/pasting selected text, etc.).
"""
def __init__(self, streaming: sc.Streaming) -> None:
self.streaming = streaming
self.browser: Browser | None = None
self.browser_context: BrowserContext | None = None
self.page: Page | None = None
self.pw: Playwright | None = None
async def connect(self, cdp_url: str | None = None) -> None:
url = cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL
LOG.info("StreamingAgent connecting to CDP", cdp_url=url)
pw = self.pw or await async_playwright().start()
self.pw = pw
headers = {
"x-api-key": self.streaming.x_api_key,
}
self.browser = await pw.chromium.connect_over_cdp(url, headers=headers)
org_id = self.streaming.organization_id
browser_session_id = (
self.streaming.browser_session.persistent_browser_session_id if self.streaming.browser_session else None
)
if browser_session_id:
cdp_session = await self.browser.new_browser_cdp_session()
await cdp_session.send(
"Browser.setDownloadBehavior",
{
"behavior": "allow",
"downloadPath": f"/app/downloads/{org_id}/{browser_session_id}",
"eventsEnabled": True,
},
)
contexts = self.browser.contexts
if contexts:
LOG.info("StreamingAgent using existing browser context")
self.browser_context = contexts[0]
else:
LOG.warning("No existing browser context found, creating new one")
self.browser_context = await self.browser.new_context()
pages = self.browser_context.pages
if pages:
self.page = pages[0]
LOG.info("StreamingAgent connected to page", url=self.page.url)
else:
LOG.warning("No pages found in browser context")
LOG.info("StreamingAgent connected successfully")
async def evaluate_js(
self, expression: str, arg: str | int | float | bool | list | dict | None = None
) -> str | int | float | bool | list | dict | None:
if not self.page:
raise RuntimeError("StreamingAgent is not connected to a page. Call connect() first.")
LOG.info("StreamingAgent evaluating JS", expression=expression[:100])
try:
result = await self.page.evaluate(expression, arg)
LOG.info("StreamingAgent JS evaluation successful")
return result
except Exception as ex:
LOG.exception("StreamingAgent JS evaluation failed", expression=expression, ex=str(ex))
raise
async def get_selected_text(self) -> str:
LOG.info("StreamingAgent getting selected text")
js_expression = """
() => {
const selection = window.getSelection();
return selection ? selection.toString() : '';
}
"""
selected_text = await self.evaluate_js(js_expression)
if isinstance(selected_text, str) or selected_text is None:
LOG.info("StreamingAgent got selected text", length=len(selected_text) if selected_text else 0)
return selected_text or ""
raise RuntimeError(f"StreamingAgent selected text is not a string, but a(n) '{type(selected_text)}'")
async def paste_text(self, text: str) -> None:
LOG.info("StreamingAgent pasting text")
js_expression = """
(text) => {
const activeElement = document.activeElement;
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
const start = activeElement.selectionStart || 0;
const end = activeElement.selectionEnd || 0;
const value = activeElement.value || '';
activeElement.value = value.slice(0, start) + text + value.slice(end);
const newCursorPos = start + text.length;
activeElement.setSelectionRange(newCursorPos, newCursorPos);
}
}
"""
await self.evaluate_js(js_expression, text)
LOG.info("StreamingAgent pasted text successfully")
async def close(self) -> None:
LOG.info("StreamingAgent closing connection")
if self.browser:
await self.browser.close()
self.browser = None
self.browser_context = None
self.page = None
if self.pw:
await self.pw.stop()
self.pw = None
LOG.info("StreamingAgent closed")
@asynccontextmanager
async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterator[StreamingAgent]:
"""
The first pass at this has us doing the following for every operation:
- creating a new agent
- connecting
- [doing smth]
- closing the agent
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
If it turns out it's too slow, we can refactor to keep a persistent agent per streaming client.
"""
if not streaming:
msg = "connected_agent: no streaming client provided."
LOG.error(msg)
raise Exception(msg)
if not streaming.browser_session or not streaming.browser_session.browser_address:
msg = "connected_agent: no browser session or browser address found for streaming client."
LOG.error(
msg,
client_id=streaming.client_id,
organization_id=streaming.organization_id,
)
raise Exception(msg)
agent = StreamingAgent(streaming=streaming)
try:
await agent.connect(streaming.browser_session.browser_address)
# NOTE(jdo:streaming-local-dev): use BROWSER_REMOTE_DEBUGGING_URL from settings
# await agent.connect()
yield agent
finally:
await agent.close()

View File

@@ -0,0 +1,61 @@
"""
Streaming auth.
"""
import structlog
from fastapi import WebSocket
from websockets.exceptions import ConnectionClosedOK
from skyvern.forge.sdk.services.org_auth_service import get_current_org
LOG = structlog.get_logger()
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
"""
Accepts the websocket connection.
Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
determined.
"""
try:
await websocket.accept()
if not token and not apikey:
await websocket.close(code=1002)
return None
except ConnectionClosedOK:
LOG.info("WebSocket connection closed cleanly.")
return None
try:
organization = await get_current_org(x_api_key=apikey, authorization=token)
organization_id = organization.organization_id
if not organization_id:
await websocket.close(code=1002)
return None
except Exception:
LOG.exception("Error occurred while retrieving organization information.")
try:
await websocket.close(code=1002)
except ConnectionClosedOK:
LOG.info("WebSocket connection closed due to invalid credentials.")
return None
return organization_id
# NOTE(jdo:streaming-local-dev): use this instead of the above `auth`
async def _auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
"""
Dummy auth for local testing.
"""
try:
await websocket.accept()
except ConnectionClosedOK:
LOG.info("WebSocket connection closed cleanly.")
return None
return "o_temp123"

View File

@@ -0,0 +1,328 @@
"""
Streaming types.
"""
import asyncio
import dataclasses
import typing as t
from enum import IntEnum
import structlog
from fastapi import WebSocket
from starlette.websockets import WebSocketState
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
LOG = structlog.get_logger()
Interactor = t.Literal["agent", "user"]
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
# Messages
# a global registry for WS message clients
message_channels: dict[str, "MessageChannel"] = {}
def add_message_client(message_channel: "MessageChannel") -> None:
message_channels[message_channel.client_id] = message_channel
def get_message_client(client_id: str) -> t.Union["MessageChannel", None]:
return message_channels.get(client_id, None)
def del_message_client(client_id: str) -> None:
try:
del message_channels[client_id]
except KeyError:
pass
@dataclasses.dataclass
class MessageChannel:
client_id: str
organization_id: str
websocket: WebSocket
# --
browser_session: AddressablePersistentBrowserSession | None = None
workflow_run: WorkflowRun | None = None
def __post_init__(self) -> None:
add_message_client(self)
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
LOG.info("Closing message stream.", reason=reason, code=code)
self.browser_session = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
del_message_client(self.client_id)
return self
@property
def is_open(self) -> bool:
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
return False
if not self.workflow_run and not self.browser_session:
return False
if not get_message_client(self.client_id):
return False
return True
async def ask_for_clipboard(self, streaming: "Streaming") -> None:
try:
await self.websocket.send_json(
{
"kind": "ask-for-clipboard",
}
)
LOG.info(
"Sent ask-for-clipboard to message channel",
organization_id=streaming.organization_id,
)
except Exception:
LOG.exception(
"Failed to send ask-for-clipboard to message channel",
organization_id=streaming.organization_id,
)
async def send_copied_text(self, copied_text: str, streaming: "Streaming") -> None:
try:
await self.websocket.send_json(
{
"kind": "copied-text",
"text": copied_text,
}
)
LOG.info(
"Sent copied text to message channel",
organization_id=streaming.organization_id,
)
except Exception:
LOG.exception(
"Failed to send copied text to message channel",
organization_id=streaming.organization_id,
)
MessageKinds = t.Literal["take-control", "cede-control", "ask-for-clipboard-response"]
@dataclasses.dataclass
class Message:
kind: MessageKinds
@dataclasses.dataclass
class MessageTakeControl(Message):
kind: t.Literal["take-control"] = "take-control"
@dataclasses.dataclass
class MessageCedeControl(Message):
kind: t.Literal["cede-control"] = "cede-control"
@dataclasses.dataclass
class MessageInAskForClipboardResponse(Message):
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
text: str = ""
ChannelMessage = t.Union[MessageTakeControl, MessageCedeControl, MessageInAskForClipboardResponse]
def reify_channel_message(data: dict) -> ChannelMessage:
kind = data.get("kind", None)
match kind:
case "take-control":
return MessageTakeControl()
case "cede-control":
return MessageCedeControl()
case "ask-for-clipboard-response":
text = data.get("text") or ""
return MessageInAskForClipboardResponse(text=text)
case _:
raise ValueError(f"Unknown message kind: '{kind}'")
# Streaming
# a global registry for WS streaming VNC clients
streaming_clients: dict[str, "Streaming"] = {}
def add_streaming_client(streaming: "Streaming") -> None:
streaming_clients[streaming.client_id] = streaming
def get_streaming_client(client_id: str) -> t.Union["Streaming", None]:
return streaming_clients.get(client_id, None)
def del_streaming_client(client_id: str) -> None:
try:
del streaming_clients[client_id]
except KeyError:
pass
class MessageType(IntEnum):
Keyboard = 4
Mouse = 5
class Keys:
"""
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
"""
class Down:
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
class Up:
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
def is_rmb(data: bytes) -> bool:
return data[0:2] == b"\x05\x04"
class Mouse:
class Up:
Right = is_rmb
@dataclasses.dataclass
class KeyState:
ctrl_is_down: bool = False
alt_is_down: bool = False
cmd_is_down: bool = False
def is_forbidden(self, data: bytes) -> bool:
"""
:return: True if the key is forbidden, else False
"""
return self.is_ctrl_o(data)
def is_ctrl_o(self, data: bytes) -> bool:
"""
Do not allow the opening of files.
"""
return self.ctrl_is_down and data == Keys.Down.OKey
def is_copy(self, data: bytes) -> bool:
"""
Detect Ctrl+C or Cmd+C for copy.
"""
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
def is_paste(self, data: bytes) -> bool:
"""
Detect Ctrl+V or Cmd+V for paste.
"""
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
@dataclasses.dataclass
class Streaming:
"""
Streaming state.
"""
client_id: str
"""
Unique to frontend app instance.
"""
interactor: Interactor
"""
Whether the user or the agent are the interactor.
"""
organization_id: str
vnc_port: int
x_api_key: str
websocket: WebSocket
# --
browser_session: AddressablePersistentBrowserSession | None = None
key_state: KeyState = dataclasses.field(default_factory=KeyState)
task: Task | None = None
workflow_run: WorkflowRun | None = None
def __post_init__(self) -> None:
add_streaming_client(self)
@property
def is_open(self) -> bool:
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
return False
if not self.task and not self.workflow_run and not self.browser_session:
return False
if not get_streaming_client(self.client_id):
return False
return True
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
LOG.info("Closing Streaming.", reason=reason, code=code)
self.browser_session = None
self.task = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
del_streaming_client(self.client_id)
return self
def update_key_state(self, data: bytes) -> None:
if data == Keys.Down.Ctrl:
self.key_state.ctrl_is_down = True
elif data == Keys.Up.Ctrl:
self.key_state.ctrl_is_down = False
elif data == Keys.Down.Alt:
self.key_state.alt_is_down = True
elif data == Keys.Up.Alt:
self.key_state.alt_is_down = False
elif data == Keys.Down.Cmd:
self.key_state.cmd_is_down = True
elif data == Keys.Up.Cmd:
self.key_state.cmd_is_down = False

View File

@@ -0,0 +1,363 @@
"""
Streaming messages for WebSocket connections.
"""
import asyncio
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
import skyvern.forge.sdk.routes.streaming.clients as sc
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
from skyvern.forge.sdk.routes.streaming.auth import auth
from skyvern.forge.sdk.routes.streaming.verify import (
loop_verify_browser_session,
loop_verify_workflow_run,
verify_browser_session,
verify_workflow_run,
)
from skyvern.forge.sdk.utils.aio import collect
LOG = structlog.get_logger()
async def get_messages_for_browser_session(
client_id: str,
browser_session_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.MessageChannel, sc.Loops] | None:
"""
Return a message channel for a browser session, with a list of loops to run concurrently.
"""
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
browser_session = await verify_browser_session(
browser_session_id=browser_session_id,
organization_id=organization_id,
)
if not browser_session:
LOG.info(
"Message channel: no initial browser session found.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
return None
message_channel = sc.MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
websocket=websocket,
)
LOG.info("Got message channel for browser session.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_browser_session(message_channel)),
asyncio.create_task(loop_channel(message_channel)),
]
return message_channel, loops
async def get_messages_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.MessageChannel, sc.Loops] | None:
"""
Return a message channel for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info(
"Message channel: no initial workflow run found.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
if not browser_session:
LOG.info(
"Message channel: no initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
message_channel = sc.MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
workflow_run=workflow_run,
websocket=websocket,
)
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_workflow_run(message_channel)),
asyncio.create_task(loop_channel(message_channel)),
]
return message_channel, loops
async def loop_channel(message_channel: sc.MessageChannel) -> None:
"""
Stream messages and their results back and forth.
Loops until the workflow run is cleared or the websocket is closed.
"""
if not message_channel.browser_session:
LOG.info(
"No browser session found for workflow run.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
return
async def frontend_to_backend() -> None:
LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel)
while message_channel.is_open:
try:
data = await message_channel.websocket.receive_json()
if not isinstance(data, dict):
LOG.error(f"Cannot create channel message: expected dict, got {type(data)}")
continue
try:
message = sc.reify_channel_message(data)
except ValueError:
continue
message_kind = message.kind
match message_kind:
case "take-control":
streaming = sc.get_streaming_client(message_channel.client_id)
if not streaming:
LOG.error(
"No streaming client found for message.",
message_channel=message_channel,
message=message,
)
continue
streaming.interactor = "user"
case "cede-control":
streaming = sc.get_streaming_client(message_channel.client_id)
if not streaming:
LOG.error(
"No streaming client found for message.",
message_channel=message_channel,
message=message,
)
continue
streaming.interactor = "agent"
case "ask-for-clipboard-response":
if not isinstance(message, sc.MessageInAskForClipboardResponse):
LOG.error(
"Invalid message type for ask-for-clipboard-response.",
message_channel=message_channel,
message=message,
)
continue
streaming = sc.get_streaming_client(message_channel.client_id)
text = message.text
async with connected_agent(streaming) as agent:
await agent.paste_text(text)
case _:
LOG.error(f"Unknown message kind: '{message_kind}'")
continue
except WebSocketDisconnect:
LOG.info(
"Frontend disconnected.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
loops = [
asyncio.create_task(frontend_to_backend()),
]
try:
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in loop channel stream.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
finally:
LOG.info(
"Closing the loop channel stream.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
await message_channel.close(reason="loop-channel-closed")
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
async def browser_session_messages(
websocket: WebSocket,
browser_session_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.error("Authentication failed.", browser_session_id=browser_session_id)
return
if not client_id:
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
return
message_channel: sc.MessageChannel
loops: list[asyncio.Task] = []
result = await get_messages_for_browser_session(
client_id=client_id,
browser_session_id=browser_session_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error(
"No streaming context found for the browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
message_channel, loops = result
try:
LOG.info(
"Starting message stream loops for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the message stream function for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the message stream session for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await message_channel.close(reason="stream-closed")
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
async def workflow_run_messages(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info("Starting message stream.", workflow_run_id=workflow_run_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
return
if not client_id:
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
return
message_channel: sc.MessageChannel
loops: list[asyncio.Task] = []
result = await get_messages_for_workflow_run(
client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error(
"No streaming context found for the workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
message_channel, loops = result
try:
LOG.info(
"Starting message stream loops.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the message stream function.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the message stream session.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await message_channel.close(reason="stream-closed")

View File

@@ -0,0 +1,271 @@
import asyncio
import base64
from datetime import datetime
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from pydantic import ValidationError
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from skyvern.forge import app
from skyvern.forge.sdk.routes.routers import legacy_base_router
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.services.org_auth_service import get_current_org
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
LOG = structlog.get_logger()
STREAMING_TIMEOUT = 300
@legacy_base_router.websocket("/stream/tasks/{task_id}")
async def task_stream(
websocket: WebSocket,
task_id: str,
apikey: str | None = None,
token: str | None = None,
) -> None:
try:
await websocket.accept()
if not token and not apikey:
await websocket.send_text("No valid credential provided")
return
except ConnectionClosedOK:
LOG.info("ConnectionClosedOK error. Streaming won't start")
return
try:
organization = await get_current_org(x_api_key=apikey, authorization=token)
organization_id = organization.organization_id
except Exception:
LOG.exception("Error while getting organization", task_id=task_id)
try:
await websocket.send_text("Invalid credential provided")
except ConnectionClosedOK:
LOG.info("ConnectionClosedOK error while sending invalid credential message")
return
LOG.info("Started task streaming", task_id=task_id, organization_id=organization_id)
# timestamp last time when streaming activity happens
last_activity_timestamp = datetime.utcnow()
try:
while True:
# if no activity for 5 minutes, close the connection
if (datetime.utcnow() - last_activity_timestamp).total_seconds() > STREAMING_TIMEOUT:
LOG.info(
"No activity for 5 minutes. Closing connection", task_id=task_id, organization_id=organization_id
)
await websocket.send_json(
{
"task_id": task_id,
"status": "timeout",
}
)
return
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("Task not found. Closing connection", task_id=task_id, organization_id=organization_id)
await websocket.send_json(
{
"task_id": task_id,
"status": "not_found",
}
)
return
if task.status.is_final():
LOG.info(
"Task is in a final state. Closing connection",
task_status=task.status,
task_id=task_id,
organization_id=organization_id,
)
await websocket.send_json(
{
"task_id": task_id,
"status": task.status,
}
)
return
if task.status == TaskStatus.running:
file_name = f"{task_id}.png"
if task.workflow_run_id:
file_name = f"{task.workflow_run_id}.png"
screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name)
if screenshot:
encoded_screenshot = base64.b64encode(screenshot).decode("utf-8")
await websocket.send_json(
{
"task_id": task_id,
"status": task.status,
"screenshot": encoded_screenshot,
}
)
last_activity_timestamp = datetime.utcnow()
await asyncio.sleep(2)
except ValidationError as e:
await websocket.send_text(f"Invalid data: {e}")
except WebSocketDisconnect:
LOG.info("WebSocket connection closed", task_id=task_id, organization_id=organization_id)
except ConnectionClosedOK:
LOG.info("ConnectionClosedOK error while streaming", task_id=task_id, organization_id=organization_id)
return
except ConnectionClosedError:
LOG.warning(
"ConnectionClosedError while streaming (client likely disconnected)",
task_id=task_id,
organization_id=organization_id,
)
return
except Exception:
LOG.warning("Error while streaming", task_id=task_id, organization_id=organization_id, exc_info=True)
return
LOG.info("WebSocket connection closed successfully", task_id=task_id, organization_id=organization_id)
return
@legacy_base_router.websocket("/stream/workflow_runs/{workflow_run_id}")
async def workflow_run_streaming(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
token: str | None = None,
) -> None:
try:
await websocket.accept()
if not token and not apikey:
await websocket.send_text("No valid credential provided")
return
except ConnectionClosedOK:
LOG.info("WofklowRun Streaming: ConnectionClosedOK error. Streaming won't start")
return
try:
organization = await get_current_org(x_api_key=apikey, authorization=token)
organization_id = organization.organization_id
except Exception:
LOG.exception(
"WofklowRun Streaming: Error while getting organization",
workflow_run_id=workflow_run_id,
token=token,
)
try:
await websocket.send_text("Invalid credential provided")
except ConnectionClosedOK:
LOG.info("WofklowRun Streaming: ConnectionClosedOK error while sending invalid credential message")
return
LOG.info(
"WofklowRun Streaming: Started workflow run streaming",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
# timestamp last time when streaming activity happens
last_activity_timestamp = datetime.utcnow()
try:
while True:
# if no activity for 5 minutes, close the connection
if (datetime.utcnow() - last_activity_timestamp).total_seconds() > STREAMING_TIMEOUT:
LOG.info(
"WofklowRun Streaming: No activity for 5 minutes. Closing connection",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.send_json(
{
"workflow_run_id": workflow_run_id,
"status": "timeout",
}
)
return
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run or workflow_run.organization_id != organization_id:
LOG.info(
"WofklowRun Streaming: Workflow not found",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.send_json(
{
"workflow_run_id": workflow_run_id,
"status": "not_found",
}
)
return
if workflow_run.status in [
WorkflowRunStatus.completed,
WorkflowRunStatus.failed,
WorkflowRunStatus.terminated,
]:
LOG.info(
"Workflow run is in a final state. Closing connection",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.send_json(
{
"workflow_run_id": workflow_run_id,
"status": workflow_run.status,
}
)
return
if workflow_run.status == WorkflowRunStatus.running:
file_name = f"{workflow_run_id}.png"
screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name)
if screenshot:
encoded_screenshot = base64.b64encode(screenshot).decode("utf-8")
await websocket.send_json(
{
"workflow_run_id": workflow_run_id,
"status": workflow_run.status,
"screenshot": encoded_screenshot,
}
)
last_activity_timestamp = datetime.utcnow()
await asyncio.sleep(2)
except ValidationError as e:
await websocket.send_text(f"Invalid data: {e}")
except WebSocketDisconnect:
LOG.info(
"WofklowRun Streaming: WebSocket connection closed",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
except ConnectionClosedOK:
LOG.info(
"WofklowRun Streaming: ConnectionClosedOK error while streaming",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return
except ConnectionClosedError:
LOG.warning(
"WofklowRun Streaming: ConnectionClosedError while streaming (client likely disconnected)",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return
except Exception:
LOG.warning(
"WofklowRun Streaming: Error while streaming",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
exc_info=True,
)
return
LOG.info(
"WofklowRun Streaming: WebSocket connection closed successfully",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return

View File

@@ -0,0 +1,288 @@
import asyncio
from datetime import datetime
import structlog
import skyvern.forge.sdk.routes.streaming.clients as sc
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
LOG = structlog.get_logger()
async def verify_browser_session(
browser_session_id: str,
organization_id: str,
) -> AddressablePersistentBrowserSession | None:
"""
Verify the browser session exists, and is usable.
"""
if settings.ENV == "local":
dummy_browser_session = AddressablePersistentBrowserSession(
persistent_browser_session_id=browser_session_id,
organization_id=organization_id,
browser_address="0.0.0.0:9223",
created_at=datetime.now(),
modified_at=datetime.now(),
)
return dummy_browser_session
browser_session = await app.DATABASE.get_persistent_browser_session(browser_session_id, organization_id)
if not browser_session:
LOG.info(
"No browser session found.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
return None
browser_address = browser_session.browser_address
if not browser_address:
LOG.info(
"Waiting for browser session address.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
try:
browser_address = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
session_id=browser_session_id,
organization_id=organization_id,
)
except Exception as ex:
LOG.info(
"Browser session address not found for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
ex=ex,
)
return None
try:
addressable_browser_session = AddressablePersistentBrowserSession(
**browser_session.model_dump() | {"browser_address": browser_address},
)
except Exception:
return None
return addressable_browser_session
async def verify_task(
task_id: str, organization_id: str
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
"""
Verify the task is running, and that it has a browser session associated
with it.
"""
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
return None, None
if task.status.is_final():
LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
return None, None
if task.status not in [TaskStatus.created, TaskStatus.queued, TaskStatus.running]:
LOG.info(
"Task is not created, queued, or running.",
task_status=task.status,
task_id=task_id,
organization_id=organization_id,
)
return None, None
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
organization_id=organization_id,
runnable_id=task_id,
)
if not browser_session:
LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
return task, None
if not browser_session.browser_address:
LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
return task, None
try:
addressable_browser_session = AddressablePersistentBrowserSession(
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
)
except Exception as e:
LOG.error(
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
)
return task, None
return task, addressable_browser_session
async def verify_workflow_run(
workflow_run_id: str,
organization_id: str,
) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
"""
Verify the workflow run is running, and that it has a browser session associated
with it.
"""
if settings.ENV == "local":
dummy_workflow_run = WorkflowRun(
workflow_id="123",
workflow_permanent_id="wpid_123",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
status=WorkflowRunStatus.running,
created_at=datetime.now(),
modified_at=datetime.now(),
)
dummy_browser_session = AddressablePersistentBrowserSession(
persistent_browser_session_id=workflow_run_id,
organization_id=organization_id,
browser_address="0.0.0.0:9223",
created_at=datetime.now(),
modified_at=datetime.now(),
)
return dummy_workflow_run, dummy_browser_session
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
return None, None
if workflow_run.status.is_final():
LOG.info(
"Workflow run is in a final state. Closing connection.",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None, None
if workflow_run.status not in [
WorkflowRunStatus.created,
WorkflowRunStatus.queued,
WorkflowRunStatus.running,
WorkflowRunStatus.paused,
]:
LOG.info(
"Workflow run is not running.",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None, None
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
organization_id=organization_id,
runnable_id=workflow_run_id,
)
if not browser_session:
LOG.info(
"No browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return workflow_run, None
browser_address = browser_session.browser_address
if not browser_address:
LOG.info(
"Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
try:
browser_address = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
session_id=browser_session.persistent_browser_session_id,
organization_id=organization_id,
)
except Exception as ex:
LOG.info(
"Browser session address not found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
ex=ex,
)
return workflow_run, None
try:
addressable_browser_session = AddressablePersistentBrowserSession(
**browser_session.model_dump() | {"browser_address": browser_address},
)
except Exception:
return workflow_run, None
return workflow_run, addressable_browser_session
async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streaming) -> None:
"""
Loop until the browser session is cleared or the websocket is closed.
"""
while verifiable.browser_session and verifiable.is_open:
browser_session = await verify_browser_session(
browser_session_id=verifiable.browser_session.persistent_browser_session_id,
organization_id=verifiable.organization_id,
)
verifiable.browser_session = browser_session
await asyncio.sleep(2)
async def loop_verify_task(streaming: sc.Streaming) -> None:
"""
Loop until the task is cleared or the websocket is closed.
"""
while streaming.task and streaming.is_open:
task, browser_session = await verify_task(
task_id=streaming.task.task_id,
organization_id=streaming.organization_id,
)
streaming.task = task
streaming.browser_session = browser_session
await asyncio.sleep(2)
async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) -> None:
"""
Loop until the workflow run is cleared or the websocket is closed.
"""
while verifiable.workflow_run and verifiable.is_open:
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=verifiable.workflow_run.workflow_run_id,
organization_id=verifiable.organization_id,
)
verifiable.workflow_run = workflow_run
verifiable.browser_session = browser_session
await asyncio.sleep(2)

View File

@@ -0,0 +1,626 @@
"""
Streaming VNC WebSocket connections.
NOTE(jdo:streaming-local-dev)
-----------------------------
- grep the above for local development seams
- augment those seams as indicated, then
- stand up https://github.com/jomido/whyvern
"""
import asyncio
import typing as t
from urllib.parse import urlparse
import structlog
import websockets
from fastapi import WebSocket, WebSocketDisconnect
from websockets import Data
from websockets.exceptions import ConnectionClosedError
import skyvern.forge.sdk.routes.streaming.clients as sc
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
from skyvern.forge.sdk.routes.streaming.auth import auth
from skyvern.forge.sdk.routes.streaming.verify import (
loop_verify_browser_session,
loop_verify_task,
loop_verify_workflow_run,
verify_browser_session,
verify_task,
verify_workflow_run,
)
from skyvern.forge.sdk.utils.aio import collect
LOG = structlog.get_logger()
class Constants:
MissingXApiKey = "<missing-x-api-key>"
async def get_x_api_key(organization_id: str) -> str:
token = await app.DATABASE.get_valid_org_auth_token(
organization_id,
OrganizationAuthTokenType.api.value,
)
if not token:
LOG.warning(
"No valid API key found for organization when streaming.",
organization_id=organization_id,
)
x_api_key = Constants.MissingXApiKey
else:
x_api_key = token.token
return x_api_key
async def get_streaming_for_browser_session(
client_id: str,
browser_session_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a browser session, with a list of loops to run concurrently.
"""
LOG.info("Getting streaming context for browser session.", browser_session_id=browser_session_id)
browser_session = await verify_browser_session(
browser_session_id=browser_session_id,
organization_id=organization_id,
)
if not browser_session:
LOG.info(
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
)
return None
x_api_key = await get_x_api_key(organization_id)
streaming = sc.Streaming(
client_id=client_id,
interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
browser_session=browser_session,
x_api_key=x_api_key,
websocket=websocket,
)
LOG.info("Got streaming context for browser session.", streaming=streaming)
loops = [
asyncio.create_task(loop_verify_browser_session(streaming)),
asyncio.create_task(loop_stream_vnc(streaming)),
]
return streaming, loops
async def get_streaming_for_task(
client_id: str,
task_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a task, with a list of loops to run concurrently.
"""
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
return None
if not browser_session:
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
return None
x_api_key = await get_x_api_key(organization_id)
streaming = sc.Streaming(
client_id=client_id,
interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
x_api_key=x_api_key,
websocket=websocket,
browser_session=browser_session,
task=task,
)
loops = [
asyncio.create_task(loop_verify_task(streaming)),
asyncio.create_task(loop_stream_vnc(streaming)),
]
return streaming, loops
async def get_streaming_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting streaming context for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
return None
if not browser_session:
LOG.info(
"No initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
x_api_key = await get_x_api_key(organization_id)
streaming = sc.Streaming(
client_id=client_id,
interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
browser_session=browser_session,
workflow_run=workflow_run,
x_api_key=x_api_key,
websocket=websocket,
)
LOG.info("Got streaming context for workflow run.", streaming=streaming)
loops = [
asyncio.create_task(loop_verify_workflow_run(streaming)),
asyncio.create_task(loop_stream_vnc(streaming)),
]
return streaming, loops
def verify_message_channel(
message_channel: sc.MessageChannel | None, streaming: sc.Streaming
) -> sc.MessageChannel | t.Literal[False]:
if message_channel and message_channel.is_open:
return message_channel
LOG.warning(
"No message channel found for client, or it is not open",
message_channel=message_channel,
client_id=streaming.client_id,
organization_id=streaming.organization_id,
)
return False
async def copy_text(streaming: sc.Streaming) -> None:
try:
async with connected_agent(streaming) as agent:
copied_text = await agent.get_selected_text()
LOG.info(
"Retrieved selected text via CDP",
organization_id=streaming.organization_id,
)
message_channel = sc.get_message_client(streaming.client_id)
if cc := verify_message_channel(message_channel, streaming):
await cc.send_copied_text(copied_text, streaming)
else:
LOG.warning(
"No message channel found for client, or it is not open",
message_channel=message_channel,
client_id=streaming.client_id,
organization_id=streaming.organization_id,
)
except Exception:
LOG.exception(
"Failed to retrieve selected text via CDP",
organization_id=streaming.organization_id,
)
async def ask_for_clipboard(streaming: sc.Streaming) -> None:
try:
LOG.info(
"Asking for clipboard data via CDP",
organization_id=streaming.organization_id,
)
message_channel = sc.get_message_client(streaming.client_id)
if cc := verify_message_channel(message_channel, streaming):
await cc.ask_for_clipboard(streaming)
except Exception:
LOG.exception(
"Failed to ask for clipboard via CDP",
organization_id=streaming.organization_id,
)
async def loop_stream_vnc(streaming: sc.Streaming) -> None:
"""
Actually stream the VNC session data between a frontend and a browser
session.
Loops until the task is cleared or the websocket is closed.
"""
if not streaming.browser_session:
LOG.info("No browser session found for task.", task=streaming.task, organization_id=streaming.organization_id)
return
vnc_url: str = ""
if streaming.browser_session.ip_address:
if ":" in streaming.browser_session.ip_address:
ip, _ = streaming.browser_session.ip_address.split(":")
vnc_url = f"ws://{ip}:{streaming.vnc_port}"
else:
vnc_url = f"ws://{streaming.browser_session.ip_address}:{streaming.vnc_port}"
else:
browser_address = streaming.browser_session.browser_address
parsed_browser_address = urlparse(browser_address)
host = parsed_browser_address.hostname
vnc_url = f"ws://{host}:{streaming.vnc_port}"
# NOTE(jdo:streaming-local-dev)
# vnc_url = "ws://localhost:9001/ws/novnc"
LOG.info(
"Connecting to VNC URL.",
vnc_url=vnc_url,
task=streaming.task,
workflow_run=streaming.workflow_run,
organization_id=streaming.organization_id,
)
async with websockets.connect(vnc_url) as novnc_ws:
async def frontend_to_browser() -> None:
LOG.info("Starting frontend-to-browser data transfer.", streaming=streaming)
data: Data | None = None
while streaming.is_open:
try:
data = await streaming.websocket.receive_bytes()
if data:
message_type = data[0]
if message_type == sc.MessageType.Keyboard.value:
streaming.update_key_state(data)
if streaming.key_state.is_copy(data):
await copy_text(streaming)
if streaming.key_state.is_paste(data):
await ask_for_clipboard(streaming)
if streaming.key_state.is_forbidden(data):
continue
if message_type == sc.MessageType.Mouse.value:
if sc.Mouse.Up.Right(data):
continue
if not streaming.interactor == "user" and message_type in (
sc.MessageType.Keyboard.value,
sc.MessageType.Mouse.value,
):
LOG.info(
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
)
continue
except WebSocketDisconnect:
LOG.info("Frontend disconnected.", task=streaming.task, organization_id=streaming.organization_id)
raise
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
if not data:
continue
try:
await novnc_ws.send(data)
except WebSocketDisconnect:
LOG.info(
"Browser disconnected from the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
except ConnectionClosedError:
LOG.info(
"Browser closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred in frontend-to-browser loop.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
async def browser_to_frontend() -> None:
LOG.info("Starting browser-to-frontend data transfer.", streaming=streaming)
data: Data | None = None
while streaming.is_open:
try:
data = await novnc_ws.recv()
except WebSocketDisconnect:
LOG.info(
"Browser disconnected from the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="browser-disconnected")
except ConnectionClosedError:
LOG.info(
"Browser closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="browser-closed")
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred in browser-to-frontend loop.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
if not data:
continue
try:
await streaming.websocket.send_bytes(data)
except WebSocketDisconnect:
LOG.info(
"Frontend disconnected from the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="frontend-disconnected")
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="frontend-closed")
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
loops = [
asyncio.create_task(frontend_to_browser()),
asyncio.create_task(browser_to_frontend()),
]
try:
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in loop stream.", task=streaming.task, organization_id=streaming.organization_id
)
finally:
LOG.info("Closing the loop stream.", task=streaming.task, organization_id=streaming.organization_id)
await streaming.close(reason="loop-stream-vnc-closed")
@base_router.websocket("/stream/vnc/browser_session/{browser_session_id}")
async def browser_session_stream(
websocket: WebSocket,
browser_session_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
await stream(websocket, apikey=apikey, client_id=client_id, browser_session_id=browser_session_id, token=token)
@legacy_base_router.websocket("/stream/vnc/task/{task_id}")
async def task_stream(
websocket: WebSocket,
task_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
await stream(websocket, apikey=apikey, client_id=client_id, task_id=task_id, token=token)
@legacy_base_router.websocket("/stream/vnc/workflow_run/{workflow_run_id}")
async def workflow_run_stream(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
await stream(websocket, apikey=apikey, client_id=client_id, workflow_run_id=workflow_run_id, token=token)
async def stream(
websocket: WebSocket,
*,
apikey: str | None = None,
browser_session_id: str | None = None,
client_id: str | None = None,
task_id: str | None = None,
token: str | None = None,
workflow_run_id: str | None = None,
) -> None:
if not client_id:
LOG.error(
"Client ID not provided for VNC stream.",
browser_session_id=browser_session_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
return
LOG.info(
"Starting VNC stream.",
browser_session_id=browser_session_id,
client_id=client_id,
task_id=task_id,
workflow_run_id=workflow_run_id,
)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
return
streaming: sc.Streaming
loops: list[asyncio.Task] = []
if browser_session_id:
result = await get_streaming_for_browser_session(
client_id=client_id,
browser_session_id=browser_session_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error(
"No streaming context found for the browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
streaming, loops = result
LOG.info(
"Starting streaming for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
elif task_id:
result = await get_streaming_for_task(
client_id=client_id,
task_id=task_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
await websocket.close(code=1013)
return
streaming, loops = result
LOG.info("Starting streaming for task.", task_id=task_id, organization_id=organization_id)
elif workflow_run_id:
LOG.info(
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
result = await get_streaming_for_workflow_run(
client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error(
"No streaming context found for the workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
streaming, loops = result
LOG.info(
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
else:
LOG.error("Neither task ID nor workflow run ID was provided.")
return
try:
LOG.info(
"Starting streaming loops.",
task_id=task_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the stream function.",
task_id=task_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the streaming session.",
task_id=task_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await streaming.close(reason="stream-closed")