WebSocket Command Channel (#2782)
This commit is contained in:
@@ -141,9 +141,7 @@ function WorkflowParametersPanel() {
|
||||
</span>
|
||||
) : (
|
||||
<span className="text-sm text-slate-400">
|
||||
{parameter.parameterType === "onepassword"
|
||||
? "credential"
|
||||
: parameter.parameterType}
|
||||
{parameter.parameterType}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -13,24 +13,40 @@ import { useQueryClient } from "@tanstack/react-query";
|
||||
import RFB from "@novnc/novnc/lib/rfb.js";
|
||||
import { environment } from "@/util/env";
|
||||
import { cn } from "@/util/utils";
|
||||
import { useClientIdStore } from "@/store/useClientIdStore";
|
||||
|
||||
import "./workflow-run-stream-vnc.css";
|
||||
|
||||
const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL;
|
||||
|
||||
interface CommandTakeControl {
|
||||
kind: "take-control";
|
||||
}
|
||||
|
||||
interface CommandCedeControl {
|
||||
kind: "cede-control";
|
||||
}
|
||||
|
||||
type Command = CommandTakeControl | CommandCedeControl;
|
||||
|
||||
function WorkflowRunStreamVnc() {
|
||||
const { data: workflowRun } = useWorkflowRunQuery();
|
||||
|
||||
const { workflowRunId, workflowPermanentId } = useParams<{
|
||||
workflowRunId: string;
|
||||
workflowPermanentId: string;
|
||||
}>();
|
||||
|
||||
const [commandSocket, setCommandSocket] = useState<WebSocket | null>(null);
|
||||
const [userIsControlling, setUserIsControlling] = useState<boolean>(false);
|
||||
const [vncDisconnectedTrigger, setVncDisconnectedTrigger] = useState(0);
|
||||
const prevVncConnectedRef = useRef<boolean>(false);
|
||||
const [isVncConnected, setIsVncConnected] = useState<boolean>(false);
|
||||
const [commandDisconnectedTrigger, setCommandDisconnectedTrigger] =
|
||||
useState(0);
|
||||
const prevCommandConnectedRef = useRef<boolean>(false);
|
||||
const [isCommandConnected, setIsCommandConnected] = useState<boolean>(false);
|
||||
const showStream = workflowRun && statusIsNotFinalized(workflowRun);
|
||||
const credentialGetter = useCredentialGetter();
|
||||
const queryClient = useQueryClient();
|
||||
const [canvasContainer, setCanvasContainer] = useState<HTMLDivElement | null>(
|
||||
null,
|
||||
@@ -38,10 +54,42 @@ function WorkflowRunStreamVnc() {
|
||||
const setCanvasContainerRef = useCallback((node: HTMLDivElement | null) => {
|
||||
setCanvasContainer(node);
|
||||
}, []);
|
||||
|
||||
const rfbRef = useRef<RFB | null>(null);
|
||||
const clientId = useClientIdStore((state) => state.clientId);
|
||||
const credentialGetter = useCredentialGetter();
|
||||
|
||||
// effect for disconnects only
|
||||
const getWebSocketParams = useCallback(async () => {
|
||||
const clientIdQueryParam = `client_id=${clientId}`;
|
||||
let credentialQueryParam = "";
|
||||
|
||||
if (environment === "local") {
|
||||
credentialQueryParam = `apikey=${envCredential}`;
|
||||
} else {
|
||||
if (credentialGetter) {
|
||||
const token = await credentialGetter();
|
||||
credentialQueryParam = `token=Bearer ${token}`;
|
||||
} else {
|
||||
credentialQueryParam = `apikey=${envCredential}`;
|
||||
}
|
||||
}
|
||||
|
||||
const params = [credentialQueryParam, clientIdQueryParam].join("&");
|
||||
|
||||
return `${params}`;
|
||||
}, [clientId, credentialGetter]);
|
||||
|
||||
const invalidateQueries = useCallback(() => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["workflowRun", workflowPermanentId, workflowRunId],
|
||||
});
|
||||
queryClient.invalidateQueries({ queryKey: ["workflowRuns"] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["workflowTasks", workflowRunId],
|
||||
});
|
||||
queryClient.invalidateQueries({ queryKey: ["runs"] });
|
||||
}, [queryClient, workflowPermanentId, workflowRunId]);
|
||||
|
||||
// effect for vnc disconnects only
|
||||
useEffect(() => {
|
||||
if (prevVncConnectedRef.current && !isVncConnected) {
|
||||
setVncDisconnectedTrigger((x) => x + 1);
|
||||
@@ -49,6 +97,15 @@ function WorkflowRunStreamVnc() {
|
||||
prevVncConnectedRef.current = isVncConnected;
|
||||
}, [isVncConnected]);
|
||||
|
||||
// effect for command disconnects only
|
||||
useEffect(() => {
|
||||
if (prevCommandConnectedRef.current && !isCommandConnected) {
|
||||
setCommandDisconnectedTrigger((x) => x + 1);
|
||||
}
|
||||
prevCommandConnectedRef.current = isCommandConnected;
|
||||
}, [isCommandConnected]);
|
||||
|
||||
// vnc socket
|
||||
useEffect(
|
||||
() => {
|
||||
if (!showStream || !canvasContainer || !workflowRunId) {
|
||||
@@ -61,24 +118,12 @@ function WorkflowRunStreamVnc() {
|
||||
}
|
||||
|
||||
async function setupVnc() {
|
||||
let credentialQueryParam = "";
|
||||
|
||||
if (environment === "local") {
|
||||
credentialQueryParam = `?apikey=${envCredential}`;
|
||||
} else {
|
||||
if (credentialGetter) {
|
||||
const token = await credentialGetter();
|
||||
credentialQueryParam = `?token=Bearer ${token}`;
|
||||
} else {
|
||||
credentialQueryParam = `?apikey=${envCredential}`;
|
||||
}
|
||||
}
|
||||
|
||||
if (rfbRef.current && isVncConnected) {
|
||||
return;
|
||||
}
|
||||
|
||||
const vncUrl = `${wssBaseUrl}/stream/vnc/workflow_run/${workflowRunId}${credentialQueryParam}`;
|
||||
const wsParams = await getWebSocketParams();
|
||||
const vncUrl = `${wssBaseUrl}/stream/vnc/workflow_run/${workflowRunId}?${wsParams}`;
|
||||
|
||||
if (rfbRef.current) {
|
||||
rfbRef.current.disconnect();
|
||||
@@ -102,15 +147,7 @@ function WorkflowRunStreamVnc() {
|
||||
|
||||
rfb.addEventListener("disconnect", async (/* e: RfbEvent */) => {
|
||||
setIsVncConnected(false);
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["workflowRun", workflowPermanentId, workflowRunId],
|
||||
});
|
||||
queryClient.invalidateQueries({ queryKey: ["workflowRuns"] });
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["workflowTasks", workflowRunId],
|
||||
});
|
||||
queryClient.invalidateQueries({ queryKey: ["runs"] });
|
||||
invalidateQueries();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -127,16 +164,75 @@ function WorkflowRunStreamVnc() {
|
||||
// cannot include isVncConnected in deps as it will cause infinite loop
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[
|
||||
credentialGetter,
|
||||
workflowRunId,
|
||||
workflowPermanentId,
|
||||
showStream,
|
||||
queryClient,
|
||||
canvasContainer,
|
||||
invalidateQueries,
|
||||
showStream,
|
||||
vncDisconnectedTrigger, // will re-run on disconnects
|
||||
workflowRunId,
|
||||
],
|
||||
);
|
||||
|
||||
// command socket
|
||||
useEffect(() => {
|
||||
let ws: WebSocket | null = null;
|
||||
|
||||
const connect = async () => {
|
||||
const wsParams = await getWebSocketParams();
|
||||
const commandUrl = `${wssBaseUrl}/stream/commands/workflow_run/${workflowRunId}?${wsParams}`;
|
||||
ws = new WebSocket(commandUrl);
|
||||
|
||||
ws.onopen = () => {
|
||||
setIsCommandConnected(true);
|
||||
setCommandSocket(ws);
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
setIsCommandConnected(false);
|
||||
invalidateQueries();
|
||||
setCommandSocket(null);
|
||||
};
|
||||
};
|
||||
|
||||
connect();
|
||||
|
||||
return () => {
|
||||
try {
|
||||
ws && ws.close();
|
||||
} catch (e) {
|
||||
// pass
|
||||
}
|
||||
};
|
||||
}, [
|
||||
commandDisconnectedTrigger,
|
||||
getWebSocketParams,
|
||||
invalidateQueries,
|
||||
workflowRunId,
|
||||
]);
|
||||
|
||||
// effect to send a command when the user is controlling, vs not controlling
|
||||
useEffect(() => {
|
||||
if (!isCommandConnected) {
|
||||
return;
|
||||
}
|
||||
|
||||
const sendCommand = (command: Command) => {
|
||||
if (!commandSocket) {
|
||||
console.warn("Cannot send command, as command socket is closed.");
|
||||
console.warn(command);
|
||||
return;
|
||||
}
|
||||
|
||||
commandSocket.send(JSON.stringify(command));
|
||||
};
|
||||
|
||||
if (userIsControlling) {
|
||||
sendCommand({ kind: "take-control" });
|
||||
} else {
|
||||
sendCommand({ kind: "cede-control" });
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [userIsControlling, isCommandConnected]);
|
||||
|
||||
// Effect to show toast when workflow reaches a final state based on hook updates
|
||||
useEffect(() => {
|
||||
if (workflowRun) {
|
||||
|
||||
11
skyvern-frontend/src/store/useClientIdStore.ts
Normal file
11
skyvern-frontend/src/store/useClientIdStore.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { create } from "zustand";
|
||||
|
||||
type ClientIdStore = {
|
||||
clientId: string;
|
||||
};
|
||||
|
||||
const initialClientId = crypto.randomUUID();
|
||||
|
||||
export const useClientIdStore = create<ClientIdStore>(() => ({
|
||||
clientId: initialClientId,
|
||||
}));
|
||||
@@ -2,4 +2,5 @@ from skyvern.forge.sdk.routes import agent_protocol # noqa: F401
|
||||
from skyvern.forge.sdk.routes import browser_sessions # noqa: F401
|
||||
from skyvern.forge.sdk.routes import credentials # noqa: F401
|
||||
from skyvern.forge.sdk.routes import streaming # noqa: F401
|
||||
from skyvern.forge.sdk.routes import streaming_commands # noqa: F401
|
||||
from skyvern.forge.sdk.routes import streaming_vnc # noqa: F401
|
||||
|
||||
46
skyvern/forge/sdk/routes/streaming_auth.py
Normal file
46
skyvern/forge/sdk/routes/streaming_auth.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Streaming auth.
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from skyvern.forge.sdk.services.org_auth_service import get_current_org
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
||||
"""
|
||||
Accepts the websocket connection.
|
||||
|
||||
Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
|
||||
determined.
|
||||
"""
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
if not token and not apikey:
|
||||
await websocket.close(code=1002)
|
||||
return None
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WebSocket connection closed cleanly.")
|
||||
return None
|
||||
|
||||
try:
|
||||
organization = await get_current_org(x_api_key=apikey, authorization=token)
|
||||
organization_id = organization.organization_id
|
||||
|
||||
if not organization_id:
|
||||
await websocket.close(code=1002)
|
||||
return None
|
||||
except Exception:
|
||||
LOG.exception("Error occurred while retrieving organization information.")
|
||||
try:
|
||||
await websocket.close(code=1002)
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WebSocket connection closed due to invalid credentials.")
|
||||
return None
|
||||
|
||||
return organization_id
|
||||
270
skyvern/forge/sdk/routes/streaming_clients.py
Normal file
270
skyvern/forge/sdk/routes/streaming_clients.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Streaming types.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
Interactor = t.Literal["agent", "user"]
|
||||
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||
|
||||
|
||||
# Commands
|
||||
|
||||
|
||||
# a global registry for WS command clients
|
||||
command_channels: dict[str, "CommandChannel"] = {}
|
||||
|
||||
|
||||
def add_command_client(command_channel: "CommandChannel") -> None:
|
||||
command_channels[command_channel.client_id] = command_channel
|
||||
|
||||
|
||||
def get_command_client(client_id: str) -> t.Union["CommandChannel", None]:
|
||||
return command_channels.get(client_id, None)
|
||||
|
||||
|
||||
def del_command_client(client_id: str) -> None:
|
||||
try:
|
||||
del command_channels[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommandChannel:
|
||||
client_id: str
|
||||
organization_id: str
|
||||
websocket: WebSocket
|
||||
|
||||
# --
|
||||
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
add_command_client(self)
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "CommandChannel":
|
||||
LOG.info("Closing command stream.", reason=reason, code=code)
|
||||
|
||||
self.browser_session = None
|
||||
self.workflow_run = None
|
||||
|
||||
try:
|
||||
await self.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
del_command_client(self.client_id)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||
return False
|
||||
|
||||
if not self.workflow_run:
|
||||
return False
|
||||
|
||||
if not get_command_client(self.client_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
CommandKinds = t.Literal["take-control", "cede-control"]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Command:
|
||||
kind: CommandKinds
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommandTakeControl(Command):
|
||||
kind: t.Literal["take-control"] = "take-control"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommandCedeControl(Command):
|
||||
kind: t.Literal["cede-control"] = "cede-control"
|
||||
|
||||
|
||||
ChannelCommand = t.Union[CommandTakeControl, CommandCedeControl]
|
||||
|
||||
|
||||
def reify_channel_command(data: dict) -> ChannelCommand:
|
||||
kind = data.get("kind", None)
|
||||
|
||||
match kind:
|
||||
case "take-control":
|
||||
return CommandTakeControl()
|
||||
case "cede-control":
|
||||
return CommandCedeControl()
|
||||
case _:
|
||||
raise ValueError(f"Unknown command kind: '{kind}'")
|
||||
|
||||
|
||||
# Streaming
|
||||
|
||||
|
||||
# a global registry for WS streaming VNC clients
|
||||
streaming_clients: dict[str, "Streaming"] = {}
|
||||
|
||||
|
||||
def add_streaming_client(streaming: "Streaming") -> None:
|
||||
streaming_clients[streaming.client_id] = streaming
|
||||
|
||||
|
||||
def get_streaming_client(client_id: str) -> t.Union["Streaming", None]:
|
||||
return streaming_clients.get(client_id, None)
|
||||
|
||||
|
||||
def del_streaming_client(client_id: str) -> None:
|
||||
try:
|
||||
del streaming_clients[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
class MessageType(IntEnum):
|
||||
Keyboard = 4
|
||||
Mouse = 5
|
||||
|
||||
|
||||
class Keys:
|
||||
"""
|
||||
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
|
||||
|
||||
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
|
||||
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
|
||||
"""
|
||||
|
||||
class Down:
|
||||
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
|
||||
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
|
||||
|
||||
class Up:
|
||||
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
|
||||
|
||||
|
||||
def is_rmb(data: bytes) -> bool:
|
||||
return data[0:2] == b"\x05\x04"
|
||||
|
||||
|
||||
class Mouse:
|
||||
class Up:
|
||||
Right = is_rmb
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KeyState:
|
||||
ctrl_is_down: bool = False
|
||||
alt_is_down: bool = False
|
||||
cmd_is_down: bool = False
|
||||
o_is_down: bool = False
|
||||
|
||||
def is_forbidden(self, data: bytes) -> bool:
|
||||
"""
|
||||
:return: True if the key is forbidden, else False
|
||||
"""
|
||||
return self.is_ctrl_o(data)
|
||||
|
||||
def is_ctrl_o(self, data: bytes) -> bool:
|
||||
"""
|
||||
Do not allow the opening of files.
|
||||
"""
|
||||
return self.ctrl_is_down and data == Keys.Down.OKey
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Streaming:
|
||||
"""
|
||||
Streaming state.
|
||||
"""
|
||||
|
||||
client_id: str
|
||||
"""
|
||||
Unique to frontend app instance.
|
||||
"""
|
||||
|
||||
interactor: Interactor
|
||||
"""
|
||||
Whether the user or the agent are the interactor.
|
||||
"""
|
||||
|
||||
organization_id: str
|
||||
vnc_port: int
|
||||
websocket: WebSocket
|
||||
|
||||
# --
|
||||
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
key_state: KeyState = dataclasses.field(default_factory=KeyState)
|
||||
task: Task | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
add_streaming_client(self)
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||
return False
|
||||
|
||||
if not self.task and not self.workflow_run:
|
||||
return False
|
||||
|
||||
if not get_streaming_client(self.client_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
|
||||
LOG.info("Closing Streaming.", reason=reason, code=code)
|
||||
|
||||
self.browser_session = None
|
||||
self.task = None
|
||||
self.workflow_run = None
|
||||
|
||||
try:
|
||||
await self.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
del_streaming_client(self.client_id)
|
||||
|
||||
return self
|
||||
|
||||
def update_key_state(self, data: bytes) -> None:
|
||||
if data == Keys.Down.Ctrl:
|
||||
self.key_state.ctrl_is_down = True
|
||||
elif data == Keys.Up.Ctrl:
|
||||
self.key_state.ctrl_is_down = False
|
||||
elif data == Keys.Down.Alt:
|
||||
self.key_state.alt_is_down = True
|
||||
elif data == Keys.Up.Alt:
|
||||
self.key_state.alt_is_down = False
|
||||
elif data == Keys.Down.Cmd:
|
||||
self.key_state.cmd_is_down = True
|
||||
elif data == Keys.Up.Cmd:
|
||||
self.key_state.cmd_is_down = False
|
||||
227
skyvern/forge/sdk/routes/streaming_commands.py
Normal file
227
skyvern/forge/sdk/routes/streaming_commands.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Streaming commands WebSocket connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||
from skyvern.forge.sdk.routes.routers import legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming_auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming_verify import loop_verify_workflow_run, verify_workflow_run
|
||||
from skyvern.forge.sdk.utils.aio import collect
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_commands_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.CommandChannel, sc.Loops] | None:
|
||||
"""
|
||||
Return a commands channel for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting commands channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info(
|
||||
"Command channel: no initial workflow run found.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Command channel: no initial browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
commands = sc.CommandChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
workflow_run=workflow_run,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got command channel for workflow run.", commands=commands)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_workflow_run(commands)),
|
||||
asyncio.create_task(loop_channel(commands)),
|
||||
]
|
||||
|
||||
return commands, loops
|
||||
|
||||
|
||||
async def loop_channel(commands: sc.CommandChannel) -> None:
|
||||
"""
|
||||
Stream commands and their results back and forth.
|
||||
|
||||
Loops until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
if not commands.browser_session:
|
||||
LOG.info(
|
||||
"No browser session found for workflow run.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
async def frontend_to_backend() -> None:
|
||||
LOG.info("Starting frontend-to-backend channel loop.", commands=commands)
|
||||
|
||||
while commands.is_open:
|
||||
try:
|
||||
data = await commands.websocket.receive_json()
|
||||
|
||||
if not isinstance(data, dict):
|
||||
LOG.error(f"Cannot create channel command: expected dict, got {type(data)}")
|
||||
continue
|
||||
|
||||
try:
|
||||
command = sc.reify_channel_command(data)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
command_kind = command.kind
|
||||
|
||||
match command_kind:
|
||||
case "take-control":
|
||||
streaming = sc.get_streaming_client(commands.client_id)
|
||||
if not streaming:
|
||||
LOG.error("No streaming client found for command.", commands=commands, command=command)
|
||||
continue
|
||||
streaming.interactor = "user"
|
||||
case "cede-control":
|
||||
streaming = sc.get_streaming_client(commands.client_id)
|
||||
if not streaming:
|
||||
LOG.error("No streaming client found for command.", commands=commands, command=command)
|
||||
continue
|
||||
streaming.interactor = "agent"
|
||||
case _:
|
||||
LOG.error(f"Unknown command kind: '{command_kind}'")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Frontend disconnected.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Frontend closed the streaming session.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(frontend_to_backend()),
|
||||
]
|
||||
|
||||
try:
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in loop channel stream.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the loop channel stream.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
)
|
||||
await commands.close(reason="loop-channel-closed")
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
|
||||
async def workflow_run_commands(
|
||||
websocket: WebSocket,
|
||||
workflow_run_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting stream commands.", workflow_run_id=workflow_run_id)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
if not client_id:
|
||||
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
commands: sc.CommandChannel
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
result = await get_commands_for_workflow_run(
|
||||
client_id=client_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
commands, loops = result
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting command stream loops.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the command stream function.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the command stream session.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await commands.close(reason="stream-closed")
|
||||
207
skyvern/forge/sdk/routes/streaming_verify.py
Normal file
207
skyvern/forge/sdk/routes/streaming_verify.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def verify_task(
|
||||
task_id: str, organization_id: str
|
||||
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
|
||||
"""
|
||||
Verify the task is running, and that it has a browser session associated
|
||||
with it.
|
||||
"""
|
||||
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
|
||||
if not task:
|
||||
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
|
||||
return None, None
|
||||
|
||||
if task.status.is_final():
|
||||
LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
|
||||
|
||||
return None, None
|
||||
|
||||
if task.status not in [TaskStatus.created, TaskStatus.queued, TaskStatus.running]:
|
||||
LOG.info(
|
||||
"Task is not created, queued, or running.",
|
||||
task_status=task.status,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
|
||||
organization_id=organization_id,
|
||||
runnable_id=task_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return task, None
|
||||
|
||||
if not browser_session.browser_address:
|
||||
LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return task, None
|
||||
|
||||
try:
|
||||
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
|
||||
)
|
||||
return task, None
|
||||
|
||||
return task, addressable_browser_session
|
||||
|
||||
|
||||
async def verify_workflow_run(
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
|
||||
"""
|
||||
Verify the workflow run is running, and that it has a browser session associated
|
||||
with it.
|
||||
"""
|
||||
|
||||
if settings.ENV == "local":
|
||||
from datetime import datetime
|
||||
|
||||
dummy_workflow_run = WorkflowRun(
|
||||
workflow_id="123",
|
||||
workflow_permanent_id="wpid_123",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
status=WorkflowRunStatus.running,
|
||||
created_at=datetime.now(),
|
||||
modified_at=datetime.now(),
|
||||
)
|
||||
|
||||
dummy_browser_session = AddressablePersistentBrowserSession(
|
||||
persistent_browser_session_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
browser_address="0.0.0.0:9223",
|
||||
created_at=datetime.now(),
|
||||
modified_at=datetime.now(),
|
||||
)
|
||||
|
||||
return dummy_workflow_run, dummy_browser_session
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
return None, None
|
||||
|
||||
if workflow_run.status.is_final():
|
||||
LOG.info(
|
||||
"Workflow run is in a final state. Closing connection.",
|
||||
workflow_run_status=workflow_run.status,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]:
|
||||
LOG.info(
|
||||
"Workflow run is not running.",
|
||||
workflow_run_status=workflow_run.status,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
|
||||
organization_id=organization_id,
|
||||
runnable_id=workflow_run_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return workflow_run, None
|
||||
|
||||
browser_address = browser_session.browser_address
|
||||
|
||||
if not browser_address:
|
||||
LOG.info(
|
||||
"Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
try:
|
||||
_, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
|
||||
session_id=browser_session.persistent_browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
browser_address = f"{host}:{cdp_port}"
|
||||
except Exception as ex:
|
||||
LOG.info(
|
||||
"Browser session address not found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
ex=ex,
|
||||
)
|
||||
return workflow_run, None
|
||||
|
||||
try:
|
||||
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||
**browser_session.model_dump() | {"browser_address": browser_address},
|
||||
)
|
||||
except Exception:
|
||||
return workflow_run, None
|
||||
|
||||
return workflow_run, addressable_browser_session
|
||||
|
||||
|
||||
async def loop_verify_task(streaming: sc.Streaming) -> None:
|
||||
"""
|
||||
Loop until the task is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
while streaming.task and streaming.is_open:
|
||||
task, browser_session = await verify_task(
|
||||
task_id=streaming.task.task_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
streaming.task = task
|
||||
streaming.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
async def loop_verify_workflow_run(verifiable: sc.CommandChannel | sc.Streaming) -> None:
|
||||
"""
|
||||
Loop until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
while verifiable.workflow_run and verifiable.is_open:
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=verifiable.workflow_run.workflow_run_id,
|
||||
organization_id=verifiable.organization_id,
|
||||
)
|
||||
|
||||
verifiable.workflow_run = workflow_run
|
||||
verifiable.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
@@ -1,233 +1,36 @@
|
||||
"""
|
||||
Streaming VNC WebSocket connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
|
||||
import structlog
|
||||
import websockets
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from starlette.websockets import WebSocketState
|
||||
from websockets import Data
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.routes.routers import legacy_base_router
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.services.org_auth_service import get_current_org
|
||||
from skyvern.forge.sdk.routes.streaming_auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming_verify import (
|
||||
loop_verify_task,
|
||||
loop_verify_workflow_run,
|
||||
verify_task,
|
||||
verify_workflow_run,
|
||||
)
|
||||
from skyvern.forge.sdk.utils.aio import collect
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
|
||||
Interactor = t.Literal["agent", "user"]
|
||||
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||
|
||||
|
||||
class MessageType(IntEnum):
|
||||
Keyboard = 4
|
||||
Mouse = 5
|
||||
|
||||
|
||||
class Keys:
|
||||
"""
|
||||
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
|
||||
"""
|
||||
|
||||
class Down:
|
||||
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
|
||||
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
|
||||
|
||||
class Up:
|
||||
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
|
||||
|
||||
|
||||
def is_rmb(data: bytes) -> bool:
|
||||
return data[0:2] == b"\x05\x04"
|
||||
|
||||
|
||||
class Mouse:
|
||||
class Up:
|
||||
Right = is_rmb
|
||||
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KeyState:
|
||||
ctrl_is_down: bool = False
|
||||
alt_is_down: bool = False
|
||||
cmd_is_down: bool = False
|
||||
o_is_down: bool = False
|
||||
|
||||
def is_forbidden(self, data: bytes) -> bool:
|
||||
"""
|
||||
:return: True if the key is forbidden, else False
|
||||
"""
|
||||
return self.is_ctrl_o(data)
|
||||
|
||||
def is_ctrl_o(self, data: bytes) -> bool:
|
||||
"""
|
||||
Do not allow the opening of files.
|
||||
"""
|
||||
return self.ctrl_is_down and data == Keys.Down.OKey
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Streaming:
|
||||
"""
|
||||
Streaming state.
|
||||
"""
|
||||
|
||||
interactor: Interactor
|
||||
"""
|
||||
Whether the user or the agent are the interactor.
|
||||
"""
|
||||
|
||||
organization_id: str
|
||||
vnc_port: int
|
||||
websocket: WebSocket
|
||||
|
||||
# --
|
||||
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
key_state: KeyState = dataclasses.field(default_factory=KeyState)
|
||||
task: Task | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||
return False
|
||||
|
||||
if not self.task and not self.workflow_run:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
|
||||
LOG.info("Closing Streaming.", reason=reason, code=code)
|
||||
|
||||
self.browser_session = None
|
||||
self.task = None
|
||||
self.workflow_run = None
|
||||
|
||||
try:
|
||||
await self.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return self
|
||||
|
||||
def update_key_state(self, data: bytes) -> None:
|
||||
if data == Keys.Down.Ctrl:
|
||||
self.key_state.ctrl_is_down = True
|
||||
elif data == Keys.Up.Ctrl:
|
||||
self.key_state.ctrl_is_down = False
|
||||
elif data == Keys.Down.Alt:
|
||||
self.key_state.alt_is_down = True
|
||||
elif data == Keys.Up.Alt:
|
||||
self.key_state.alt_is_down = False
|
||||
elif data == Keys.Down.Cmd:
|
||||
self.key_state.cmd_is_down = True
|
||||
elif data == Keys.Up.Cmd:
|
||||
self.key_state.cmd_is_down = False
|
||||
|
||||
|
||||
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
||||
"""
|
||||
Accepts the websocket connection.
|
||||
|
||||
Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
|
||||
determined.
|
||||
"""
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
if not token and not apikey:
|
||||
await websocket.close(code=1002)
|
||||
return None
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WebSocket connection closed cleanly.")
|
||||
return None
|
||||
|
||||
try:
|
||||
organization = await get_current_org(x_api_key=apikey, authorization=token)
|
||||
organization_id = organization.organization_id
|
||||
|
||||
if not organization_id:
|
||||
await websocket.close(code=1002)
|
||||
return None
|
||||
except Exception:
|
||||
LOG.exception("Error occurred while retrieving organization information.")
|
||||
try:
|
||||
await websocket.close(code=1002)
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WebSocket connection closed due to invalid credentials.")
|
||||
return None
|
||||
|
||||
return organization_id
|
||||
|
||||
|
||||
async def verify_task(
|
||||
task_id: str, organization_id: str
|
||||
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
|
||||
"""
|
||||
Verify the task is running, and that it has a browser session associated
|
||||
with it.
|
||||
"""
|
||||
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
|
||||
if not task:
|
||||
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
|
||||
return None, None
|
||||
|
||||
if task.status.is_final():
|
||||
LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
|
||||
|
||||
return None, None
|
||||
|
||||
if not task.status == TaskStatus.running:
|
||||
LOG.info("Task is not running.", task_status=task.status, task_id=task_id, organization_id=organization_id)
|
||||
|
||||
return None, None
|
||||
|
||||
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
|
||||
organization_id=organization_id,
|
||||
runnable_id=task_id, # is this correct; is there a task_run_id?
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return task, None
|
||||
|
||||
if not browser_session.browser_address:
|
||||
LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return task, None
|
||||
|
||||
try:
|
||||
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
|
||||
)
|
||||
return task, None
|
||||
|
||||
return task, addressable_browser_session
|
||||
|
||||
|
||||
async def get_streaming_for_task(
|
||||
client_id: str,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[Streaming, Loops] | None:
|
||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||
"""
|
||||
Return a streaming context for a task, with a list of loops to run concurrently.
|
||||
"""
|
||||
@@ -242,8 +45,9 @@ async def get_streaming_for_task(
|
||||
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
streaming = Streaming(
|
||||
interactor="user",
|
||||
streaming = sc.Streaming(
|
||||
client_id=client_id,
|
||||
interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
websocket=websocket,
|
||||
@@ -260,10 +64,11 @@ async def get_streaming_for_task(
|
||||
|
||||
|
||||
async def get_streaming_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[Streaming, Loops] | None:
|
||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||
"""
|
||||
Return a streaming context for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
@@ -287,8 +92,9 @@ async def get_streaming_for_workflow_run(
|
||||
)
|
||||
return None
|
||||
|
||||
streaming = Streaming(
|
||||
interactor="user",
|
||||
streaming = sc.Streaming(
|
||||
client_id=client_id,
|
||||
interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
browser_session=browser_session,
|
||||
@@ -306,128 +112,7 @@ async def get_streaming_for_workflow_run(
|
||||
return streaming, loops
|
||||
|
||||
|
||||
async def verify_workflow_run(
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
|
||||
"""
|
||||
Verify the workflow run is running, and that it has a browser session associated
|
||||
with it.
|
||||
"""
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
return None, None
|
||||
|
||||
if workflow_run.status in [
|
||||
WorkflowRunStatus.completed,
|
||||
WorkflowRunStatus.failed,
|
||||
WorkflowRunStatus.terminated,
|
||||
]:
|
||||
LOG.info(
|
||||
"Workflow run is in a final state. Closing connection.",
|
||||
workflow_run_status=workflow_run.status,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]:
|
||||
LOG.info(
|
||||
"Workflow run is not running.",
|
||||
workflow_run_status=workflow_run.status,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
|
||||
organization_id=organization_id,
|
||||
runnable_id=workflow_run_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return workflow_run, None
|
||||
|
||||
browser_address = browser_session.browser_address
|
||||
|
||||
if not browser_address:
|
||||
LOG.info(
|
||||
"Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
try:
|
||||
_, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
|
||||
session_id=browser_session.persistent_browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
browser_address = f"{host}:{cdp_port}"
|
||||
except Exception as ex:
|
||||
LOG.info(
|
||||
"Browser session address not found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
ex=ex,
|
||||
)
|
||||
return workflow_run, None
|
||||
|
||||
try:
|
||||
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||
**browser_session.model_dump() | {"browser_address": browser_address},
|
||||
)
|
||||
except Exception:
|
||||
return workflow_run, None
|
||||
|
||||
return workflow_run, addressable_browser_session
|
||||
|
||||
|
||||
async def loop_verify_task(streaming: Streaming) -> None:
|
||||
"""
|
||||
Loop until the task is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
while streaming.task and streaming.is_open:
|
||||
task, browser_session = await verify_task(
|
||||
task_id=streaming.task.task_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
streaming.task = task
|
||||
streaming.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
async def loop_verify_workflow_run(streaming: Streaming) -> None:
|
||||
"""
|
||||
Loop until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
while streaming.workflow_run and streaming.is_open:
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=streaming.workflow_run.workflow_run_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
streaming.workflow_run = workflow_run
|
||||
streaming.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
async def loop_stream_vnc(streaming: Streaming) -> None:
|
||||
async def loop_stream_vnc(streaming: sc.Streaming) -> None:
|
||||
"""
|
||||
Actually stream the VNC session data between a frontend and a browser
|
||||
session.
|
||||
@@ -465,19 +150,19 @@ async def loop_stream_vnc(streaming: Streaming) -> None:
|
||||
if data:
|
||||
message_type = data[0]
|
||||
|
||||
if message_type == MessageType.Keyboard.value:
|
||||
if message_type == sc.MessageType.Keyboard.value:
|
||||
streaming.update_key_state(data)
|
||||
|
||||
if streaming.key_state.is_forbidden(data):
|
||||
continue
|
||||
|
||||
if message_type == MessageType.Mouse.value:
|
||||
if Mouse.Up.Right(data):
|
||||
if message_type == sc.MessageType.Mouse.value:
|
||||
if sc.Mouse.Up.Right(data):
|
||||
continue
|
||||
|
||||
if not streaming.interactor == "user" and message_type in (
|
||||
MessageType.Keyboard.value,
|
||||
MessageType.Mouse.value,
|
||||
sc.MessageType.Keyboard.value,
|
||||
sc.MessageType.Mouse.value,
|
||||
):
|
||||
LOG.info(
|
||||
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
|
||||
@@ -615,9 +300,10 @@ async def task_stream(
|
||||
websocket: WebSocket,
|
||||
task_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
await stream(websocket, apikey=apikey, task_id=task_id, token=token)
|
||||
await stream(websocket, apikey=apikey, client_id=client_id, task_id=task_id, token=token)
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/vnc/workflow_run/{workflow_run_id}")
|
||||
@@ -625,32 +311,43 @@ async def workflow_run_stream(
|
||||
websocket: WebSocket,
|
||||
workflow_run_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
await stream(websocket, apikey=apikey, workflow_run_id=workflow_run_id, token=token)
|
||||
await stream(websocket, apikey=apikey, client_id=client_id, workflow_run_id=workflow_run_id, token=token)
|
||||
|
||||
|
||||
async def stream(
|
||||
websocket: WebSocket,
|
||||
*,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
token: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||
if not client_id:
|
||||
LOG.error("Client ID not provided for VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
LOG.info("Starting VNC stream.", client_id=client_id, task_id=task_id, workflow_run_id=workflow_run_id)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
LOG.info("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||
LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
streaming: Streaming
|
||||
streaming: sc.Streaming
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
if task_id:
|
||||
result = await get_streaming_for_task(task_id=task_id, organization_id=organization_id, websocket=websocket)
|
||||
result = await get_streaming_for_task(
|
||||
client_id=client_id,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
|
||||
@@ -666,6 +363,7 @@ async def stream(
|
||||
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
result = await get_streaming_for_workflow_run(
|
||||
client_id=client_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
|
||||
Reference in New Issue
Block a user