diff --git a/skyvern-frontend/src/routes/workflows/editor/panels/WorkflowParametersPanel.tsx b/skyvern-frontend/src/routes/workflows/editor/panels/WorkflowParametersPanel.tsx
index a292ac34..2f6c86cc 100644
--- a/skyvern-frontend/src/routes/workflows/editor/panels/WorkflowParametersPanel.tsx
+++ b/skyvern-frontend/src/routes/workflows/editor/panels/WorkflowParametersPanel.tsx
@@ -141,9 +141,7 @@ function WorkflowParametersPanel() {
) : (
- {parameter.parameterType === "onepassword"
- ? "credential"
- : parameter.parameterType}
+ {parameter.parameterType}
)}
diff --git a/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStreamVnc.tsx b/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStreamVnc.tsx
index aa25a244..bcf91c80 100644
--- a/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStreamVnc.tsx
+++ b/skyvern-frontend/src/routes/workflows/workflowRun/WorkflowRunStreamVnc.tsx
@@ -13,24 +13,40 @@ import { useQueryClient } from "@tanstack/react-query";
import RFB from "@novnc/novnc/lib/rfb.js";
import { environment } from "@/util/env";
import { cn } from "@/util/utils";
+import { useClientIdStore } from "@/store/useClientIdStore";
import "./workflow-run-stream-vnc.css";
const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL;
+interface CommandTakeControl {
+ kind: "take-control";
+}
+
+interface CommandCedeControl {
+ kind: "cede-control";
+}
+
+type Command = CommandTakeControl | CommandCedeControl;
+
function WorkflowRunStreamVnc() {
const { data: workflowRun } = useWorkflowRunQuery();
+
const { workflowRunId, workflowPermanentId } = useParams<{
workflowRunId: string;
workflowPermanentId: string;
}>();
+ const [commandSocket, setCommandSocket] = useState(null);
const [userIsControlling, setUserIsControlling] = useState(false);
const [vncDisconnectedTrigger, setVncDisconnectedTrigger] = useState(0);
const prevVncConnectedRef = useRef(false);
const [isVncConnected, setIsVncConnected] = useState(false);
+ const [commandDisconnectedTrigger, setCommandDisconnectedTrigger] =
+ useState(0);
+ const prevCommandConnectedRef = useRef(false);
+ const [isCommandConnected, setIsCommandConnected] = useState(false);
const showStream = workflowRun && statusIsNotFinalized(workflowRun);
- const credentialGetter = useCredentialGetter();
const queryClient = useQueryClient();
const [canvasContainer, setCanvasContainer] = useState(
null,
@@ -38,10 +54,42 @@ function WorkflowRunStreamVnc() {
const setCanvasContainerRef = useCallback((node: HTMLDivElement | null) => {
setCanvasContainer(node);
}, []);
-
const rfbRef = useRef(null);
+ const clientId = useClientIdStore((state) => state.clientId);
+ const credentialGetter = useCredentialGetter();
- // effect for disconnects only
+ const getWebSocketParams = useCallback(async () => {
+ const clientIdQueryParam = `client_id=${clientId}`;
+ let credentialQueryParam = "";
+
+ if (environment === "local") {
+ credentialQueryParam = `apikey=${envCredential}`;
+ } else {
+ if (credentialGetter) {
+ const token = await credentialGetter();
+ credentialQueryParam = `token=Bearer ${token}`;
+ } else {
+ credentialQueryParam = `apikey=${envCredential}`;
+ }
+ }
+
+ const params = [credentialQueryParam, clientIdQueryParam].join("&");
+
+ return `${params}`;
+ }, [clientId, credentialGetter]);
+
+ const invalidateQueries = useCallback(() => {
+ queryClient.invalidateQueries({
+ queryKey: ["workflowRun", workflowPermanentId, workflowRunId],
+ });
+ queryClient.invalidateQueries({ queryKey: ["workflowRuns"] });
+ queryClient.invalidateQueries({
+ queryKey: ["workflowTasks", workflowRunId],
+ });
+ queryClient.invalidateQueries({ queryKey: ["runs"] });
+ }, [queryClient, workflowPermanentId, workflowRunId]);
+
+ // effect for vnc disconnects only
useEffect(() => {
if (prevVncConnectedRef.current && !isVncConnected) {
setVncDisconnectedTrigger((x) => x + 1);
@@ -49,6 +97,15 @@ function WorkflowRunStreamVnc() {
prevVncConnectedRef.current = isVncConnected;
}, [isVncConnected]);
+ // effect for command disconnects only
+ useEffect(() => {
+ if (prevCommandConnectedRef.current && !isCommandConnected) {
+ setCommandDisconnectedTrigger((x) => x + 1);
+ }
+ prevCommandConnectedRef.current = isCommandConnected;
+ }, [isCommandConnected]);
+
+ // vnc socket
useEffect(
() => {
if (!showStream || !canvasContainer || !workflowRunId) {
@@ -61,24 +118,12 @@ function WorkflowRunStreamVnc() {
}
async function setupVnc() {
- let credentialQueryParam = "";
-
- if (environment === "local") {
- credentialQueryParam = `?apikey=${envCredential}`;
- } else {
- if (credentialGetter) {
- const token = await credentialGetter();
- credentialQueryParam = `?token=Bearer ${token}`;
- } else {
- credentialQueryParam = `?apikey=${envCredential}`;
- }
- }
-
if (rfbRef.current && isVncConnected) {
return;
}
- const vncUrl = `${wssBaseUrl}/stream/vnc/workflow_run/${workflowRunId}${credentialQueryParam}`;
+ const wsParams = await getWebSocketParams();
+ const vncUrl = `${wssBaseUrl}/stream/vnc/workflow_run/${workflowRunId}?${wsParams}`;
if (rfbRef.current) {
rfbRef.current.disconnect();
@@ -102,15 +147,7 @@ function WorkflowRunStreamVnc() {
rfb.addEventListener("disconnect", async (/* e: RfbEvent */) => {
setIsVncConnected(false);
-
- queryClient.invalidateQueries({
- queryKey: ["workflowRun", workflowPermanentId, workflowRunId],
- });
- queryClient.invalidateQueries({ queryKey: ["workflowRuns"] });
- queryClient.invalidateQueries({
- queryKey: ["workflowTasks", workflowRunId],
- });
- queryClient.invalidateQueries({ queryKey: ["runs"] });
+ invalidateQueries();
});
}
@@ -127,16 +164,75 @@ function WorkflowRunStreamVnc() {
// cannot include isVncConnected in deps as it will cause infinite loop
// eslint-disable-next-line react-hooks/exhaustive-deps
[
- credentialGetter,
- workflowRunId,
- workflowPermanentId,
- showStream,
- queryClient,
canvasContainer,
+ invalidateQueries,
+ showStream,
vncDisconnectedTrigger, // will re-run on disconnects
+ workflowRunId,
],
);
+ // command socket
+ useEffect(() => {
+ let ws: WebSocket | null = null;
+
+ const connect = async () => {
+ const wsParams = await getWebSocketParams();
+ const commandUrl = `${wssBaseUrl}/stream/commands/workflow_run/${workflowRunId}?${wsParams}`;
+ ws = new WebSocket(commandUrl);
+
+ ws.onopen = () => {
+ setIsCommandConnected(true);
+ setCommandSocket(ws);
+ };
+
+ ws.onclose = () => {
+ setIsCommandConnected(false);
+ invalidateQueries();
+ setCommandSocket(null);
+ };
+ };
+
+ connect();
+
+ return () => {
+ try {
+ ws && ws.close();
+ } catch (e) {
+ // pass
+ }
+ };
+ }, [
+ commandDisconnectedTrigger,
+ getWebSocketParams,
+ invalidateQueries,
+ workflowRunId,
+ ]);
+
+ // effect to send a command when the user is controlling, vs not controlling
+ useEffect(() => {
+ if (!isCommandConnected) {
+ return;
+ }
+
+ const sendCommand = (command: Command) => {
+ if (!commandSocket) {
+ console.warn("Cannot send command, as command socket is closed.");
+ console.warn(command);
+ return;
+ }
+
+ commandSocket.send(JSON.stringify(command));
+ };
+
+ if (userIsControlling) {
+ sendCommand({ kind: "take-control" });
+ } else {
+ sendCommand({ kind: "cede-control" });
+ }
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, [userIsControlling, isCommandConnected]);
+
// Effect to show toast when workflow reaches a final state based on hook updates
useEffect(() => {
if (workflowRun) {
diff --git a/skyvern-frontend/src/store/useClientIdStore.ts b/skyvern-frontend/src/store/useClientIdStore.ts
new file mode 100644
index 00000000..b31e4dd9
--- /dev/null
+++ b/skyvern-frontend/src/store/useClientIdStore.ts
@@ -0,0 +1,11 @@
+import { create } from "zustand";
+
+type ClientIdStore = {
+ clientId: string;
+};
+
+const initialClientId = crypto.randomUUID();
+
+export const useClientIdStore = create(() => ({
+ clientId: initialClientId,
+}));
diff --git a/skyvern/forge/sdk/routes/__init__.py b/skyvern/forge/sdk/routes/__init__.py
index 83ba7d5d..40527991 100644
--- a/skyvern/forge/sdk/routes/__init__.py
+++ b/skyvern/forge/sdk/routes/__init__.py
@@ -2,4 +2,5 @@ from skyvern.forge.sdk.routes import agent_protocol # noqa: F401
from skyvern.forge.sdk.routes import browser_sessions # noqa: F401
from skyvern.forge.sdk.routes import credentials # noqa: F401
from skyvern.forge.sdk.routes import streaming # noqa: F401
+from skyvern.forge.sdk.routes import streaming_commands # noqa: F401
from skyvern.forge.sdk.routes import streaming_vnc # noqa: F401
diff --git a/skyvern/forge/sdk/routes/streaming_auth.py b/skyvern/forge/sdk/routes/streaming_auth.py
new file mode 100644
index 00000000..0e53b5ee
--- /dev/null
+++ b/skyvern/forge/sdk/routes/streaming_auth.py
@@ -0,0 +1,46 @@
+"""
+Streaming auth.
+"""
+
+import structlog
+from fastapi import WebSocket
+from websockets.exceptions import ConnectionClosedOK
+
+from skyvern.forge.sdk.services.org_auth_service import get_current_org
+
+LOG = structlog.get_logger()
+
+
+async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
+ """
+ Accepts the websocket connection.
+
+ Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
+ determined.
+ """
+
+ try:
+ await websocket.accept()
+ if not token and not apikey:
+ await websocket.close(code=1002)
+ return None
+ except ConnectionClosedOK:
+ LOG.info("WebSocket connection closed cleanly.")
+ return None
+
+ try:
+ organization = await get_current_org(x_api_key=apikey, authorization=token)
+ organization_id = organization.organization_id
+
+ if not organization_id:
+ await websocket.close(code=1002)
+ return None
+ except Exception:
+ LOG.exception("Error occurred while retrieving organization information.")
+ try:
+ await websocket.close(code=1002)
+ except ConnectionClosedOK:
+ LOG.info("WebSocket connection closed due to invalid credentials.")
+ return None
+
+ return organization_id
diff --git a/skyvern/forge/sdk/routes/streaming_clients.py b/skyvern/forge/sdk/routes/streaming_clients.py
new file mode 100644
index 00000000..3d54e3a7
--- /dev/null
+++ b/skyvern/forge/sdk/routes/streaming_clients.py
@@ -0,0 +1,270 @@
+"""
+Streaming types.
+"""
+
+import asyncio
+import dataclasses
+import typing as t
+from enum import IntEnum
+
+import structlog
+from fastapi import WebSocket
+from starlette.websockets import WebSocketState
+
+from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
+from skyvern.forge.sdk.schemas.tasks import Task
+from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
+
+LOG = structlog.get_logger()
+
+
+Interactor = t.Literal["agent", "user"]
+Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
+
+
+# Commands
+
+
+# a global registry for WS command clients
+command_channels: dict[str, "CommandChannel"] = {}
+
+
+def add_command_client(command_channel: "CommandChannel") -> None:
+ command_channels[command_channel.client_id] = command_channel
+
+
+def get_command_client(client_id: str) -> t.Union["CommandChannel", None]:
+ return command_channels.get(client_id, None)
+
+
+def del_command_client(client_id: str) -> None:
+ try:
+ del command_channels[client_id]
+ except KeyError:
+ pass
+
+
+@dataclasses.dataclass
+class CommandChannel:
+ client_id: str
+ organization_id: str
+ websocket: WebSocket
+
+ # --
+
+ browser_session: AddressablePersistentBrowserSession | None = None
+ workflow_run: WorkflowRun | None = None
+
+ def __post_init__(self) -> None:
+ add_command_client(self)
+
+ async def close(self, code: int = 1000, reason: str | None = None) -> "CommandChannel":
+ LOG.info("Closing command stream.", reason=reason, code=code)
+
+ self.browser_session = None
+ self.workflow_run = None
+
+ try:
+ await self.websocket.close(code=code, reason=reason)
+ except Exception:
+ pass
+
+ del_command_client(self.client_id)
+
+ return self
+
+ @property
+ def is_open(self) -> bool:
+ if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
+ return False
+
+ if not self.workflow_run:
+ return False
+
+ if not get_command_client(self.client_id):
+ return False
+
+ return True
+
+
+CommandKinds = t.Literal["take-control", "cede-control"]
+
+
+@dataclasses.dataclass
+class Command:
+ kind: CommandKinds
+
+
+@dataclasses.dataclass
+class CommandTakeControl(Command):
+ kind: t.Literal["take-control"] = "take-control"
+
+
+@dataclasses.dataclass
+class CommandCedeControl(Command):
+ kind: t.Literal["cede-control"] = "cede-control"
+
+
+ChannelCommand = t.Union[CommandTakeControl, CommandCedeControl]
+
+
+def reify_channel_command(data: dict) -> ChannelCommand:
+ kind = data.get("kind", None)
+
+ match kind:
+ case "take-control":
+ return CommandTakeControl()
+ case "cede-control":
+ return CommandCedeControl()
+ case _:
+ raise ValueError(f"Unknown command kind: '{kind}'")
+
+
+# Streaming
+
+
+# a global registry for WS streaming VNC clients
+streaming_clients: dict[str, "Streaming"] = {}
+
+
+def add_streaming_client(streaming: "Streaming") -> None:
+ streaming_clients[streaming.client_id] = streaming
+
+
+def get_streaming_client(client_id: str) -> t.Union["Streaming", None]:
+ return streaming_clients.get(client_id, None)
+
+
+def del_streaming_client(client_id: str) -> None:
+ try:
+ del streaming_clients[client_id]
+ except KeyError:
+ pass
+
+
+class MessageType(IntEnum):
+ Keyboard = 4
+ Mouse = 5
+
+
+class Keys:
+ """
+ VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
+
+ ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
+ ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
+ """
+
+ class Down:
+ Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
+ Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
+ Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
+ OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
+
+ class Up:
+ Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
+ Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
+ Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
+
+
+def is_rmb(data: bytes) -> bool:
+ return data[0:2] == b"\x05\x04"
+
+
+class Mouse:
+ class Up:
+ Right = is_rmb
+
+
+@dataclasses.dataclass
+class KeyState:
+ ctrl_is_down: bool = False
+ alt_is_down: bool = False
+ cmd_is_down: bool = False
+ o_is_down: bool = False
+
+ def is_forbidden(self, data: bytes) -> bool:
+ """
+ :return: True if the key is forbidden, else False
+ """
+ return self.is_ctrl_o(data)
+
+ def is_ctrl_o(self, data: bytes) -> bool:
+ """
+ Do not allow the opening of files.
+ """
+ return self.ctrl_is_down and data == Keys.Down.OKey
+
+
+@dataclasses.dataclass
+class Streaming:
+ """
+ Streaming state.
+ """
+
+ client_id: str
+ """
+ Unique to frontend app instance.
+ """
+
+ interactor: Interactor
+ """
+ Whether the user or the agent are the interactor.
+ """
+
+ organization_id: str
+ vnc_port: int
+ websocket: WebSocket
+
+ # --
+
+ browser_session: AddressablePersistentBrowserSession | None = None
+ key_state: KeyState = dataclasses.field(default_factory=KeyState)
+ task: Task | None = None
+ workflow_run: WorkflowRun | None = None
+
+ def __post_init__(self) -> None:
+ add_streaming_client(self)
+
+ @property
+ def is_open(self) -> bool:
+ if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
+ return False
+
+ if not self.task and not self.workflow_run:
+ return False
+
+ if not get_streaming_client(self.client_id):
+ return False
+
+ return True
+
+ async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
+ LOG.info("Closing Streaming.", reason=reason, code=code)
+
+ self.browser_session = None
+ self.task = None
+ self.workflow_run = None
+
+ try:
+ await self.websocket.close(code=code, reason=reason)
+ except Exception:
+ pass
+
+ del_streaming_client(self.client_id)
+
+ return self
+
+ def update_key_state(self, data: bytes) -> None:
+ if data == Keys.Down.Ctrl:
+ self.key_state.ctrl_is_down = True
+ elif data == Keys.Up.Ctrl:
+ self.key_state.ctrl_is_down = False
+ elif data == Keys.Down.Alt:
+ self.key_state.alt_is_down = True
+ elif data == Keys.Up.Alt:
+ self.key_state.alt_is_down = False
+ elif data == Keys.Down.Cmd:
+ self.key_state.cmd_is_down = True
+ elif data == Keys.Up.Cmd:
+ self.key_state.cmd_is_down = False
diff --git a/skyvern/forge/sdk/routes/streaming_commands.py b/skyvern/forge/sdk/routes/streaming_commands.py
new file mode 100644
index 00000000..fa2f8b6b
--- /dev/null
+++ b/skyvern/forge/sdk/routes/streaming_commands.py
@@ -0,0 +1,227 @@
+"""
+Streaming commands WebSocket connections.
+"""
+
+import asyncio
+
+import structlog
+from fastapi import WebSocket, WebSocketDisconnect
+from websockets.exceptions import ConnectionClosedError
+
+import skyvern.forge.sdk.routes.streaming_clients as sc
+from skyvern.forge.sdk.routes.routers import legacy_base_router
+from skyvern.forge.sdk.routes.streaming_auth import auth
+from skyvern.forge.sdk.routes.streaming_verify import loop_verify_workflow_run, verify_workflow_run
+from skyvern.forge.sdk.utils.aio import collect
+
+LOG = structlog.get_logger()
+
+
+async def get_commands_for_workflow_run(
+ client_id: str,
+ workflow_run_id: str,
+ organization_id: str,
+ websocket: WebSocket,
+) -> tuple[sc.CommandChannel, sc.Loops] | None:
+ """
+ Return a commands channel for a workflow run, with a list of loops to run concurrently.
+ """
+
+ LOG.info("Getting commands channel for workflow run.", workflow_run_id=workflow_run_id)
+
+ workflow_run, browser_session = await verify_workflow_run(
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+
+ if not workflow_run:
+ LOG.info(
+ "Command channel: no initial workflow run found.",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ return None
+
+ if not browser_session:
+ LOG.info(
+ "Command channel: no initial browser session found for workflow run.",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ return None
+
+ commands = sc.CommandChannel(
+ client_id=client_id,
+ organization_id=organization_id,
+ browser_session=browser_session,
+ workflow_run=workflow_run,
+ websocket=websocket,
+ )
+
+ LOG.info("Got command channel for workflow run.", commands=commands)
+
+ loops = [
+ asyncio.create_task(loop_verify_workflow_run(commands)),
+ asyncio.create_task(loop_channel(commands)),
+ ]
+
+ return commands, loops
+
+
+async def loop_channel(commands: sc.CommandChannel) -> None:
+ """
+ Stream commands and their results back and forth.
+
+ Loops until the workflow run is cleared or the websocket is closed.
+ """
+
+ if not commands.browser_session:
+ LOG.info(
+ "No browser session found for workflow run.",
+ workflow_run=commands.workflow_run,
+ organization_id=commands.organization_id,
+ )
+ return
+
+ async def frontend_to_backend() -> None:
+ LOG.info("Starting frontend-to-backend channel loop.", commands=commands)
+
+ while commands.is_open:
+ try:
+ data = await commands.websocket.receive_json()
+
+ if not isinstance(data, dict):
+ LOG.error(f"Cannot create channel command: expected dict, got {type(data)}")
+ continue
+
+ try:
+ command = sc.reify_channel_command(data)
+ except ValueError:
+ continue
+
+ command_kind = command.kind
+
+ match command_kind:
+ case "take-control":
+ streaming = sc.get_streaming_client(commands.client_id)
+ if not streaming:
+ LOG.error("No streaming client found for command.", commands=commands, command=command)
+ continue
+ streaming.interactor = "user"
+ case "cede-control":
+ streaming = sc.get_streaming_client(commands.client_id)
+ if not streaming:
+ LOG.error("No streaming client found for command.", commands=commands, command=command)
+ continue
+ streaming.interactor = "agent"
+ case _:
+ LOG.error(f"Unknown command kind: '{command_kind}'")
+ continue
+
+ except WebSocketDisconnect:
+ LOG.info(
+ "Frontend disconnected.",
+ workflow_run=commands.workflow_run,
+ organization_id=commands.organization_id,
+ )
+ raise
+ except ConnectionClosedError:
+ LOG.info(
+ "Frontend closed the streaming session.",
+ workflow_run=commands.workflow_run,
+ organization_id=commands.organization_id,
+ )
+ raise
+ except asyncio.CancelledError:
+ pass
+ except Exception:
+ LOG.exception(
+ "An unexpected exception occurred.",
+ workflow_run=commands.workflow_run,
+ organization_id=commands.organization_id,
+ )
+ raise
+
+ loops = [
+ asyncio.create_task(frontend_to_backend()),
+ ]
+
+ try:
+ await collect(loops)
+ except Exception:
+ LOG.exception(
+ "An exception occurred in loop channel stream.",
+ workflow_run=commands.workflow_run,
+ organization_id=commands.organization_id,
+ )
+ finally:
+ LOG.info(
+ "Closing the loop channel stream.",
+ workflow_run=commands.workflow_run,
+ organization_id=commands.organization_id,
+ )
+ await commands.close(reason="loop-channel-closed")
+
+
+@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
+async def workflow_run_commands(
+ websocket: WebSocket,
+ workflow_run_id: str,
+ apikey: str | None = None,
+ client_id: str | None = None,
+ token: str | None = None,
+) -> None:
+ LOG.info("Starting stream commands.", workflow_run_id=workflow_run_id)
+
+ organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
+
+ if not organization_id:
+ LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
+ return
+
+ if not client_id:
+ LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
+ return
+
+ commands: sc.CommandChannel
+ loops: list[asyncio.Task] = []
+
+ result = await get_commands_for_workflow_run(
+ client_id=client_id,
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ websocket=websocket,
+ )
+
+ if not result:
+ LOG.error(
+ "No streaming context found for the workflow run.",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ await websocket.close(code=1013)
+ return
+
+ commands, loops = result
+
+ try:
+ LOG.info(
+ "Starting command stream loops.",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ await collect(loops)
+ except Exception:
+ LOG.exception(
+ "An exception occurred in the command stream function.",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ finally:
+ LOG.info(
+ "Closing the command stream session.",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+
+ await commands.close(reason="stream-closed")
diff --git a/skyvern/forge/sdk/routes/streaming_verify.py b/skyvern/forge/sdk/routes/streaming_verify.py
new file mode 100644
index 00000000..21ed6296
--- /dev/null
+++ b/skyvern/forge/sdk/routes/streaming_verify.py
@@ -0,0 +1,207 @@
+import asyncio
+
+import structlog
+
+import skyvern.forge.sdk.routes.streaming_clients as sc
+from skyvern.config import settings
+from skyvern.forge import app
+from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
+from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
+from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
+
+LOG = structlog.get_logger()
+
+
+async def verify_task(
+ task_id: str, organization_id: str
+) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
+ """
+ Verify the task is running, and that it has a browser session associated
+ with it.
+ """
+
+ task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
+
+ if not task:
+ LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
+ return None, None
+
+ if task.status.is_final():
+ LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
+
+ return None, None
+
+ if task.status not in [TaskStatus.created, TaskStatus.queued, TaskStatus.running]:
+ LOG.info(
+ "Task is not created, queued, or running.",
+ task_status=task.status,
+ task_id=task_id,
+ organization_id=organization_id,
+ )
+
+ return None, None
+
+ browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
+ organization_id=organization_id,
+ runnable_id=task_id,
+ )
+
+ if not browser_session:
+ LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
+ return task, None
+
+ if not browser_session.browser_address:
+ LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
+ return task, None
+
+ try:
+ addressable_browser_session = AddressablePersistentBrowserSession(
+ **browser_session.model_dump() | {"browser_address": browser_session.browser_address},
+ )
+ except Exception as e:
+ LOG.error(
+ "streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
+ )
+ return task, None
+
+ return task, addressable_browser_session
+
+
+async def verify_workflow_run(
+ workflow_run_id: str,
+ organization_id: str,
+) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
+ """
+ Verify the workflow run is running, and that it has a browser session associated
+ with it.
+ """
+
+ if settings.ENV == "local":
+ from datetime import datetime
+
+ dummy_workflow_run = WorkflowRun(
+ workflow_id="123",
+ workflow_permanent_id="wpid_123",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ status=WorkflowRunStatus.running,
+ created_at=datetime.now(),
+ modified_at=datetime.now(),
+ )
+
+ dummy_browser_session = AddressablePersistentBrowserSession(
+ persistent_browser_session_id=workflow_run_id,
+ organization_id=organization_id,
+ browser_address="0.0.0.0:9223",
+ created_at=datetime.now(),
+ modified_at=datetime.now(),
+ )
+
+ return dummy_workflow_run, dummy_browser_session
+
+ workflow_run = await app.DATABASE.get_workflow_run(
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+
+ if not workflow_run:
+ LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
+ return None, None
+
+ if workflow_run.status.is_final():
+ LOG.info(
+ "Workflow run is in a final state. Closing connection.",
+ workflow_run_status=workflow_run.status,
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+
+ return None, None
+
+ if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]:
+ LOG.info(
+ "Workflow run is not running.",
+ workflow_run_status=workflow_run.status,
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+
+ return None, None
+
+ browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
+ organization_id=organization_id,
+ runnable_id=workflow_run_id,
+ )
+
+ if not browser_session:
+ LOG.info(
+ "No browser session found for workflow run.",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ return workflow_run, None
+
+ browser_address = browser_session.browser_address
+
+ if not browser_address:
+ LOG.info(
+ "Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
+ )
+
+ try:
+ _, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
+ session_id=browser_session.persistent_browser_session_id,
+ organization_id=organization_id,
+ )
+ browser_address = f"{host}:{cdp_port}"
+ except Exception as ex:
+ LOG.info(
+ "Browser session address not found for workflow run.",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ ex=ex,
+ )
+ return workflow_run, None
+
+ try:
+ addressable_browser_session = AddressablePersistentBrowserSession(
+ **browser_session.model_dump() | {"browser_address": browser_address},
+ )
+ except Exception:
+ return workflow_run, None
+
+ return workflow_run, addressable_browser_session
+
+
+async def loop_verify_task(streaming: sc.Streaming) -> None:
+ """
+ Loop until the task is cleared or the websocket is closed.
+ """
+
+ while streaming.task and streaming.is_open:
+ task, browser_session = await verify_task(
+ task_id=streaming.task.task_id,
+ organization_id=streaming.organization_id,
+ )
+
+ streaming.task = task
+ streaming.browser_session = browser_session
+
+ await asyncio.sleep(2)
+
+
+async def loop_verify_workflow_run(verifiable: sc.CommandChannel | sc.Streaming) -> None:
+ """
+ Loop until the workflow run is cleared or the websocket is closed.
+ """
+
+ while verifiable.workflow_run and verifiable.is_open:
+ workflow_run, browser_session = await verify_workflow_run(
+ workflow_run_id=verifiable.workflow_run.workflow_run_id,
+ organization_id=verifiable.organization_id,
+ )
+
+ verifiable.workflow_run = workflow_run
+ verifiable.browser_session = browser_session
+
+ await asyncio.sleep(2)
diff --git a/skyvern/forge/sdk/routes/streaming_vnc.py b/skyvern/forge/sdk/routes/streaming_vnc.py
index ac7712fb..0e065d13 100644
--- a/skyvern/forge/sdk/routes/streaming_vnc.py
+++ b/skyvern/forge/sdk/routes/streaming_vnc.py
@@ -1,233 +1,36 @@
+"""
+Streaming VNC WebSocket connections.
+"""
+
import asyncio
-import dataclasses
-import typing as t
-from enum import IntEnum
import structlog
import websockets
from fastapi import WebSocket, WebSocketDisconnect
-from starlette.websockets import WebSocketState
from websockets import Data
-from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
+from websockets.exceptions import ConnectionClosedError
+import skyvern.forge.sdk.routes.streaming_clients as sc
from skyvern.config import settings
-from skyvern.forge import app
from skyvern.forge.sdk.routes.routers import legacy_base_router
-from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
-from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
-from skyvern.forge.sdk.services.org_auth_service import get_current_org
+from skyvern.forge.sdk.routes.streaming_auth import auth
+from skyvern.forge.sdk.routes.streaming_verify import (
+ loop_verify_task,
+ loop_verify_workflow_run,
+ verify_task,
+ verify_workflow_run,
+)
from skyvern.forge.sdk.utils.aio import collect
-from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
-
-Interactor = t.Literal["agent", "user"]
-Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
-
-
-class MessageType(IntEnum):
- Keyboard = 4
- Mouse = 5
-
-
-class Keys:
- """
- VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
- """
-
- class Down:
- Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
- Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
- Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
- OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
-
- class Up:
- Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
- Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
- Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
-
-
-def is_rmb(data: bytes) -> bool:
- return data[0:2] == b"\x05\x04"
-
-
-class Mouse:
- class Up:
- Right = is_rmb
-
LOG = structlog.get_logger()
-@dataclasses.dataclass
-class KeyState:
- ctrl_is_down: bool = False
- alt_is_down: bool = False
- cmd_is_down: bool = False
- o_is_down: bool = False
-
- def is_forbidden(self, data: bytes) -> bool:
- """
- :return: True if the key is forbidden, else False
- """
- return self.is_ctrl_o(data)
-
- def is_ctrl_o(self, data: bytes) -> bool:
- """
- Do not allow the opening of files.
- """
- return self.ctrl_is_down and data == Keys.Down.OKey
-
-
-@dataclasses.dataclass
-class Streaming:
- """
- Streaming state.
- """
-
- interactor: Interactor
- """
- Whether the user or the agent are the interactor.
- """
-
- organization_id: str
- vnc_port: int
- websocket: WebSocket
-
- # --
-
- browser_session: AddressablePersistentBrowserSession | None = None
- key_state: KeyState = dataclasses.field(default_factory=KeyState)
- task: Task | None = None
- workflow_run: WorkflowRun | None = None
-
- @property
- def is_open(self) -> bool:
- if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
- return False
-
- if not self.task and not self.workflow_run:
- return False
-
- return True
-
- async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
- LOG.info("Closing Streaming.", reason=reason, code=code)
-
- self.browser_session = None
- self.task = None
- self.workflow_run = None
-
- try:
- await self.websocket.close(code=code, reason=reason)
- except Exception:
- pass
-
- return self
-
- def update_key_state(self, data: bytes) -> None:
- if data == Keys.Down.Ctrl:
- self.key_state.ctrl_is_down = True
- elif data == Keys.Up.Ctrl:
- self.key_state.ctrl_is_down = False
- elif data == Keys.Down.Alt:
- self.key_state.alt_is_down = True
- elif data == Keys.Up.Alt:
- self.key_state.alt_is_down = False
- elif data == Keys.Down.Cmd:
- self.key_state.cmd_is_down = True
- elif data == Keys.Up.Cmd:
- self.key_state.cmd_is_down = False
-
-
-async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
- """
- Accepts the websocket connection.
-
- Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
- determined.
- """
-
- try:
- await websocket.accept()
- if not token and not apikey:
- await websocket.close(code=1002)
- return None
- except ConnectionClosedOK:
- LOG.info("WebSocket connection closed cleanly.")
- return None
-
- try:
- organization = await get_current_org(x_api_key=apikey, authorization=token)
- organization_id = organization.organization_id
-
- if not organization_id:
- await websocket.close(code=1002)
- return None
- except Exception:
- LOG.exception("Error occurred while retrieving organization information.")
- try:
- await websocket.close(code=1002)
- except ConnectionClosedOK:
- LOG.info("WebSocket connection closed due to invalid credentials.")
- return None
-
- return organization_id
-
-
-async def verify_task(
- task_id: str, organization_id: str
-) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
- """
- Verify the task is running, and that it has a browser session associated
- with it.
- """
-
- task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
-
- if not task:
- LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
- return None, None
-
- if task.status.is_final():
- LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
-
- return None, None
-
- if not task.status == TaskStatus.running:
- LOG.info("Task is not running.", task_status=task.status, task_id=task_id, organization_id=organization_id)
-
- return None, None
-
- browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
- organization_id=organization_id,
- runnable_id=task_id, # is this correct; is there a task_run_id?
- )
-
- if not browser_session:
- LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
- return task, None
-
- if not browser_session.browser_address:
- LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
- return task, None
-
- try:
- addressable_browser_session = AddressablePersistentBrowserSession(
- **browser_session.model_dump() | {"browser_address": browser_session.browser_address},
- )
- except Exception as e:
- LOG.error(
- "streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
- )
- return task, None
-
- return task, addressable_browser_session
-
-
async def get_streaming_for_task(
+ client_id: str,
task_id: str,
organization_id: str,
websocket: WebSocket,
-) -> tuple[Streaming, Loops] | None:
+) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a task, with a list of loops to run concurrently.
"""
@@ -242,8 +45,9 @@ async def get_streaming_for_task(
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
return None
- streaming = Streaming(
- interactor="user",
+ streaming = sc.Streaming(
+ client_id=client_id,
+ interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
websocket=websocket,
@@ -260,10 +64,11 @@ async def get_streaming_for_task(
async def get_streaming_for_workflow_run(
+ client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
-) -> tuple[Streaming, Loops] | None:
+) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a workflow run, with a list of loops to run concurrently.
"""
@@ -287,8 +92,9 @@ async def get_streaming_for_workflow_run(
)
return None
- streaming = Streaming(
- interactor="user",
+ streaming = sc.Streaming(
+ client_id=client_id,
+ interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
browser_session=browser_session,
@@ -306,128 +112,7 @@ async def get_streaming_for_workflow_run(
return streaming, loops
-async def verify_workflow_run(
- workflow_run_id: str,
- organization_id: str,
-) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
- """
- Verify the workflow run is running, and that it has a browser session associated
- with it.
- """
-
- workflow_run = await app.DATABASE.get_workflow_run(
- workflow_run_id=workflow_run_id,
- organization_id=organization_id,
- )
-
- if not workflow_run:
- LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
- return None, None
-
- if workflow_run.status in [
- WorkflowRunStatus.completed,
- WorkflowRunStatus.failed,
- WorkflowRunStatus.terminated,
- ]:
- LOG.info(
- "Workflow run is in a final state. Closing connection.",
- workflow_run_status=workflow_run.status,
- workflow_run_id=workflow_run_id,
- organization_id=organization_id,
- )
-
- return None, None
-
- if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]:
- LOG.info(
- "Workflow run is not running.",
- workflow_run_status=workflow_run.status,
- workflow_run_id=workflow_run_id,
- organization_id=organization_id,
- )
-
- return None, None
-
- browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
- organization_id=organization_id,
- runnable_id=workflow_run_id,
- )
-
- if not browser_session:
- LOG.info(
- "No browser session found for workflow run.",
- workflow_run_id=workflow_run_id,
- organization_id=organization_id,
- )
- return workflow_run, None
-
- browser_address = browser_session.browser_address
-
- if not browser_address:
- LOG.info(
- "Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
- )
-
- try:
- _, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
- session_id=browser_session.persistent_browser_session_id,
- organization_id=organization_id,
- )
- browser_address = f"{host}:{cdp_port}"
- except Exception as ex:
- LOG.info(
- "Browser session address not found for workflow run.",
- workflow_run_id=workflow_run_id,
- organization_id=organization_id,
- ex=ex,
- )
- return workflow_run, None
-
- try:
- addressable_browser_session = AddressablePersistentBrowserSession(
- **browser_session.model_dump() | {"browser_address": browser_address},
- )
- except Exception:
- return workflow_run, None
-
- return workflow_run, addressable_browser_session
-
-
-async def loop_verify_task(streaming: Streaming) -> None:
- """
- Loop until the task is cleared or the websocket is closed.
- """
-
- while streaming.task and streaming.is_open:
- task, browser_session = await verify_task(
- task_id=streaming.task.task_id,
- organization_id=streaming.organization_id,
- )
-
- streaming.task = task
- streaming.browser_session = browser_session
-
- await asyncio.sleep(2)
-
-
-async def loop_verify_workflow_run(streaming: Streaming) -> None:
- """
- Loop until the workflow run is cleared or the websocket is closed.
- """
-
- while streaming.workflow_run and streaming.is_open:
- workflow_run, browser_session = await verify_workflow_run(
- workflow_run_id=streaming.workflow_run.workflow_run_id,
- organization_id=streaming.organization_id,
- )
-
- streaming.workflow_run = workflow_run
- streaming.browser_session = browser_session
-
- await asyncio.sleep(2)
-
-
-async def loop_stream_vnc(streaming: Streaming) -> None:
+async def loop_stream_vnc(streaming: sc.Streaming) -> None:
"""
Actually stream the VNC session data between a frontend and a browser
session.
@@ -465,19 +150,19 @@ async def loop_stream_vnc(streaming: Streaming) -> None:
if data:
message_type = data[0]
- if message_type == MessageType.Keyboard.value:
+ if message_type == sc.MessageType.Keyboard.value:
streaming.update_key_state(data)
if streaming.key_state.is_forbidden(data):
continue
- if message_type == MessageType.Mouse.value:
- if Mouse.Up.Right(data):
+ if message_type == sc.MessageType.Mouse.value:
+ if sc.Mouse.Up.Right(data):
continue
if not streaming.interactor == "user" and message_type in (
- MessageType.Keyboard.value,
- MessageType.Mouse.value,
+ sc.MessageType.Keyboard.value,
+ sc.MessageType.Mouse.value,
):
LOG.info(
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
@@ -615,9 +300,10 @@ async def task_stream(
websocket: WebSocket,
task_id: str,
apikey: str | None = None,
+ client_id: str | None = None,
token: str | None = None,
) -> None:
- await stream(websocket, apikey=apikey, task_id=task_id, token=token)
+ await stream(websocket, apikey=apikey, client_id=client_id, task_id=task_id, token=token)
@legacy_base_router.websocket("/stream/vnc/workflow_run/{workflow_run_id}")
@@ -625,32 +311,43 @@ async def workflow_run_stream(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
+ client_id: str | None = None,
token: str | None = None,
) -> None:
- await stream(websocket, apikey=apikey, workflow_run_id=workflow_run_id, token=token)
+ await stream(websocket, apikey=apikey, client_id=client_id, workflow_run_id=workflow_run_id, token=token)
async def stream(
websocket: WebSocket,
*,
apikey: str | None = None,
+ client_id: str | None = None,
task_id: str | None = None,
token: str | None = None,
workflow_run_id: str | None = None,
) -> None:
- LOG.info("Starting VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
+ if not client_id:
+ LOG.error("Client ID not provided for VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
+ return
+
+ LOG.info("Starting VNC stream.", client_id=client_id, task_id=task_id, workflow_run_id=workflow_run_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
- LOG.info("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
+ LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
return
- streaming: Streaming
+ streaming: sc.Streaming
loops: list[asyncio.Task] = []
if task_id:
- result = await get_streaming_for_task(task_id=task_id, organization_id=organization_id, websocket=websocket)
+ result = await get_streaming_for_task(
+ client_id=client_id,
+ task_id=task_id,
+ organization_id=organization_id,
+ websocket=websocket,
+ )
if not result:
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
@@ -666,6 +363,7 @@ async def stream(
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
result = await get_streaming_for_workflow_run(
+ client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
websocket=websocket,