From 6b5699a98cffc680977466e23960a42de6f54f74 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 25 Jun 2025 02:37:26 +0800 Subject: [PATCH] WebSocket Command Channel (#2782) --- .../editor/panels/WorkflowParametersPanel.tsx | 4 +- .../workflowRun/WorkflowRunStreamVnc.tsx | 158 +++++-- .../src/store/useClientIdStore.ts | 11 + skyvern/forge/sdk/routes/__init__.py | 1 + skyvern/forge/sdk/routes/streaming_auth.py | 46 ++ skyvern/forge/sdk/routes/streaming_clients.py | 270 ++++++++++++ .../forge/sdk/routes/streaming_commands.py | 227 ++++++++++ skyvern/forge/sdk/routes/streaming_verify.py | 207 +++++++++ skyvern/forge/sdk/routes/streaming_vnc.py | 398 +++--------------- 9 files changed, 938 insertions(+), 384 deletions(-) create mode 100644 skyvern-frontend/src/store/useClientIdStore.ts create mode 100644 skyvern/forge/sdk/routes/streaming_auth.py create mode 100644 skyvern/forge/sdk/routes/streaming_clients.py create mode 100644 skyvern/forge/sdk/routes/streaming_commands.py create mode 100644 skyvern/forge/sdk/routes/streaming_verify.py diff --git a/skyvern-frontend/src/routes/workflows/editor/panels/WorkflowParametersPanel.tsx b/skyvern-frontend/src/routes/workflows/editor/panels/WorkflowParametersPanel.tsx index a292ac34..2f6c86cc 100644 --- a/skyvern-frontend/src/routes/workflows/editor/panels/WorkflowParametersPanel.tsx +++ b/skyvern-frontend/src/routes/workflows/editor/panels/WorkflowParametersPanel.tsx @@ -141,9 +141,7 @@ function WorkflowParametersPanel() { ) : ( - {parameter.parameterType === "onepassword" - ? "credential" - : parameter.parameterType} + {parameter.parameterType} )} diff --git a/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStreamVnc.tsx b/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStreamVnc.tsx index aa25a244..bcf91c80 100644 --- a/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStreamVnc.tsx +++ b/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStreamVnc.tsx @@ -13,24 +13,40 @@ import { useQueryClient } from "@tanstack/react-query"; import RFB from "@novnc/novnc/lib/rfb.js"; import { environment } from "@/util/env"; import { cn } from "@/util/utils"; +import { useClientIdStore } from "@/store/useClientIdStore"; import "./workflow-run-stream-vnc.css"; const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL; +interface CommandTakeControl { + kind: "take-control"; +} + +interface CommandCedeControl { + kind: "cede-control"; +} + +type Command = CommandTakeControl | CommandCedeControl; + function WorkflowRunStreamVnc() { const { data: workflowRun } = useWorkflowRunQuery(); + const { workflowRunId, workflowPermanentId } = useParams<{ workflowRunId: string; workflowPermanentId: string; }>(); + const [commandSocket, setCommandSocket] = useState(null); const [userIsControlling, setUserIsControlling] = useState(false); const [vncDisconnectedTrigger, setVncDisconnectedTrigger] = useState(0); const prevVncConnectedRef = useRef(false); const [isVncConnected, setIsVncConnected] = useState(false); + const [commandDisconnectedTrigger, setCommandDisconnectedTrigger] = + useState(0); + const prevCommandConnectedRef = useRef(false); + const [isCommandConnected, setIsCommandConnected] = useState(false); const showStream = workflowRun && statusIsNotFinalized(workflowRun); - const credentialGetter = useCredentialGetter(); const queryClient = useQueryClient(); const [canvasContainer, setCanvasContainer] = useState( null, @@ -38,10 +54,42 @@ function WorkflowRunStreamVnc() { const setCanvasContainerRef = useCallback((node: HTMLDivElement | null) => { setCanvasContainer(node); }, []); - const rfbRef = useRef(null); + const clientId = useClientIdStore((state) => state.clientId); + const credentialGetter = useCredentialGetter(); - // effect for disconnects only + const getWebSocketParams = useCallback(async () => { + const clientIdQueryParam = `client_id=${clientId}`; + let credentialQueryParam = ""; + + if (environment === "local") { + credentialQueryParam = `apikey=${envCredential}`; + } else { + if (credentialGetter) { + const token = await credentialGetter(); + credentialQueryParam = `token=Bearer ${token}`; + } else { + credentialQueryParam = `apikey=${envCredential}`; + } + } + + const params = [credentialQueryParam, clientIdQueryParam].join("&"); + + return `${params}`; + }, [clientId, credentialGetter]); + + const invalidateQueries = useCallback(() => { + queryClient.invalidateQueries({ + queryKey: ["workflowRun", workflowPermanentId, workflowRunId], + }); + queryClient.invalidateQueries({ queryKey: ["workflowRuns"] }); + queryClient.invalidateQueries({ + queryKey: ["workflowTasks", workflowRunId], + }); + queryClient.invalidateQueries({ queryKey: ["runs"] }); + }, [queryClient, workflowPermanentId, workflowRunId]); + + // effect for vnc disconnects only useEffect(() => { if (prevVncConnectedRef.current && !isVncConnected) { setVncDisconnectedTrigger((x) => x + 1); @@ -49,6 +97,15 @@ function WorkflowRunStreamVnc() { prevVncConnectedRef.current = isVncConnected; }, [isVncConnected]); + // effect for command disconnects only + useEffect(() => { + if (prevCommandConnectedRef.current && !isCommandConnected) { + setCommandDisconnectedTrigger((x) => x + 1); + } + prevCommandConnectedRef.current = isCommandConnected; + }, [isCommandConnected]); + + // vnc socket useEffect( () => { if (!showStream || !canvasContainer || !workflowRunId) { @@ -61,24 +118,12 @@ function WorkflowRunStreamVnc() { } async function setupVnc() { - let credentialQueryParam = ""; - - if (environment === "local") { - credentialQueryParam = `?apikey=${envCredential}`; - } else { - if (credentialGetter) { - const token = await credentialGetter(); - credentialQueryParam = `?token=Bearer ${token}`; - } else { - credentialQueryParam = `?apikey=${envCredential}`; - } - } - if (rfbRef.current && isVncConnected) { return; } - const vncUrl = `${wssBaseUrl}/stream/vnc/workflow_run/${workflowRunId}${credentialQueryParam}`; + const wsParams = await getWebSocketParams(); + const vncUrl = `${wssBaseUrl}/stream/vnc/workflow_run/${workflowRunId}?${wsParams}`; if (rfbRef.current) { rfbRef.current.disconnect(); @@ -102,15 +147,7 @@ function WorkflowRunStreamVnc() { rfb.addEventListener("disconnect", async (/* e: RfbEvent */) => { setIsVncConnected(false); - - queryClient.invalidateQueries({ - queryKey: ["workflowRun", workflowPermanentId, workflowRunId], - }); - queryClient.invalidateQueries({ queryKey: ["workflowRuns"] }); - queryClient.invalidateQueries({ - queryKey: ["workflowTasks", workflowRunId], - }); - queryClient.invalidateQueries({ queryKey: ["runs"] }); + invalidateQueries(); }); } @@ -127,16 +164,75 @@ function WorkflowRunStreamVnc() { // cannot include isVncConnected in deps as it will cause infinite loop // eslint-disable-next-line react-hooks/exhaustive-deps [ - credentialGetter, - workflowRunId, - workflowPermanentId, - showStream, - queryClient, canvasContainer, + invalidateQueries, + showStream, vncDisconnectedTrigger, // will re-run on disconnects + workflowRunId, ], ); + // command socket + useEffect(() => { + let ws: WebSocket | null = null; + + const connect = async () => { + const wsParams = await getWebSocketParams(); + const commandUrl = `${wssBaseUrl}/stream/commands/workflow_run/${workflowRunId}?${wsParams}`; + ws = new WebSocket(commandUrl); + + ws.onopen = () => { + setIsCommandConnected(true); + setCommandSocket(ws); + }; + + ws.onclose = () => { + setIsCommandConnected(false); + invalidateQueries(); + setCommandSocket(null); + }; + }; + + connect(); + + return () => { + try { + ws && ws.close(); + } catch (e) { + // pass + } + }; + }, [ + commandDisconnectedTrigger, + getWebSocketParams, + invalidateQueries, + workflowRunId, + ]); + + // effect to send a command when the user is controlling, vs not controlling + useEffect(() => { + if (!isCommandConnected) { + return; + } + + const sendCommand = (command: Command) => { + if (!commandSocket) { + console.warn("Cannot send command, as command socket is closed."); + console.warn(command); + return; + } + + commandSocket.send(JSON.stringify(command)); + }; + + if (userIsControlling) { + sendCommand({ kind: "take-control" }); + } else { + sendCommand({ kind: "cede-control" }); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [userIsControlling, isCommandConnected]); + // Effect to show toast when workflow reaches a final state based on hook updates useEffect(() => { if (workflowRun) { diff --git a/skyvern-frontend/src/store/useClientIdStore.ts b/skyvern-frontend/src/store/useClientIdStore.ts new file mode 100644 index 00000000..b31e4dd9 --- /dev/null +++ b/skyvern-frontend/src/store/useClientIdStore.ts @@ -0,0 +1,11 @@ +import { create } from "zustand"; + +type ClientIdStore = { + clientId: string; +}; + +const initialClientId = crypto.randomUUID(); + +export const useClientIdStore = create(() => ({ + clientId: initialClientId, +})); diff --git a/skyvern/forge/sdk/routes/__init__.py b/skyvern/forge/sdk/routes/__init__.py index 83ba7d5d..40527991 100644 --- a/skyvern/forge/sdk/routes/__init__.py +++ b/skyvern/forge/sdk/routes/__init__.py @@ -2,4 +2,5 @@ from skyvern.forge.sdk.routes import agent_protocol # noqa: F401 from skyvern.forge.sdk.routes import browser_sessions # noqa: F401 from skyvern.forge.sdk.routes import credentials # noqa: F401 from skyvern.forge.sdk.routes import streaming # noqa: F401 +from skyvern.forge.sdk.routes import streaming_commands # noqa: F401 from skyvern.forge.sdk.routes import streaming_vnc # noqa: F401 diff --git a/skyvern/forge/sdk/routes/streaming_auth.py b/skyvern/forge/sdk/routes/streaming_auth.py new file mode 100644 index 00000000..0e53b5ee --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming_auth.py @@ -0,0 +1,46 @@ +""" +Streaming auth. +""" + +import structlog +from fastapi import WebSocket +from websockets.exceptions import ConnectionClosedOK + +from skyvern.forge.sdk.services.org_auth_service import get_current_org + +LOG = structlog.get_logger() + + +async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None: + """ + Accepts the websocket connection. + + Authenticates the user; cannot proceed with WS connection if an organization_id cannot be + determined. + """ + + try: + await websocket.accept() + if not token and not apikey: + await websocket.close(code=1002) + return None + except ConnectionClosedOK: + LOG.info("WebSocket connection closed cleanly.") + return None + + try: + organization = await get_current_org(x_api_key=apikey, authorization=token) + organization_id = organization.organization_id + + if not organization_id: + await websocket.close(code=1002) + return None + except Exception: + LOG.exception("Error occurred while retrieving organization information.") + try: + await websocket.close(code=1002) + except ConnectionClosedOK: + LOG.info("WebSocket connection closed due to invalid credentials.") + return None + + return organization_id diff --git a/skyvern/forge/sdk/routes/streaming_clients.py b/skyvern/forge/sdk/routes/streaming_clients.py new file mode 100644 index 00000000..3d54e3a7 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming_clients.py @@ -0,0 +1,270 @@ +""" +Streaming types. +""" + +import asyncio +import dataclasses +import typing as t +from enum import IntEnum + +import structlog +from fastapi import WebSocket +from starlette.websockets import WebSocketState + +from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession +from skyvern.forge.sdk.schemas.tasks import Task +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun + +LOG = structlog.get_logger() + + +Interactor = t.Literal["agent", "user"] +Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs" + + +# Commands + + +# a global registry for WS command clients +command_channels: dict[str, "CommandChannel"] = {} + + +def add_command_client(command_channel: "CommandChannel") -> None: + command_channels[command_channel.client_id] = command_channel + + +def get_command_client(client_id: str) -> t.Union["CommandChannel", None]: + return command_channels.get(client_id, None) + + +def del_command_client(client_id: str) -> None: + try: + del command_channels[client_id] + except KeyError: + pass + + +@dataclasses.dataclass +class CommandChannel: + client_id: str + organization_id: str + websocket: WebSocket + + # -- + + browser_session: AddressablePersistentBrowserSession | None = None + workflow_run: WorkflowRun | None = None + + def __post_init__(self) -> None: + add_command_client(self) + + async def close(self, code: int = 1000, reason: str | None = None) -> "CommandChannel": + LOG.info("Closing command stream.", reason=reason, code=code) + + self.browser_session = None + self.workflow_run = None + + try: + await self.websocket.close(code=code, reason=reason) + except Exception: + pass + + del_command_client(self.client_id) + + return self + + @property + def is_open(self) -> bool: + if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING): + return False + + if not self.workflow_run: + return False + + if not get_command_client(self.client_id): + return False + + return True + + +CommandKinds = t.Literal["take-control", "cede-control"] + + +@dataclasses.dataclass +class Command: + kind: CommandKinds + + +@dataclasses.dataclass +class CommandTakeControl(Command): + kind: t.Literal["take-control"] = "take-control" + + +@dataclasses.dataclass +class CommandCedeControl(Command): + kind: t.Literal["cede-control"] = "cede-control" + + +ChannelCommand = t.Union[CommandTakeControl, CommandCedeControl] + + +def reify_channel_command(data: dict) -> ChannelCommand: + kind = data.get("kind", None) + + match kind: + case "take-control": + return CommandTakeControl() + case "cede-control": + return CommandCedeControl() + case _: + raise ValueError(f"Unknown command kind: '{kind}'") + + +# Streaming + + +# a global registry for WS streaming VNC clients +streaming_clients: dict[str, "Streaming"] = {} + + +def add_streaming_client(streaming: "Streaming") -> None: + streaming_clients[streaming.client_id] = streaming + + +def get_streaming_client(client_id: str) -> t.Union["Streaming", None]: + return streaming_clients.get(client_id, None) + + +def del_streaming_client(client_id: str) -> None: + try: + del streaming_clients[client_id] + except KeyError: + pass + + +class MessageType(IntEnum): + Keyboard = 4 + Mouse = 5 + + +class Keys: + """ + VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now. + + ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4 + ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf + """ + + class Down: + Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3" + Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9" + Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option + OKey = b"\x04\x01\x00\x00\x00\x00\x00o" + + class Up: + Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3" + Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9" + Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option + + +def is_rmb(data: bytes) -> bool: + return data[0:2] == b"\x05\x04" + + +class Mouse: + class Up: + Right = is_rmb + + +@dataclasses.dataclass +class KeyState: + ctrl_is_down: bool = False + alt_is_down: bool = False + cmd_is_down: bool = False + o_is_down: bool = False + + def is_forbidden(self, data: bytes) -> bool: + """ + :return: True if the key is forbidden, else False + """ + return self.is_ctrl_o(data) + + def is_ctrl_o(self, data: bytes) -> bool: + """ + Do not allow the opening of files. + """ + return self.ctrl_is_down and data == Keys.Down.OKey + + +@dataclasses.dataclass +class Streaming: + """ + Streaming state. + """ + + client_id: str + """ + Unique to frontend app instance. + """ + + interactor: Interactor + """ + Whether the user or the agent are the interactor. + """ + + organization_id: str + vnc_port: int + websocket: WebSocket + + # -- + + browser_session: AddressablePersistentBrowserSession | None = None + key_state: KeyState = dataclasses.field(default_factory=KeyState) + task: Task | None = None + workflow_run: WorkflowRun | None = None + + def __post_init__(self) -> None: + add_streaming_client(self) + + @property + def is_open(self) -> bool: + if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING): + return False + + if not self.task and not self.workflow_run: + return False + + if not get_streaming_client(self.client_id): + return False + + return True + + async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming": + LOG.info("Closing Streaming.", reason=reason, code=code) + + self.browser_session = None + self.task = None + self.workflow_run = None + + try: + await self.websocket.close(code=code, reason=reason) + except Exception: + pass + + del_streaming_client(self.client_id) + + return self + + def update_key_state(self, data: bytes) -> None: + if data == Keys.Down.Ctrl: + self.key_state.ctrl_is_down = True + elif data == Keys.Up.Ctrl: + self.key_state.ctrl_is_down = False + elif data == Keys.Down.Alt: + self.key_state.alt_is_down = True + elif data == Keys.Up.Alt: + self.key_state.alt_is_down = False + elif data == Keys.Down.Cmd: + self.key_state.cmd_is_down = True + elif data == Keys.Up.Cmd: + self.key_state.cmd_is_down = False diff --git a/skyvern/forge/sdk/routes/streaming_commands.py b/skyvern/forge/sdk/routes/streaming_commands.py new file mode 100644 index 00000000..fa2f8b6b --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming_commands.py @@ -0,0 +1,227 @@ +""" +Streaming commands WebSocket connections. +""" + +import asyncio + +import structlog +from fastapi import WebSocket, WebSocketDisconnect +from websockets.exceptions import ConnectionClosedError + +import skyvern.forge.sdk.routes.streaming_clients as sc +from skyvern.forge.sdk.routes.routers import legacy_base_router +from skyvern.forge.sdk.routes.streaming_auth import auth +from skyvern.forge.sdk.routes.streaming_verify import loop_verify_workflow_run, verify_workflow_run +from skyvern.forge.sdk.utils.aio import collect + +LOG = structlog.get_logger() + + +async def get_commands_for_workflow_run( + client_id: str, + workflow_run_id: str, + organization_id: str, + websocket: WebSocket, +) -> tuple[sc.CommandChannel, sc.Loops] | None: + """ + Return a commands channel for a workflow run, with a list of loops to run concurrently. + """ + + LOG.info("Getting commands channel for workflow run.", workflow_run_id=workflow_run_id) + + workflow_run, browser_session = await verify_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + + if not workflow_run: + LOG.info( + "Command channel: no initial workflow run found.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return None + + if not browser_session: + LOG.info( + "Command channel: no initial browser session found for workflow run.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return None + + commands = sc.CommandChannel( + client_id=client_id, + organization_id=organization_id, + browser_session=browser_session, + workflow_run=workflow_run, + websocket=websocket, + ) + + LOG.info("Got command channel for workflow run.", commands=commands) + + loops = [ + asyncio.create_task(loop_verify_workflow_run(commands)), + asyncio.create_task(loop_channel(commands)), + ] + + return commands, loops + + +async def loop_channel(commands: sc.CommandChannel) -> None: + """ + Stream commands and their results back and forth. + + Loops until the workflow run is cleared or the websocket is closed. + """ + + if not commands.browser_session: + LOG.info( + "No browser session found for workflow run.", + workflow_run=commands.workflow_run, + organization_id=commands.organization_id, + ) + return + + async def frontend_to_backend() -> None: + LOG.info("Starting frontend-to-backend channel loop.", commands=commands) + + while commands.is_open: + try: + data = await commands.websocket.receive_json() + + if not isinstance(data, dict): + LOG.error(f"Cannot create channel command: expected dict, got {type(data)}") + continue + + try: + command = sc.reify_channel_command(data) + except ValueError: + continue + + command_kind = command.kind + + match command_kind: + case "take-control": + streaming = sc.get_streaming_client(commands.client_id) + if not streaming: + LOG.error("No streaming client found for command.", commands=commands, command=command) + continue + streaming.interactor = "user" + case "cede-control": + streaming = sc.get_streaming_client(commands.client_id) + if not streaming: + LOG.error("No streaming client found for command.", commands=commands, command=command) + continue + streaming.interactor = "agent" + case _: + LOG.error(f"Unknown command kind: '{command_kind}'") + continue + + except WebSocketDisconnect: + LOG.info( + "Frontend disconnected.", + workflow_run=commands.workflow_run, + organization_id=commands.organization_id, + ) + raise + except ConnectionClosedError: + LOG.info( + "Frontend closed the streaming session.", + workflow_run=commands.workflow_run, + organization_id=commands.organization_id, + ) + raise + except asyncio.CancelledError: + pass + except Exception: + LOG.exception( + "An unexpected exception occurred.", + workflow_run=commands.workflow_run, + organization_id=commands.organization_id, + ) + raise + + loops = [ + asyncio.create_task(frontend_to_backend()), + ] + + try: + await collect(loops) + except Exception: + LOG.exception( + "An exception occurred in loop channel stream.", + workflow_run=commands.workflow_run, + organization_id=commands.organization_id, + ) + finally: + LOG.info( + "Closing the loop channel stream.", + workflow_run=commands.workflow_run, + organization_id=commands.organization_id, + ) + await commands.close(reason="loop-channel-closed") + + +@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}") +async def workflow_run_commands( + websocket: WebSocket, + workflow_run_id: str, + apikey: str | None = None, + client_id: str | None = None, + token: str | None = None, +) -> None: + LOG.info("Starting stream commands.", workflow_run_id=workflow_run_id) + + organization_id = await auth(apikey=apikey, token=token, websocket=websocket) + + if not organization_id: + LOG.error("Authentication failed.", workflow_run_id=workflow_run_id) + return + + if not client_id: + LOG.error("No client ID provided.", workflow_run_id=workflow_run_id) + return + + commands: sc.CommandChannel + loops: list[asyncio.Task] = [] + + result = await get_commands_for_workflow_run( + client_id=client_id, + workflow_run_id=workflow_run_id, + organization_id=organization_id, + websocket=websocket, + ) + + if not result: + LOG.error( + "No streaming context found for the workflow run.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + await websocket.close(code=1013) + return + + commands, loops = result + + try: + LOG.info( + "Starting command stream loops.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + await collect(loops) + except Exception: + LOG.exception( + "An exception occurred in the command stream function.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + finally: + LOG.info( + "Closing the command stream session.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + + await commands.close(reason="stream-closed") diff --git a/skyvern/forge/sdk/routes/streaming_verify.py b/skyvern/forge/sdk/routes/streaming_verify.py new file mode 100644 index 00000000..21ed6296 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming_verify.py @@ -0,0 +1,207 @@ +import asyncio + +import structlog + +import skyvern.forge.sdk.routes.streaming_clients as sc +from skyvern.config import settings +from skyvern.forge import app +from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession +from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus + +LOG = structlog.get_logger() + + +async def verify_task( + task_id: str, organization_id: str +) -> tuple[Task | None, AddressablePersistentBrowserSession | None]: + """ + Verify the task is running, and that it has a browser session associated + with it. + """ + + task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id) + + if not task: + LOG.info("Task not found.", task_id=task_id, organization_id=organization_id) + return None, None + + if task.status.is_final(): + LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id) + + return None, None + + if task.status not in [TaskStatus.created, TaskStatus.queued, TaskStatus.running]: + LOG.info( + "Task is not created, queued, or running.", + task_status=task.status, + task_id=task_id, + organization_id=organization_id, + ) + + return None, None + + browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id( + organization_id=organization_id, + runnable_id=task_id, + ) + + if not browser_session: + LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id) + return task, None + + if not browser_session.browser_address: + LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id) + return task, None + + try: + addressable_browser_session = AddressablePersistentBrowserSession( + **browser_session.model_dump() | {"browser_address": browser_session.browser_address}, + ) + except Exception as e: + LOG.error( + "streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e + ) + return task, None + + return task, addressable_browser_session + + +async def verify_workflow_run( + workflow_run_id: str, + organization_id: str, +) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]: + """ + Verify the workflow run is running, and that it has a browser session associated + with it. + """ + + if settings.ENV == "local": + from datetime import datetime + + dummy_workflow_run = WorkflowRun( + workflow_id="123", + workflow_permanent_id="wpid_123", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + status=WorkflowRunStatus.running, + created_at=datetime.now(), + modified_at=datetime.now(), + ) + + dummy_browser_session = AddressablePersistentBrowserSession( + persistent_browser_session_id=workflow_run_id, + organization_id=organization_id, + browser_address="0.0.0.0:9223", + created_at=datetime.now(), + modified_at=datetime.now(), + ) + + return dummy_workflow_run, dummy_browser_session + + workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + + if not workflow_run: + LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id) + return None, None + + if workflow_run.status.is_final(): + LOG.info( + "Workflow run is in a final state. Closing connection.", + workflow_run_status=workflow_run.status, + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + + return None, None + + if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]: + LOG.info( + "Workflow run is not running.", + workflow_run_status=workflow_run.status, + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + + return None, None + + browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id( + organization_id=organization_id, + runnable_id=workflow_run_id, + ) + + if not browser_session: + LOG.info( + "No browser session found for workflow run.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return workflow_run, None + + browser_address = browser_session.browser_address + + if not browser_address: + LOG.info( + "Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id + ) + + try: + _, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address( + session_id=browser_session.persistent_browser_session_id, + organization_id=organization_id, + ) + browser_address = f"{host}:{cdp_port}" + except Exception as ex: + LOG.info( + "Browser session address not found for workflow run.", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ex=ex, + ) + return workflow_run, None + + try: + addressable_browser_session = AddressablePersistentBrowserSession( + **browser_session.model_dump() | {"browser_address": browser_address}, + ) + except Exception: + return workflow_run, None + + return workflow_run, addressable_browser_session + + +async def loop_verify_task(streaming: sc.Streaming) -> None: + """ + Loop until the task is cleared or the websocket is closed. + """ + + while streaming.task and streaming.is_open: + task, browser_session = await verify_task( + task_id=streaming.task.task_id, + organization_id=streaming.organization_id, + ) + + streaming.task = task + streaming.browser_session = browser_session + + await asyncio.sleep(2) + + +async def loop_verify_workflow_run(verifiable: sc.CommandChannel | sc.Streaming) -> None: + """ + Loop until the workflow run is cleared or the websocket is closed. + """ + + while verifiable.workflow_run and verifiable.is_open: + workflow_run, browser_session = await verify_workflow_run( + workflow_run_id=verifiable.workflow_run.workflow_run_id, + organization_id=verifiable.organization_id, + ) + + verifiable.workflow_run = workflow_run + verifiable.browser_session = browser_session + + await asyncio.sleep(2) diff --git a/skyvern/forge/sdk/routes/streaming_vnc.py b/skyvern/forge/sdk/routes/streaming_vnc.py index ac7712fb..0e065d13 100644 --- a/skyvern/forge/sdk/routes/streaming_vnc.py +++ b/skyvern/forge/sdk/routes/streaming_vnc.py @@ -1,233 +1,36 @@ +""" +Streaming VNC WebSocket connections. +""" + import asyncio -import dataclasses -import typing as t -from enum import IntEnum import structlog import websockets from fastapi import WebSocket, WebSocketDisconnect -from starlette.websockets import WebSocketState from websockets import Data -from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.exceptions import ConnectionClosedError +import skyvern.forge.sdk.routes.streaming_clients as sc from skyvern.config import settings -from skyvern.forge import app from skyvern.forge.sdk.routes.routers import legacy_base_router -from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession -from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus -from skyvern.forge.sdk.services.org_auth_service import get_current_org +from skyvern.forge.sdk.routes.streaming_auth import auth +from skyvern.forge.sdk.routes.streaming_verify import ( + loop_verify_task, + loop_verify_workflow_run, + verify_task, + verify_workflow_run, +) from skyvern.forge.sdk.utils.aio import collect -from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus - -Interactor = t.Literal["agent", "user"] -Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs" - - -class MessageType(IntEnum): - Keyboard = 4 - Mouse = 5 - - -class Keys: - """ - VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now. - """ - - class Down: - Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3" - Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9" - Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option - OKey = b"\x04\x01\x00\x00\x00\x00\x00o" - - class Up: - Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3" - Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9" - Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option - - -def is_rmb(data: bytes) -> bool: - return data[0:2] == b"\x05\x04" - - -class Mouse: - class Up: - Right = is_rmb - LOG = structlog.get_logger() -@dataclasses.dataclass -class KeyState: - ctrl_is_down: bool = False - alt_is_down: bool = False - cmd_is_down: bool = False - o_is_down: bool = False - - def is_forbidden(self, data: bytes) -> bool: - """ - :return: True if the key is forbidden, else False - """ - return self.is_ctrl_o(data) - - def is_ctrl_o(self, data: bytes) -> bool: - """ - Do not allow the opening of files. - """ - return self.ctrl_is_down and data == Keys.Down.OKey - - -@dataclasses.dataclass -class Streaming: - """ - Streaming state. - """ - - interactor: Interactor - """ - Whether the user or the agent are the interactor. - """ - - organization_id: str - vnc_port: int - websocket: WebSocket - - # -- - - browser_session: AddressablePersistentBrowserSession | None = None - key_state: KeyState = dataclasses.field(default_factory=KeyState) - task: Task | None = None - workflow_run: WorkflowRun | None = None - - @property - def is_open(self) -> bool: - if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING): - return False - - if not self.task and not self.workflow_run: - return False - - return True - - async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming": - LOG.info("Closing Streaming.", reason=reason, code=code) - - self.browser_session = None - self.task = None - self.workflow_run = None - - try: - await self.websocket.close(code=code, reason=reason) - except Exception: - pass - - return self - - def update_key_state(self, data: bytes) -> None: - if data == Keys.Down.Ctrl: - self.key_state.ctrl_is_down = True - elif data == Keys.Up.Ctrl: - self.key_state.ctrl_is_down = False - elif data == Keys.Down.Alt: - self.key_state.alt_is_down = True - elif data == Keys.Up.Alt: - self.key_state.alt_is_down = False - elif data == Keys.Down.Cmd: - self.key_state.cmd_is_down = True - elif data == Keys.Up.Cmd: - self.key_state.cmd_is_down = False - - -async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None: - """ - Accepts the websocket connection. - - Authenticates the user; cannot proceed with WS connection if an organization_id cannot be - determined. - """ - - try: - await websocket.accept() - if not token and not apikey: - await websocket.close(code=1002) - return None - except ConnectionClosedOK: - LOG.info("WebSocket connection closed cleanly.") - return None - - try: - organization = await get_current_org(x_api_key=apikey, authorization=token) - organization_id = organization.organization_id - - if not organization_id: - await websocket.close(code=1002) - return None - except Exception: - LOG.exception("Error occurred while retrieving organization information.") - try: - await websocket.close(code=1002) - except ConnectionClosedOK: - LOG.info("WebSocket connection closed due to invalid credentials.") - return None - - return organization_id - - -async def verify_task( - task_id: str, organization_id: str -) -> tuple[Task | None, AddressablePersistentBrowserSession | None]: - """ - Verify the task is running, and that it has a browser session associated - with it. - """ - - task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id) - - if not task: - LOG.info("Task not found.", task_id=task_id, organization_id=organization_id) - return None, None - - if task.status.is_final(): - LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id) - - return None, None - - if not task.status == TaskStatus.running: - LOG.info("Task is not running.", task_status=task.status, task_id=task_id, organization_id=organization_id) - - return None, None - - browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id( - organization_id=organization_id, - runnable_id=task_id, # is this correct; is there a task_run_id? - ) - - if not browser_session: - LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id) - return task, None - - if not browser_session.browser_address: - LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id) - return task, None - - try: - addressable_browser_session = AddressablePersistentBrowserSession( - **browser_session.model_dump() | {"browser_address": browser_session.browser_address}, - ) - except Exception as e: - LOG.error( - "streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e - ) - return task, None - - return task, addressable_browser_session - - async def get_streaming_for_task( + client_id: str, task_id: str, organization_id: str, websocket: WebSocket, -) -> tuple[Streaming, Loops] | None: +) -> tuple[sc.Streaming, sc.Loops] | None: """ Return a streaming context for a task, with a list of loops to run concurrently. """ @@ -242,8 +45,9 @@ async def get_streaming_for_task( LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id) return None - streaming = Streaming( - interactor="user", + streaming = sc.Streaming( + client_id=client_id, + interactor="agent", organization_id=organization_id, vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, websocket=websocket, @@ -260,10 +64,11 @@ async def get_streaming_for_task( async def get_streaming_for_workflow_run( + client_id: str, workflow_run_id: str, organization_id: str, websocket: WebSocket, -) -> tuple[Streaming, Loops] | None: +) -> tuple[sc.Streaming, sc.Loops] | None: """ Return a streaming context for a workflow run, with a list of loops to run concurrently. """ @@ -287,8 +92,9 @@ async def get_streaming_for_workflow_run( ) return None - streaming = Streaming( - interactor="user", + streaming = sc.Streaming( + client_id=client_id, + interactor="agent", organization_id=organization_id, vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, browser_session=browser_session, @@ -306,128 +112,7 @@ async def get_streaming_for_workflow_run( return streaming, loops -async def verify_workflow_run( - workflow_run_id: str, - organization_id: str, -) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]: - """ - Verify the workflow run is running, and that it has a browser session associated - with it. - """ - - workflow_run = await app.DATABASE.get_workflow_run( - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ) - - if not workflow_run: - LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id) - return None, None - - if workflow_run.status in [ - WorkflowRunStatus.completed, - WorkflowRunStatus.failed, - WorkflowRunStatus.terminated, - ]: - LOG.info( - "Workflow run is in a final state. Closing connection.", - workflow_run_status=workflow_run.status, - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ) - - return None, None - - if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]: - LOG.info( - "Workflow run is not running.", - workflow_run_status=workflow_run.status, - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ) - - return None, None - - browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id( - organization_id=organization_id, - runnable_id=workflow_run_id, - ) - - if not browser_session: - LOG.info( - "No browser session found for workflow run.", - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ) - return workflow_run, None - - browser_address = browser_session.browser_address - - if not browser_address: - LOG.info( - "Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id - ) - - try: - _, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address( - session_id=browser_session.persistent_browser_session_id, - organization_id=organization_id, - ) - browser_address = f"{host}:{cdp_port}" - except Exception as ex: - LOG.info( - "Browser session address not found for workflow run.", - workflow_run_id=workflow_run_id, - organization_id=organization_id, - ex=ex, - ) - return workflow_run, None - - try: - addressable_browser_session = AddressablePersistentBrowserSession( - **browser_session.model_dump() | {"browser_address": browser_address}, - ) - except Exception: - return workflow_run, None - - return workflow_run, addressable_browser_session - - -async def loop_verify_task(streaming: Streaming) -> None: - """ - Loop until the task is cleared or the websocket is closed. - """ - - while streaming.task and streaming.is_open: - task, browser_session = await verify_task( - task_id=streaming.task.task_id, - organization_id=streaming.organization_id, - ) - - streaming.task = task - streaming.browser_session = browser_session - - await asyncio.sleep(2) - - -async def loop_verify_workflow_run(streaming: Streaming) -> None: - """ - Loop until the workflow run is cleared or the websocket is closed. - """ - - while streaming.workflow_run and streaming.is_open: - workflow_run, browser_session = await verify_workflow_run( - workflow_run_id=streaming.workflow_run.workflow_run_id, - organization_id=streaming.organization_id, - ) - - streaming.workflow_run = workflow_run - streaming.browser_session = browser_session - - await asyncio.sleep(2) - - -async def loop_stream_vnc(streaming: Streaming) -> None: +async def loop_stream_vnc(streaming: sc.Streaming) -> None: """ Actually stream the VNC session data between a frontend and a browser session. @@ -465,19 +150,19 @@ async def loop_stream_vnc(streaming: Streaming) -> None: if data: message_type = data[0] - if message_type == MessageType.Keyboard.value: + if message_type == sc.MessageType.Keyboard.value: streaming.update_key_state(data) if streaming.key_state.is_forbidden(data): continue - if message_type == MessageType.Mouse.value: - if Mouse.Up.Right(data): + if message_type == sc.MessageType.Mouse.value: + if sc.Mouse.Up.Right(data): continue if not streaming.interactor == "user" and message_type in ( - MessageType.Keyboard.value, - MessageType.Mouse.value, + sc.MessageType.Keyboard.value, + sc.MessageType.Mouse.value, ): LOG.info( "Blocking user message.", task=streaming.task, organization_id=streaming.organization_id @@ -615,9 +300,10 @@ async def task_stream( websocket: WebSocket, task_id: str, apikey: str | None = None, + client_id: str | None = None, token: str | None = None, ) -> None: - await stream(websocket, apikey=apikey, task_id=task_id, token=token) + await stream(websocket, apikey=apikey, client_id=client_id, task_id=task_id, token=token) @legacy_base_router.websocket("/stream/vnc/workflow_run/{workflow_run_id}") @@ -625,32 +311,43 @@ async def workflow_run_stream( websocket: WebSocket, workflow_run_id: str, apikey: str | None = None, + client_id: str | None = None, token: str | None = None, ) -> None: - await stream(websocket, apikey=apikey, workflow_run_id=workflow_run_id, token=token) + await stream(websocket, apikey=apikey, client_id=client_id, workflow_run_id=workflow_run_id, token=token) async def stream( websocket: WebSocket, *, apikey: str | None = None, + client_id: str | None = None, task_id: str | None = None, token: str | None = None, workflow_run_id: str | None = None, ) -> None: - LOG.info("Starting VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id) + if not client_id: + LOG.error("Client ID not provided for VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id) + return + + LOG.info("Starting VNC stream.", client_id=client_id, task_id=task_id, workflow_run_id=workflow_run_id) organization_id = await auth(apikey=apikey, token=token, websocket=websocket) if not organization_id: - LOG.info("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id) + LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id) return - streaming: Streaming + streaming: sc.Streaming loops: list[asyncio.Task] = [] if task_id: - result = await get_streaming_for_task(task_id=task_id, organization_id=organization_id, websocket=websocket) + result = await get_streaming_for_task( + client_id=client_id, + task_id=task_id, + organization_id=organization_id, + websocket=websocket, + ) if not result: LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id) @@ -666,6 +363,7 @@ async def stream( "Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id ) result = await get_streaming_for_workflow_run( + client_id=client_id, workflow_run_id=workflow_run_id, organization_id=organization_id, websocket=websocket,