WebSocket Command Channel (#2782)

This commit is contained in:
Shuchang Zheng
2025-06-25 02:37:26 +08:00
committed by GitHub
parent b8f560728b
commit 6b5699a98c
9 changed files with 938 additions and 384 deletions

View File

@@ -141,9 +141,7 @@ function WorkflowParametersPanel() {
</span>
) : (
<span className="text-sm text-slate-400">
{parameter.parameterType === "onepassword"
? "credential"
: parameter.parameterType}
{parameter.parameterType}
</span>
)}
</div>

View File

@@ -13,24 +13,40 @@ import { useQueryClient } from "@tanstack/react-query";
import RFB from "@novnc/novnc/lib/rfb.js";
import { environment } from "@/util/env";
import { cn } from "@/util/utils";
import { useClientIdStore } from "@/store/useClientIdStore";
import "./workflow-run-stream-vnc.css";
const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL;
interface CommandTakeControl {
kind: "take-control";
}
interface CommandCedeControl {
kind: "cede-control";
}
type Command = CommandTakeControl | CommandCedeControl;
function WorkflowRunStreamVnc() {
const { data: workflowRun } = useWorkflowRunQuery();
const { workflowRunId, workflowPermanentId } = useParams<{
workflowRunId: string;
workflowPermanentId: string;
}>();
const [commandSocket, setCommandSocket] = useState<WebSocket | null>(null);
const [userIsControlling, setUserIsControlling] = useState<boolean>(false);
const [vncDisconnectedTrigger, setVncDisconnectedTrigger] = useState(0);
const prevVncConnectedRef = useRef<boolean>(false);
const [isVncConnected, setIsVncConnected] = useState<boolean>(false);
const [commandDisconnectedTrigger, setCommandDisconnectedTrigger] =
useState(0);
const prevCommandConnectedRef = useRef<boolean>(false);
const [isCommandConnected, setIsCommandConnected] = useState<boolean>(false);
const showStream = workflowRun && statusIsNotFinalized(workflowRun);
const credentialGetter = useCredentialGetter();
const queryClient = useQueryClient();
const [canvasContainer, setCanvasContainer] = useState<HTMLDivElement | null>(
null,
@@ -38,10 +54,42 @@ function WorkflowRunStreamVnc() {
const setCanvasContainerRef = useCallback((node: HTMLDivElement | null) => {
setCanvasContainer(node);
}, []);
const rfbRef = useRef<RFB | null>(null);
const clientId = useClientIdStore((state) => state.clientId);
const credentialGetter = useCredentialGetter();
// effect for disconnects only
const getWebSocketParams = useCallback(async () => {
const clientIdQueryParam = `client_id=${clientId}`;
let credentialQueryParam = "";
if (environment === "local") {
credentialQueryParam = `apikey=${envCredential}`;
} else {
if (credentialGetter) {
const token = await credentialGetter();
credentialQueryParam = `token=Bearer ${token}`;
} else {
credentialQueryParam = `apikey=${envCredential}`;
}
}
const params = [credentialQueryParam, clientIdQueryParam].join("&");
return `${params}`;
}, [clientId, credentialGetter]);
const invalidateQueries = useCallback(() => {
queryClient.invalidateQueries({
queryKey: ["workflowRun", workflowPermanentId, workflowRunId],
});
queryClient.invalidateQueries({ queryKey: ["workflowRuns"] });
queryClient.invalidateQueries({
queryKey: ["workflowTasks", workflowRunId],
});
queryClient.invalidateQueries({ queryKey: ["runs"] });
}, [queryClient, workflowPermanentId, workflowRunId]);
// effect for vnc disconnects only
useEffect(() => {
if (prevVncConnectedRef.current && !isVncConnected) {
setVncDisconnectedTrigger((x) => x + 1);
@@ -49,6 +97,15 @@ function WorkflowRunStreamVnc() {
prevVncConnectedRef.current = isVncConnected;
}, [isVncConnected]);
// effect for command disconnects only
useEffect(() => {
if (prevCommandConnectedRef.current && !isCommandConnected) {
setCommandDisconnectedTrigger((x) => x + 1);
}
prevCommandConnectedRef.current = isCommandConnected;
}, [isCommandConnected]);
// vnc socket
useEffect(
() => {
if (!showStream || !canvasContainer || !workflowRunId) {
@@ -61,24 +118,12 @@ function WorkflowRunStreamVnc() {
}
async function setupVnc() {
let credentialQueryParam = "";
if (environment === "local") {
credentialQueryParam = `?apikey=${envCredential}`;
} else {
if (credentialGetter) {
const token = await credentialGetter();
credentialQueryParam = `?token=Bearer ${token}`;
} else {
credentialQueryParam = `?apikey=${envCredential}`;
}
}
if (rfbRef.current && isVncConnected) {
return;
}
const vncUrl = `${wssBaseUrl}/stream/vnc/workflow_run/${workflowRunId}${credentialQueryParam}`;
const wsParams = await getWebSocketParams();
const vncUrl = `${wssBaseUrl}/stream/vnc/workflow_run/${workflowRunId}?${wsParams}`;
if (rfbRef.current) {
rfbRef.current.disconnect();
@@ -102,15 +147,7 @@ function WorkflowRunStreamVnc() {
rfb.addEventListener("disconnect", async (/* e: RfbEvent */) => {
setIsVncConnected(false);
queryClient.invalidateQueries({
queryKey: ["workflowRun", workflowPermanentId, workflowRunId],
});
queryClient.invalidateQueries({ queryKey: ["workflowRuns"] });
queryClient.invalidateQueries({
queryKey: ["workflowTasks", workflowRunId],
});
queryClient.invalidateQueries({ queryKey: ["runs"] });
invalidateQueries();
});
}
@@ -127,16 +164,75 @@ function WorkflowRunStreamVnc() {
// cannot include isVncConnected in deps as it will cause infinite loop
// eslint-disable-next-line react-hooks/exhaustive-deps
[
credentialGetter,
workflowRunId,
workflowPermanentId,
showStream,
queryClient,
canvasContainer,
invalidateQueries,
showStream,
vncDisconnectedTrigger, // will re-run on disconnects
workflowRunId,
],
);
// command socket
useEffect(() => {
let ws: WebSocket | null = null;
const connect = async () => {
const wsParams = await getWebSocketParams();
const commandUrl = `${wssBaseUrl}/stream/commands/workflow_run/${workflowRunId}?${wsParams}`;
ws = new WebSocket(commandUrl);
ws.onopen = () => {
setIsCommandConnected(true);
setCommandSocket(ws);
};
ws.onclose = () => {
setIsCommandConnected(false);
invalidateQueries();
setCommandSocket(null);
};
};
connect();
return () => {
try {
ws && ws.close();
} catch (e) {
// pass
}
};
}, [
commandDisconnectedTrigger,
getWebSocketParams,
invalidateQueries,
workflowRunId,
]);
// effect to send a command when the user is controlling, vs not controlling
useEffect(() => {
if (!isCommandConnected) {
return;
}
const sendCommand = (command: Command) => {
if (!commandSocket) {
console.warn("Cannot send command, as command socket is closed.");
console.warn(command);
return;
}
commandSocket.send(JSON.stringify(command));
};
if (userIsControlling) {
sendCommand({ kind: "take-control" });
} else {
sendCommand({ kind: "cede-control" });
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [userIsControlling, isCommandConnected]);
// Effect to show toast when workflow reaches a final state based on hook updates
useEffect(() => {
if (workflowRun) {

View File

@@ -0,0 +1,11 @@
import { create } from "zustand";
type ClientIdStore = {
clientId: string;
};
const initialClientId = crypto.randomUUID();
export const useClientIdStore = create<ClientIdStore>(() => ({
clientId: initialClientId,
}));

View File

@@ -2,4 +2,5 @@ from skyvern.forge.sdk.routes import agent_protocol # noqa: F401
from skyvern.forge.sdk.routes import browser_sessions # noqa: F401
from skyvern.forge.sdk.routes import credentials # noqa: F401
from skyvern.forge.sdk.routes import streaming # noqa: F401
from skyvern.forge.sdk.routes import streaming_commands # noqa: F401
from skyvern.forge.sdk.routes import streaming_vnc # noqa: F401

View File

@@ -0,0 +1,46 @@
"""
Streaming auth.
"""
import structlog
from fastapi import WebSocket
from websockets.exceptions import ConnectionClosedOK
from skyvern.forge.sdk.services.org_auth_service import get_current_org
LOG = structlog.get_logger()
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
"""
Accepts the websocket connection.
Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
determined.
"""
try:
await websocket.accept()
if not token and not apikey:
await websocket.close(code=1002)
return None
except ConnectionClosedOK:
LOG.info("WebSocket connection closed cleanly.")
return None
try:
organization = await get_current_org(x_api_key=apikey, authorization=token)
organization_id = organization.organization_id
if not organization_id:
await websocket.close(code=1002)
return None
except Exception:
LOG.exception("Error occurred while retrieving organization information.")
try:
await websocket.close(code=1002)
except ConnectionClosedOK:
LOG.info("WebSocket connection closed due to invalid credentials.")
return None
return organization_id

View File

@@ -0,0 +1,270 @@
"""
Streaming types.
"""
import asyncio
import dataclasses
import typing as t
from enum import IntEnum
import structlog
from fastapi import WebSocket
from starlette.websockets import WebSocketState
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
LOG = structlog.get_logger()
Interactor = t.Literal["agent", "user"]
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
# Commands
# a global registry for WS command clients
command_channels: dict[str, "CommandChannel"] = {}
def add_command_client(command_channel: "CommandChannel") -> None:
command_channels[command_channel.client_id] = command_channel
def get_command_client(client_id: str) -> t.Union["CommandChannel", None]:
return command_channels.get(client_id, None)
def del_command_client(client_id: str) -> None:
try:
del command_channels[client_id]
except KeyError:
pass
@dataclasses.dataclass
class CommandChannel:
client_id: str
organization_id: str
websocket: WebSocket
# --
browser_session: AddressablePersistentBrowserSession | None = None
workflow_run: WorkflowRun | None = None
def __post_init__(self) -> None:
add_command_client(self)
async def close(self, code: int = 1000, reason: str | None = None) -> "CommandChannel":
LOG.info("Closing command stream.", reason=reason, code=code)
self.browser_session = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
del_command_client(self.client_id)
return self
@property
def is_open(self) -> bool:
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
return False
if not self.workflow_run:
return False
if not get_command_client(self.client_id):
return False
return True
CommandKinds = t.Literal["take-control", "cede-control"]
@dataclasses.dataclass
class Command:
kind: CommandKinds
@dataclasses.dataclass
class CommandTakeControl(Command):
kind: t.Literal["take-control"] = "take-control"
@dataclasses.dataclass
class CommandCedeControl(Command):
kind: t.Literal["cede-control"] = "cede-control"
ChannelCommand = t.Union[CommandTakeControl, CommandCedeControl]
def reify_channel_command(data: dict) -> ChannelCommand:
kind = data.get("kind", None)
match kind:
case "take-control":
return CommandTakeControl()
case "cede-control":
return CommandCedeControl()
case _:
raise ValueError(f"Unknown command kind: '{kind}'")
# Streaming
# a global registry for WS streaming VNC clients
streaming_clients: dict[str, "Streaming"] = {}
def add_streaming_client(streaming: "Streaming") -> None:
streaming_clients[streaming.client_id] = streaming
def get_streaming_client(client_id: str) -> t.Union["Streaming", None]:
return streaming_clients.get(client_id, None)
def del_streaming_client(client_id: str) -> None:
try:
del streaming_clients[client_id]
except KeyError:
pass
class MessageType(IntEnum):
Keyboard = 4
Mouse = 5
class Keys:
"""
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
"""
class Down:
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
class Up:
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
def is_rmb(data: bytes) -> bool:
return data[0:2] == b"\x05\x04"
class Mouse:
class Up:
Right = is_rmb
@dataclasses.dataclass
class KeyState:
ctrl_is_down: bool = False
alt_is_down: bool = False
cmd_is_down: bool = False
o_is_down: bool = False
def is_forbidden(self, data: bytes) -> bool:
"""
:return: True if the key is forbidden, else False
"""
return self.is_ctrl_o(data)
def is_ctrl_o(self, data: bytes) -> bool:
"""
Do not allow the opening of files.
"""
return self.ctrl_is_down and data == Keys.Down.OKey
@dataclasses.dataclass
class Streaming:
"""
Streaming state.
"""
client_id: str
"""
Unique to frontend app instance.
"""
interactor: Interactor
"""
Whether the user or the agent are the interactor.
"""
organization_id: str
vnc_port: int
websocket: WebSocket
# --
browser_session: AddressablePersistentBrowserSession | None = None
key_state: KeyState = dataclasses.field(default_factory=KeyState)
task: Task | None = None
workflow_run: WorkflowRun | None = None
def __post_init__(self) -> None:
add_streaming_client(self)
@property
def is_open(self) -> bool:
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
return False
if not self.task and not self.workflow_run:
return False
if not get_streaming_client(self.client_id):
return False
return True
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
LOG.info("Closing Streaming.", reason=reason, code=code)
self.browser_session = None
self.task = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
del_streaming_client(self.client_id)
return self
def update_key_state(self, data: bytes) -> None:
if data == Keys.Down.Ctrl:
self.key_state.ctrl_is_down = True
elif data == Keys.Up.Ctrl:
self.key_state.ctrl_is_down = False
elif data == Keys.Down.Alt:
self.key_state.alt_is_down = True
elif data == Keys.Up.Alt:
self.key_state.alt_is_down = False
elif data == Keys.Down.Cmd:
self.key_state.cmd_is_down = True
elif data == Keys.Up.Cmd:
self.key_state.cmd_is_down = False

View File

@@ -0,0 +1,227 @@
"""
Streaming commands WebSocket connections.
"""
import asyncio
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
import skyvern.forge.sdk.routes.streaming_clients as sc
from skyvern.forge.sdk.routes.routers import legacy_base_router
from skyvern.forge.sdk.routes.streaming_auth import auth
from skyvern.forge.sdk.routes.streaming_verify import loop_verify_workflow_run, verify_workflow_run
from skyvern.forge.sdk.utils.aio import collect
LOG = structlog.get_logger()
async def get_commands_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.CommandChannel, sc.Loops] | None:
"""
Return a commands channel for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting commands channel for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info(
"Command channel: no initial workflow run found.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
if not browser_session:
LOG.info(
"Command channel: no initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
commands = sc.CommandChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
workflow_run=workflow_run,
websocket=websocket,
)
LOG.info("Got command channel for workflow run.", commands=commands)
loops = [
asyncio.create_task(loop_verify_workflow_run(commands)),
asyncio.create_task(loop_channel(commands)),
]
return commands, loops
async def loop_channel(commands: sc.CommandChannel) -> None:
"""
Stream commands and their results back and forth.
Loops until the workflow run is cleared or the websocket is closed.
"""
if not commands.browser_session:
LOG.info(
"No browser session found for workflow run.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
)
return
async def frontend_to_backend() -> None:
LOG.info("Starting frontend-to-backend channel loop.", commands=commands)
while commands.is_open:
try:
data = await commands.websocket.receive_json()
if not isinstance(data, dict):
LOG.error(f"Cannot create channel command: expected dict, got {type(data)}")
continue
try:
command = sc.reify_channel_command(data)
except ValueError:
continue
command_kind = command.kind
match command_kind:
case "take-control":
streaming = sc.get_streaming_client(commands.client_id)
if not streaming:
LOG.error("No streaming client found for command.", commands=commands, command=command)
continue
streaming.interactor = "user"
case "cede-control":
streaming = sc.get_streaming_client(commands.client_id)
if not streaming:
LOG.error("No streaming client found for command.", commands=commands, command=command)
continue
streaming.interactor = "agent"
case _:
LOG.error(f"Unknown command kind: '{command_kind}'")
continue
except WebSocketDisconnect:
LOG.info(
"Frontend disconnected.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
)
raise
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
)
raise
loops = [
asyncio.create_task(frontend_to_backend()),
]
try:
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in loop channel stream.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
)
finally:
LOG.info(
"Closing the loop channel stream.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
)
await commands.close(reason="loop-channel-closed")
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
async def workflow_run_commands(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info("Starting stream commands.", workflow_run_id=workflow_run_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
return
if not client_id:
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
return
commands: sc.CommandChannel
loops: list[asyncio.Task] = []
result = await get_commands_for_workflow_run(
client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error(
"No streaming context found for the workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
commands, loops = result
try:
LOG.info(
"Starting command stream loops.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the command stream function.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the command stream session.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await commands.close(reason="stream-closed")

View File

@@ -0,0 +1,207 @@
import asyncio
import structlog
import skyvern.forge.sdk.routes.streaming_clients as sc
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
LOG = structlog.get_logger()
async def verify_task(
task_id: str, organization_id: str
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
"""
Verify the task is running, and that it has a browser session associated
with it.
"""
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
return None, None
if task.status.is_final():
LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
return None, None
if task.status not in [TaskStatus.created, TaskStatus.queued, TaskStatus.running]:
LOG.info(
"Task is not created, queued, or running.",
task_status=task.status,
task_id=task_id,
organization_id=organization_id,
)
return None, None
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
organization_id=organization_id,
runnable_id=task_id,
)
if not browser_session:
LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
return task, None
if not browser_session.browser_address:
LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
return task, None
try:
addressable_browser_session = AddressablePersistentBrowserSession(
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
)
except Exception as e:
LOG.error(
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
)
return task, None
return task, addressable_browser_session
async def verify_workflow_run(
workflow_run_id: str,
organization_id: str,
) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
"""
Verify the workflow run is running, and that it has a browser session associated
with it.
"""
if settings.ENV == "local":
from datetime import datetime
dummy_workflow_run = WorkflowRun(
workflow_id="123",
workflow_permanent_id="wpid_123",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
status=WorkflowRunStatus.running,
created_at=datetime.now(),
modified_at=datetime.now(),
)
dummy_browser_session = AddressablePersistentBrowserSession(
persistent_browser_session_id=workflow_run_id,
organization_id=organization_id,
browser_address="0.0.0.0:9223",
created_at=datetime.now(),
modified_at=datetime.now(),
)
return dummy_workflow_run, dummy_browser_session
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
return None, None
if workflow_run.status.is_final():
LOG.info(
"Workflow run is in a final state. Closing connection.",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None, None
if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]:
LOG.info(
"Workflow run is not running.",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None, None
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
organization_id=organization_id,
runnable_id=workflow_run_id,
)
if not browser_session:
LOG.info(
"No browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return workflow_run, None
browser_address = browser_session.browser_address
if not browser_address:
LOG.info(
"Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
try:
_, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
session_id=browser_session.persistent_browser_session_id,
organization_id=organization_id,
)
browser_address = f"{host}:{cdp_port}"
except Exception as ex:
LOG.info(
"Browser session address not found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
ex=ex,
)
return workflow_run, None
try:
addressable_browser_session = AddressablePersistentBrowserSession(
**browser_session.model_dump() | {"browser_address": browser_address},
)
except Exception:
return workflow_run, None
return workflow_run, addressable_browser_session
async def loop_verify_task(streaming: sc.Streaming) -> None:
"""
Loop until the task is cleared or the websocket is closed.
"""
while streaming.task and streaming.is_open:
task, browser_session = await verify_task(
task_id=streaming.task.task_id,
organization_id=streaming.organization_id,
)
streaming.task = task
streaming.browser_session = browser_session
await asyncio.sleep(2)
async def loop_verify_workflow_run(verifiable: sc.CommandChannel | sc.Streaming) -> None:
"""
Loop until the workflow run is cleared or the websocket is closed.
"""
while verifiable.workflow_run and verifiable.is_open:
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=verifiable.workflow_run.workflow_run_id,
organization_id=verifiable.organization_id,
)
verifiable.workflow_run = workflow_run
verifiable.browser_session = browser_session
await asyncio.sleep(2)

View File

@@ -1,233 +1,36 @@
"""
Streaming VNC WebSocket connections.
"""
import asyncio
import dataclasses
import typing as t
from enum import IntEnum
import structlog
import websockets
from fastapi import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
from websockets import Data
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from websockets.exceptions import ConnectionClosedError
import skyvern.forge.sdk.routes.streaming_clients as sc
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.routes.routers import legacy_base_router
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.services.org_auth_service import get_current_org
from skyvern.forge.sdk.routes.streaming_auth import auth
from skyvern.forge.sdk.routes.streaming_verify import (
loop_verify_task,
loop_verify_workflow_run,
verify_task,
verify_workflow_run,
)
from skyvern.forge.sdk.utils.aio import collect
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
Interactor = t.Literal["agent", "user"]
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
class MessageType(IntEnum):
Keyboard = 4
Mouse = 5
class Keys:
"""
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
"""
class Down:
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
class Up:
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
def is_rmb(data: bytes) -> bool:
return data[0:2] == b"\x05\x04"
class Mouse:
class Up:
Right = is_rmb
LOG = structlog.get_logger()
@dataclasses.dataclass
class KeyState:
ctrl_is_down: bool = False
alt_is_down: bool = False
cmd_is_down: bool = False
o_is_down: bool = False
def is_forbidden(self, data: bytes) -> bool:
"""
:return: True if the key is forbidden, else False
"""
return self.is_ctrl_o(data)
def is_ctrl_o(self, data: bytes) -> bool:
"""
Do not allow the opening of files.
"""
return self.ctrl_is_down and data == Keys.Down.OKey
@dataclasses.dataclass
class Streaming:
"""
Streaming state.
"""
interactor: Interactor
"""
Whether the user or the agent are the interactor.
"""
organization_id: str
vnc_port: int
websocket: WebSocket
# --
browser_session: AddressablePersistentBrowserSession | None = None
key_state: KeyState = dataclasses.field(default_factory=KeyState)
task: Task | None = None
workflow_run: WorkflowRun | None = None
@property
def is_open(self) -> bool:
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
return False
if not self.task and not self.workflow_run:
return False
return True
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
LOG.info("Closing Streaming.", reason=reason, code=code)
self.browser_session = None
self.task = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
return self
def update_key_state(self, data: bytes) -> None:
if data == Keys.Down.Ctrl:
self.key_state.ctrl_is_down = True
elif data == Keys.Up.Ctrl:
self.key_state.ctrl_is_down = False
elif data == Keys.Down.Alt:
self.key_state.alt_is_down = True
elif data == Keys.Up.Alt:
self.key_state.alt_is_down = False
elif data == Keys.Down.Cmd:
self.key_state.cmd_is_down = True
elif data == Keys.Up.Cmd:
self.key_state.cmd_is_down = False
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
"""
Accepts the websocket connection.
Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
determined.
"""
try:
await websocket.accept()
if not token and not apikey:
await websocket.close(code=1002)
return None
except ConnectionClosedOK:
LOG.info("WebSocket connection closed cleanly.")
return None
try:
organization = await get_current_org(x_api_key=apikey, authorization=token)
organization_id = organization.organization_id
if not organization_id:
await websocket.close(code=1002)
return None
except Exception:
LOG.exception("Error occurred while retrieving organization information.")
try:
await websocket.close(code=1002)
except ConnectionClosedOK:
LOG.info("WebSocket connection closed due to invalid credentials.")
return None
return organization_id
async def verify_task(
task_id: str, organization_id: str
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
"""
Verify the task is running, and that it has a browser session associated
with it.
"""
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
return None, None
if task.status.is_final():
LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
return None, None
if not task.status == TaskStatus.running:
LOG.info("Task is not running.", task_status=task.status, task_id=task_id, organization_id=organization_id)
return None, None
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
organization_id=organization_id,
runnable_id=task_id, # is this correct; is there a task_run_id?
)
if not browser_session:
LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
return task, None
if not browser_session.browser_address:
LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
return task, None
try:
addressable_browser_session = AddressablePersistentBrowserSession(
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
)
except Exception as e:
LOG.error(
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
)
return task, None
return task, addressable_browser_session
async def get_streaming_for_task(
client_id: str,
task_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[Streaming, Loops] | None:
) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a task, with a list of loops to run concurrently.
"""
@@ -242,8 +45,9 @@ async def get_streaming_for_task(
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
return None
streaming = Streaming(
interactor="user",
streaming = sc.Streaming(
client_id=client_id,
interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
websocket=websocket,
@@ -260,10 +64,11 @@ async def get_streaming_for_task(
async def get_streaming_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[Streaming, Loops] | None:
) -> tuple[sc.Streaming, sc.Loops] | None:
"""
Return a streaming context for a workflow run, with a list of loops to run concurrently.
"""
@@ -287,8 +92,9 @@ async def get_streaming_for_workflow_run(
)
return None
streaming = Streaming(
interactor="user",
streaming = sc.Streaming(
client_id=client_id,
interactor="agent",
organization_id=organization_id,
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
browser_session=browser_session,
@@ -306,128 +112,7 @@ async def get_streaming_for_workflow_run(
return streaming, loops
async def verify_workflow_run(
workflow_run_id: str,
organization_id: str,
) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
"""
Verify the workflow run is running, and that it has a browser session associated
with it.
"""
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
return None, None
if workflow_run.status in [
WorkflowRunStatus.completed,
WorkflowRunStatus.failed,
WorkflowRunStatus.terminated,
]:
LOG.info(
"Workflow run is in a final state. Closing connection.",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None, None
if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]:
LOG.info(
"Workflow run is not running.",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None, None
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
organization_id=organization_id,
runnable_id=workflow_run_id,
)
if not browser_session:
LOG.info(
"No browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return workflow_run, None
browser_address = browser_session.browser_address
if not browser_address:
LOG.info(
"Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
try:
_, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
session_id=browser_session.persistent_browser_session_id,
organization_id=organization_id,
)
browser_address = f"{host}:{cdp_port}"
except Exception as ex:
LOG.info(
"Browser session address not found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
ex=ex,
)
return workflow_run, None
try:
addressable_browser_session = AddressablePersistentBrowserSession(
**browser_session.model_dump() | {"browser_address": browser_address},
)
except Exception:
return workflow_run, None
return workflow_run, addressable_browser_session
async def loop_verify_task(streaming: Streaming) -> None:
"""
Loop until the task is cleared or the websocket is closed.
"""
while streaming.task and streaming.is_open:
task, browser_session = await verify_task(
task_id=streaming.task.task_id,
organization_id=streaming.organization_id,
)
streaming.task = task
streaming.browser_session = browser_session
await asyncio.sleep(2)
async def loop_verify_workflow_run(streaming: Streaming) -> None:
"""
Loop until the workflow run is cleared or the websocket is closed.
"""
while streaming.workflow_run and streaming.is_open:
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=streaming.workflow_run.workflow_run_id,
organization_id=streaming.organization_id,
)
streaming.workflow_run = workflow_run
streaming.browser_session = browser_session
await asyncio.sleep(2)
async def loop_stream_vnc(streaming: Streaming) -> None:
async def loop_stream_vnc(streaming: sc.Streaming) -> None:
"""
Actually stream the VNC session data between a frontend and a browser
session.
@@ -465,19 +150,19 @@ async def loop_stream_vnc(streaming: Streaming) -> None:
if data:
message_type = data[0]
if message_type == MessageType.Keyboard.value:
if message_type == sc.MessageType.Keyboard.value:
streaming.update_key_state(data)
if streaming.key_state.is_forbidden(data):
continue
if message_type == MessageType.Mouse.value:
if Mouse.Up.Right(data):
if message_type == sc.MessageType.Mouse.value:
if sc.Mouse.Up.Right(data):
continue
if not streaming.interactor == "user" and message_type in (
MessageType.Keyboard.value,
MessageType.Mouse.value,
sc.MessageType.Keyboard.value,
sc.MessageType.Mouse.value,
):
LOG.info(
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
@@ -615,9 +300,10 @@ async def task_stream(
websocket: WebSocket,
task_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
await stream(websocket, apikey=apikey, task_id=task_id, token=token)
await stream(websocket, apikey=apikey, client_id=client_id, task_id=task_id, token=token)
@legacy_base_router.websocket("/stream/vnc/workflow_run/{workflow_run_id}")
@@ -625,32 +311,43 @@ async def workflow_run_stream(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
await stream(websocket, apikey=apikey, workflow_run_id=workflow_run_id, token=token)
await stream(websocket, apikey=apikey, client_id=client_id, workflow_run_id=workflow_run_id, token=token)
async def stream(
websocket: WebSocket,
*,
apikey: str | None = None,
client_id: str | None = None,
task_id: str | None = None,
token: str | None = None,
workflow_run_id: str | None = None,
) -> None:
LOG.info("Starting VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
if not client_id:
LOG.error("Client ID not provided for VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
return
LOG.info("Starting VNC stream.", client_id=client_id, task_id=task_id, workflow_run_id=workflow_run_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.info("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
return
streaming: Streaming
streaming: sc.Streaming
loops: list[asyncio.Task] = []
if task_id:
result = await get_streaming_for_task(task_id=task_id, organization_id=organization_id, websocket=websocket)
result = await get_streaming_for_task(
client_id=client_id,
task_id=task_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
@@ -666,6 +363,7 @@ async def stream(
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
result = await get_streaming_for_workflow_run(
client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
websocket=websocket,