BE portion of seamless clipboard transfer in browser stream (#3788)
This commit is contained in:
@@ -6,6 +6,6 @@ from skyvern.forge.sdk.routes import pylon # noqa: F401
|
||||
from skyvern.forge.sdk.routes import run_blocks # noqa: F401
|
||||
from skyvern.forge.sdk.routes import scripts # noqa: F401
|
||||
from skyvern.forge.sdk.routes import streaming # noqa: F401
|
||||
from skyvern.forge.sdk.routes import streaming_commands # noqa: F401
|
||||
from skyvern.forge.sdk.routes import streaming_messages # noqa: F401
|
||||
from skyvern.forge.sdk.routes import streaming_vnc # noqa: F401
|
||||
from skyvern.forge.sdk.routes import webhooks # noqa: F401
|
||||
|
||||
@@ -225,6 +225,8 @@ async def new_debug_session(
|
||||
user_id=current_user_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
vnc_streaming_supported=True if new_browser_session.ip_address else False,
|
||||
# NOTE(jdo:streaming-local-dev)
|
||||
# vnc_streaming_supported=True,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
|
||||
169
skyvern/forge/sdk/routes/streaming_agent.py
Normal file
169
skyvern/forge/sdk/routes/streaming_agent.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
A lightweight "agent" for interacting with the streaming browser over CDP.
|
||||
"""
|
||||
|
||||
import typing
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||
from skyvern.config import settings
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class StreamingAgent:
|
||||
"""
|
||||
A minimal agent that can connect to a browser via CDP and execute JavaScript.
|
||||
|
||||
Specifically for operations during streaming sessions (like copy/pasting selected text, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.browser: Browser | None = None
|
||||
self.browser_context: BrowserContext | None = None
|
||||
self.page: Page | None = None
|
||||
self.pw: Playwright | None = None
|
||||
|
||||
async def connect(self, cdp_url: str | None = None) -> None:
|
||||
url = cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL
|
||||
|
||||
LOG.info("StreamingAgent connecting to CDP", cdp_url=url)
|
||||
|
||||
pw = self.pw or await async_playwright().start()
|
||||
|
||||
self.pw = pw
|
||||
|
||||
self.browser = await pw.chromium.connect_over_cdp(url)
|
||||
|
||||
contexts = self.browser.contexts
|
||||
if contexts:
|
||||
LOG.info("StreamingAgent using existing browser context")
|
||||
self.browser_context = contexts[0]
|
||||
else:
|
||||
LOG.warning("No existing browser context found, creating new one")
|
||||
self.browser_context = await self.browser.new_context()
|
||||
|
||||
pages = self.browser_context.pages
|
||||
if pages:
|
||||
self.page = pages[0]
|
||||
LOG.info("StreamingAgent connected to page", url=self.page.url)
|
||||
else:
|
||||
LOG.warning("No pages found in browser context")
|
||||
|
||||
LOG.info("StreamingAgent connected successfully")
|
||||
|
||||
async def evaluate_js(
|
||||
self, expression: str, arg: str | int | float | bool | list | dict | None = None
|
||||
) -> str | int | float | bool | list | dict | None:
|
||||
if not self.page:
|
||||
raise RuntimeError("StreamingAgent is not connected to a page. Call connect() first.")
|
||||
|
||||
LOG.info("StreamingAgent evaluating JS", expression=expression[:100])
|
||||
|
||||
try:
|
||||
result = await self.page.evaluate(expression, arg)
|
||||
LOG.info("StreamingAgent JS evaluation successful")
|
||||
return result
|
||||
except Exception as ex:
|
||||
LOG.exception("StreamingAgent JS evaluation failed", expression=expression, ex=str(ex))
|
||||
raise
|
||||
|
||||
async def get_selected_text(self) -> str:
|
||||
LOG.info("StreamingAgent getting selected text")
|
||||
|
||||
js_expression = """
|
||||
() => {
|
||||
const selection = window.getSelection();
|
||||
return selection ? selection.toString() : '';
|
||||
}
|
||||
"""
|
||||
|
||||
selected_text = await self.evaluate_js(js_expression)
|
||||
|
||||
if isinstance(selected_text, str) or selected_text is None:
|
||||
LOG.info("StreamingAgent got selected text", length=len(selected_text) if selected_text else 0)
|
||||
return selected_text or ""
|
||||
|
||||
raise RuntimeError(f"StreamingAgent selected text is not a string, but a(n) '{type(selected_text)}'")
|
||||
|
||||
async def paste_text(self, text: str) -> None:
|
||||
LOG.info("StreamingAgent pasting text")
|
||||
|
||||
js_expression = """
|
||||
(text) => {
|
||||
const activeElement = document.activeElement;
|
||||
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
|
||||
const start = activeElement.selectionStart || 0;
|
||||
const end = activeElement.selectionEnd || 0;
|
||||
const value = activeElement.value || '';
|
||||
activeElement.value = value.slice(0, start) + text + value.slice(end);
|
||||
const newCursorPos = start + text.length;
|
||||
activeElement.setSelectionRange(newCursorPos, newCursorPos);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
await self.evaluate_js(js_expression, text)
|
||||
|
||||
LOG.info("StreamingAgent pasted text successfully")
|
||||
|
||||
async def close(self) -> None:
|
||||
LOG.info("StreamingAgent closing connection")
|
||||
|
||||
if self.browser:
|
||||
await self.browser.close()
|
||||
self.browser = None
|
||||
self.browser_context = None
|
||||
self.page = None
|
||||
|
||||
if self.pw:
|
||||
await self.pw.stop()
|
||||
self.pw = None
|
||||
|
||||
LOG.info("StreamingAgent closed")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterator[StreamingAgent]:
|
||||
"""
|
||||
The first pass at this has us doing the following for every operation:
|
||||
- creating a new agent
|
||||
- connecting
|
||||
- [doing smth]
|
||||
- closing the agent
|
||||
|
||||
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
|
||||
|
||||
If it turns out it's too slow, we can refactor to keep a persistent agent per streaming client.
|
||||
"""
|
||||
|
||||
if not streaming:
|
||||
msg = "connected_agent: no streaming client provided."
|
||||
LOG.error(msg)
|
||||
|
||||
raise Exception(msg)
|
||||
if not streaming.browser_session or not streaming.browser_session.browser_address:
|
||||
msg = "connected_agent: no browser session or browser address found for streaming client."
|
||||
|
||||
LOG.error(
|
||||
msg,
|
||||
client_id=streaming.client_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
raise Exception(msg)
|
||||
|
||||
agent = StreamingAgent()
|
||||
|
||||
try:
|
||||
await agent.connect(streaming.browser_session.browser_address)
|
||||
|
||||
# NOTE(jdo:streaming-local-dev): use BROWSER_REMOTE_DEBUGGING_URL from settings
|
||||
# await agent.connect()
|
||||
|
||||
yield agent
|
||||
finally:
|
||||
await agent.close()
|
||||
@@ -44,3 +44,18 @@ async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> s
|
||||
return None
|
||||
|
||||
return organization_id
|
||||
|
||||
|
||||
# NOTE(jdo:streaming-local-dev): use this instead of the above `auth`
|
||||
async def _auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
||||
"""
|
||||
Dummy auth for local testing.
|
||||
"""
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WebSocket connection closed cleanly.")
|
||||
return None
|
||||
|
||||
return "o_temp123"
|
||||
|
||||
@@ -22,30 +22,30 @@ Interactor = t.Literal["agent", "user"]
|
||||
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||
|
||||
|
||||
# Commands
|
||||
# Messages
|
||||
|
||||
|
||||
# a global registry for WS command clients
|
||||
command_channels: dict[str, "CommandChannel"] = {}
|
||||
# a global registry for WS message clients
|
||||
message_channels: dict[str, "MessageChannel"] = {}
|
||||
|
||||
|
||||
def add_command_client(command_channel: "CommandChannel") -> None:
|
||||
command_channels[command_channel.client_id] = command_channel
|
||||
def add_message_client(message_channel: "MessageChannel") -> None:
|
||||
message_channels[message_channel.client_id] = message_channel
|
||||
|
||||
|
||||
def get_command_client(client_id: str) -> t.Union["CommandChannel", None]:
|
||||
return command_channels.get(client_id, None)
|
||||
def get_message_client(client_id: str) -> t.Union["MessageChannel", None]:
|
||||
return message_channels.get(client_id, None)
|
||||
|
||||
|
||||
def del_command_client(client_id: str) -> None:
|
||||
def del_message_client(client_id: str) -> None:
|
||||
try:
|
||||
del command_channels[client_id]
|
||||
del message_channels[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommandChannel:
|
||||
class MessageChannel:
|
||||
client_id: str
|
||||
organization_id: str
|
||||
websocket: WebSocket
|
||||
@@ -56,10 +56,10 @@ class CommandChannel:
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
add_command_client(self)
|
||||
add_message_client(self)
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "CommandChannel":
|
||||
LOG.info("Closing command stream.", reason=reason, code=code)
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
|
||||
LOG.info("Closing message stream.", reason=reason, code=code)
|
||||
|
||||
self.browser_session = None
|
||||
self.workflow_run = None
|
||||
@@ -69,7 +69,7 @@ class CommandChannel:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
del_command_client(self.client_id)
|
||||
del_message_client(self.client_id)
|
||||
|
||||
return self
|
||||
|
||||
@@ -81,43 +81,87 @@ class CommandChannel:
|
||||
if not self.workflow_run and not self.browser_session:
|
||||
return False
|
||||
|
||||
if not get_command_client(self.client_id):
|
||||
if not get_message_client(self.client_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def ask_for_clipboard(self, streaming: "Streaming") -> None:
|
||||
try:
|
||||
await self.websocket.send_json(
|
||||
{
|
||||
"kind": "ask-for-clipboard",
|
||||
}
|
||||
)
|
||||
LOG.info(
|
||||
"Sent ask-for-clipboard to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to send ask-for-clipboard to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
CommandKinds = t.Literal["take-control", "cede-control"]
|
||||
async def send_copied_text(self, copied_text: str, streaming: "Streaming") -> None:
|
||||
try:
|
||||
await self.websocket.send_json(
|
||||
{
|
||||
"kind": "copied-text",
|
||||
"text": copied_text,
|
||||
}
|
||||
)
|
||||
LOG.info(
|
||||
"Sent copied text to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to send copied text to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
|
||||
MessageKinds = t.Literal["take-control", "cede-control", "ask-for-clipboard-response"]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Command:
|
||||
kind: CommandKinds
|
||||
class Message:
|
||||
kind: MessageKinds
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommandTakeControl(Command):
|
||||
class MessageTakeControl(Message):
|
||||
kind: t.Literal["take-control"] = "take-control"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommandCedeControl(Command):
|
||||
class MessageCedeControl(Message):
|
||||
kind: t.Literal["cede-control"] = "cede-control"
|
||||
|
||||
|
||||
ChannelCommand = t.Union[CommandTakeControl, CommandCedeControl]
|
||||
@dataclasses.dataclass
|
||||
class MessageInAskForClipboardResponse(Message):
|
||||
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
|
||||
text: str = ""
|
||||
|
||||
|
||||
def reify_channel_command(data: dict) -> ChannelCommand:
|
||||
ChannelMessage = t.Union[MessageTakeControl, MessageCedeControl, MessageInAskForClipboardResponse]
|
||||
|
||||
|
||||
def reify_channel_message(data: dict) -> ChannelMessage:
|
||||
kind = data.get("kind", None)
|
||||
|
||||
match kind:
|
||||
case "take-control":
|
||||
return CommandTakeControl()
|
||||
return MessageTakeControl()
|
||||
case "cede-control":
|
||||
return CommandCedeControl()
|
||||
return MessageCedeControl()
|
||||
case "ask-for-clipboard-response":
|
||||
text = data.get("text") or ""
|
||||
return MessageInAskForClipboardResponse(text=text)
|
||||
case _:
|
||||
raise ValueError(f"Unknown command kind: '{kind}'")
|
||||
raise ValueError(f"Unknown message kind: '{kind}'")
|
||||
|
||||
|
||||
# Streaming
|
||||
@@ -159,7 +203,9 @@ class Keys:
|
||||
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
|
||||
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
|
||||
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
|
||||
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
|
||||
|
||||
class Up:
|
||||
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
|
||||
@@ -181,7 +227,6 @@ class KeyState:
|
||||
ctrl_is_down: bool = False
|
||||
alt_is_down: bool = False
|
||||
cmd_is_down: bool = False
|
||||
o_is_down: bool = False
|
||||
|
||||
def is_forbidden(self, data: bytes) -> bool:
|
||||
"""
|
||||
@@ -195,6 +240,18 @@ class KeyState:
|
||||
"""
|
||||
return self.ctrl_is_down and data == Keys.Down.OKey
|
||||
|
||||
def is_copy(self, data: bytes) -> bool:
|
||||
"""
|
||||
Detect Ctrl+C or Cmd+C for copy.
|
||||
"""
|
||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
|
||||
|
||||
def is_paste(self, data: bytes) -> bool:
|
||||
"""
|
||||
Detect Ctrl+V or Cmd+V for paste.
|
||||
"""
|
||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Streaming:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Streaming commands WebSocket connections.
|
||||
Streaming messages for WebSocket connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -10,6 +10,7 @@ from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming_agent import connected_agent
|
||||
from skyvern.forge.sdk.routes.streaming_auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming_verify import (
|
||||
loop_verify_browser_session,
|
||||
@@ -22,17 +23,17 @@ from skyvern.forge.sdk.utils.aio import collect
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_commands_for_browser_session(
|
||||
async def get_messages_for_browser_session(
|
||||
client_id: str,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.CommandChannel, sc.Loops] | None:
|
||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
||||
"""
|
||||
Return a commands channel for a browser session, with a list of loops to run concurrently.
|
||||
Return a message channel for a browser session, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting commands channel for browser session.", browser_session_id=browser_session_id)
|
||||
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=browser_session_id,
|
||||
@@ -41,40 +42,40 @@ async def get_commands_for_browser_session(
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Command channel: no initial browser session found.",
|
||||
"Message channel: no initial browser session found.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
commands = sc.CommandChannel(
|
||||
message_channel = sc.MessageChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got command channel for browser session.", commands=commands)
|
||||
LOG.info("Got message channel for browser session.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_browser_session(commands)),
|
||||
asyncio.create_task(loop_channel(commands)),
|
||||
asyncio.create_task(loop_verify_browser_session(message_channel)),
|
||||
asyncio.create_task(loop_channel(message_channel)),
|
||||
]
|
||||
|
||||
return commands, loops
|
||||
return message_channel, loops
|
||||
|
||||
|
||||
async def get_commands_for_workflow_run(
|
||||
async def get_messages_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.CommandChannel, sc.Loops] | None:
|
||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
||||
"""
|
||||
Return a commands channel for a workflow run, with a list of loops to run concurrently.
|
||||
Return a message channel for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting commands channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
@@ -83,7 +84,7 @@ async def get_commands_for_workflow_run(
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info(
|
||||
"Command channel: no initial workflow run found.",
|
||||
"Message channel: no initial workflow run found.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
@@ -91,13 +92,13 @@ async def get_commands_for_workflow_run(
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Command channel: no initial browser session found for workflow run.",
|
||||
"Message channel: no initial browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
commands = sc.CommandChannel(
|
||||
message_channel = sc.MessageChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
@@ -105,78 +106,100 @@ async def get_commands_for_workflow_run(
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got command channel for workflow run.", commands=commands)
|
||||
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_workflow_run(commands)),
|
||||
asyncio.create_task(loop_channel(commands)),
|
||||
asyncio.create_task(loop_verify_workflow_run(message_channel)),
|
||||
asyncio.create_task(loop_channel(message_channel)),
|
||||
]
|
||||
|
||||
return commands, loops
|
||||
return message_channel, loops
|
||||
|
||||
|
||||
async def loop_channel(commands: sc.CommandChannel) -> None:
|
||||
async def loop_channel(message_channel: sc.MessageChannel) -> None:
|
||||
"""
|
||||
Stream commands and their results back and forth.
|
||||
Stream messages and their results back and forth.
|
||||
|
||||
Loops until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
if not commands.browser_session:
|
||||
if not message_channel.browser_session:
|
||||
LOG.info(
|
||||
"No browser session found for workflow run.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
async def frontend_to_backend() -> None:
|
||||
LOG.info("Starting frontend-to-backend channel loop.", commands=commands)
|
||||
LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel)
|
||||
|
||||
while commands.is_open:
|
||||
while message_channel.is_open:
|
||||
try:
|
||||
data = await commands.websocket.receive_json()
|
||||
data = await message_channel.websocket.receive_json()
|
||||
|
||||
if not isinstance(data, dict):
|
||||
LOG.error(f"Cannot create channel command: expected dict, got {type(data)}")
|
||||
LOG.error(f"Cannot create channel message: expected dict, got {type(data)}")
|
||||
continue
|
||||
|
||||
try:
|
||||
command = sc.reify_channel_command(data)
|
||||
message = sc.reify_channel_message(data)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
command_kind = command.kind
|
||||
message_kind = message.kind
|
||||
|
||||
match command_kind:
|
||||
match message_kind:
|
||||
case "take-control":
|
||||
streaming = sc.get_streaming_client(commands.client_id)
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
if not streaming:
|
||||
LOG.error("No streaming client found for command.", commands=commands, command=command)
|
||||
LOG.error(
|
||||
"No streaming client found for message.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
streaming.interactor = "user"
|
||||
case "cede-control":
|
||||
streaming = sc.get_streaming_client(commands.client_id)
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
if not streaming:
|
||||
LOG.error("No streaming client found for command.", commands=commands, command=command)
|
||||
LOG.error(
|
||||
"No streaming client found for message.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
streaming.interactor = "agent"
|
||||
case "ask-for-clipboard-response":
|
||||
if not isinstance(message, sc.MessageInAskForClipboardResponse):
|
||||
LOG.error(
|
||||
"Invalid message type for ask-for-clipboard-response.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
text = message.text
|
||||
|
||||
async with connected_agent(streaming) as agent:
|
||||
await agent.paste_text(text)
|
||||
case _:
|
||||
LOG.error(f"Unknown command kind: '{command_kind}'")
|
||||
LOG.error(f"Unknown message kind: '{message_kind}'")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Frontend disconnected.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Frontend closed the streaming session.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
@@ -184,8 +207,8 @@ async def loop_channel(commands: sc.CommandChannel) -> None:
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -198,27 +221,28 @@ async def loop_channel(commands: sc.CommandChannel) -> None:
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in loop channel stream.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the loop channel stream.",
|
||||
workflow_run=commands.workflow_run,
|
||||
organization_id=commands.organization_id,
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
await commands.close(reason="loop-channel-closed")
|
||||
await message_channel.close(reason="loop-channel-closed")
|
||||
|
||||
|
||||
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
|
||||
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
|
||||
async def browser_session_commands(
|
||||
async def browser_session_messages(
|
||||
websocket: WebSocket,
|
||||
browser_session_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting stream commands for browser session.", browser_session_id=browser_session_id)
|
||||
LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
@@ -230,10 +254,10 @@ async def browser_session_commands(
|
||||
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
|
||||
return
|
||||
|
||||
commands: sc.CommandChannel
|
||||
message_channel: sc.MessageChannel
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
result = await get_commands_for_browser_session(
|
||||
result = await get_messages_for_browser_session(
|
||||
client_id=client_id,
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
@@ -249,40 +273,41 @@ async def browser_session_commands(
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
commands, loops = result
|
||||
message_channel, loops = result
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting command stream loops for browser session.",
|
||||
"Starting message stream loops for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the command stream function for browser session.",
|
||||
"An exception occurred in the message stream function for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the command stream session for browser session.",
|
||||
"Closing the message stream session for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await commands.close(reason="stream-closed")
|
||||
await message_channel.close(reason="stream-closed")
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
|
||||
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
|
||||
async def workflow_run_commands(
|
||||
async def workflow_run_messages(
|
||||
websocket: WebSocket,
|
||||
workflow_run_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting stream commands.", workflow_run_id=workflow_run_id)
|
||||
LOG.info("Starting message stream.", workflow_run_id=workflow_run_id)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
@@ -294,10 +319,10 @@ async def workflow_run_commands(
|
||||
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
commands: sc.CommandChannel
|
||||
message_channel: sc.MessageChannel
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
result = await get_commands_for_workflow_run(
|
||||
result = await get_messages_for_workflow_run(
|
||||
client_id=client_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
@@ -313,26 +338,26 @@ async def workflow_run_commands(
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
commands, loops = result
|
||||
message_channel, loops = result
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting command stream loops.",
|
||||
"Starting message stream loops.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the command stream function.",
|
||||
"An exception occurred in the message stream function.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the command stream session.",
|
||||
"Closing the message stream session.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await commands.close(reason="stream-closed")
|
||||
await message_channel.close(reason="stream-closed")
|
||||
@@ -233,7 +233,7 @@ async def verify_workflow_run(
|
||||
return workflow_run, addressable_browser_session
|
||||
|
||||
|
||||
async def loop_verify_browser_session(verifiable: sc.CommandChannel | sc.Streaming) -> None:
|
||||
async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streaming) -> None:
|
||||
"""
|
||||
Loop until the browser session is cleared or the websocket is closed.
|
||||
"""
|
||||
@@ -266,7 +266,7 @@ async def loop_verify_task(streaming: sc.Streaming) -> None:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
async def loop_verify_workflow_run(verifiable: sc.CommandChannel | sc.Streaming) -> None:
|
||||
async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) -> None:
|
||||
"""
|
||||
Loop until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
"""
|
||||
Streaming VNC WebSocket connections.
|
||||
|
||||
NOTE(jdo:streaming-local-dev)
|
||||
-----------------------------
|
||||
- grep the above for local development seams
|
||||
- augment those seams as indicated, then
|
||||
- stand up https://github.com/jomido/whyvern
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import typing as t
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import structlog
|
||||
@@ -14,6 +22,7 @@ from websockets.exceptions import ConnectionClosedError
|
||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming_agent import connected_agent
|
||||
from skyvern.forge.sdk.routes.streaming_auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming_verify import (
|
||||
loop_verify_browser_session,
|
||||
@@ -157,6 +166,68 @@ async def get_streaming_for_workflow_run(
|
||||
return streaming, loops
|
||||
|
||||
|
||||
def verify_message_channel(
|
||||
message_channel: sc.MessageChannel | None, streaming: sc.Streaming
|
||||
) -> sc.MessageChannel | t.Literal[False]:
|
||||
if message_channel and message_channel.is_open:
|
||||
return message_channel
|
||||
|
||||
LOG.warning(
|
||||
"No message channel found for client, or it is not open",
|
||||
message_channel=message_channel,
|
||||
client_id=streaming.client_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def copy_text(streaming: sc.Streaming) -> None:
|
||||
try:
|
||||
async with connected_agent(streaming) as agent:
|
||||
copied_text = await agent.get_selected_text()
|
||||
|
||||
LOG.info(
|
||||
"Retrieved selected text via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
message_channel = sc.get_message_client(streaming.client_id)
|
||||
|
||||
if cc := verify_message_channel(message_channel, streaming):
|
||||
await cc.send_copied_text(copied_text, streaming)
|
||||
else:
|
||||
LOG.warning(
|
||||
"No message channel found for client, or it is not open",
|
||||
message_channel=message_channel,
|
||||
client_id=streaming.client_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to retrieve selected text via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
|
||||
async def ask_for_clipboard(streaming: sc.Streaming) -> None:
|
||||
try:
|
||||
LOG.info(
|
||||
"Asking for clipboard data via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
message_channel = sc.get_message_client(streaming.client_id)
|
||||
|
||||
if cc := verify_message_channel(message_channel, streaming):
|
||||
await cc.ask_for_clipboard(streaming)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to ask for clipboard via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
|
||||
async def loop_stream_vnc(streaming: sc.Streaming) -> None:
|
||||
"""
|
||||
Actually stream the VNC session data between a frontend and a browser
|
||||
@@ -183,6 +254,9 @@ async def loop_stream_vnc(streaming: sc.Streaming) -> None:
|
||||
host = parsed_browser_address.hostname
|
||||
vnc_url = f"ws://{host}:{streaming.vnc_port}"
|
||||
|
||||
# NOTE(jdo:streaming-local-dev)
|
||||
# vnc_url = "ws://localhost:9001/ws/novnc"
|
||||
|
||||
LOG.info(
|
||||
"Connecting to VNC URL.",
|
||||
vnc_url=vnc_url,
|
||||
@@ -207,6 +281,12 @@ async def loop_stream_vnc(streaming: sc.Streaming) -> None:
|
||||
if message_type == sc.MessageType.Keyboard.value:
|
||||
streaming.update_key_state(data)
|
||||
|
||||
if streaming.key_state.is_copy(data):
|
||||
await copy_text(streaming)
|
||||
|
||||
if streaming.key_state.is_paste(data):
|
||||
await ask_for_clipboard(streaming)
|
||||
|
||||
if streaming.key_state.is_forbidden(data):
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user