From b52e88bd99a6f71b779fbea405403fcbee100fa3 Mon Sep 17 00:00:00 2001 From: Jonathan Dobson Date: Wed, 22 Oct 2025 11:57:50 -0400 Subject: [PATCH] BE portion of seamless clipboard transfer in browser stream (#3788) --- skyvern/forge/sdk/routes/__init__.py | 2 +- skyvern/forge/sdk/routes/debug_sessions.py | 2 + skyvern/forge/sdk/routes/streaming_agent.py | 169 ++++++++++++++++++ skyvern/forge/sdk/routes/streaming_auth.py | 15 ++ skyvern/forge/sdk/routes/streaming_clients.py | 109 ++++++++--- ...ming_commands.py => streaming_messages.py} | 161 ++++++++++------- skyvern/forge/sdk/routes/streaming_verify.py | 4 +- skyvern/forge/sdk/routes/streaming_vnc.py | 80 +++++++++ 8 files changed, 445 insertions(+), 97 deletions(-) create mode 100644 skyvern/forge/sdk/routes/streaming_agent.py rename skyvern/forge/sdk/routes/{streaming_commands.py => streaming_messages.py} (55%) diff --git a/skyvern/forge/sdk/routes/__init__.py b/skyvern/forge/sdk/routes/__init__.py index c20a5381..53188c8b 100644 --- a/skyvern/forge/sdk/routes/__init__.py +++ b/skyvern/forge/sdk/routes/__init__.py @@ -6,6 +6,6 @@ from skyvern.forge.sdk.routes import pylon # noqa: F401 from skyvern.forge.sdk.routes import run_blocks # noqa: F401 from skyvern.forge.sdk.routes import scripts # noqa: F401 from skyvern.forge.sdk.routes import streaming # noqa: F401 -from skyvern.forge.sdk.routes import streaming_commands # noqa: F401 +from skyvern.forge.sdk.routes import streaming_messages # noqa: F401 from skyvern.forge.sdk.routes import streaming_vnc # noqa: F401 from skyvern.forge.sdk.routes import webhooks # noqa: F401 diff --git a/skyvern/forge/sdk/routes/debug_sessions.py b/skyvern/forge/sdk/routes/debug_sessions.py index dea81fd7..2a502783 100644 --- a/skyvern/forge/sdk/routes/debug_sessions.py +++ b/skyvern/forge/sdk/routes/debug_sessions.py @@ -225,6 +225,8 @@ async def new_debug_session( user_id=current_user_id, workflow_permanent_id=workflow_permanent_id, vnc_streaming_supported=True if new_browser_session.ip_address else False, + # NOTE(jdo:streaming-local-dev) + # vnc_streaming_supported=True, ) LOG.info( diff --git a/skyvern/forge/sdk/routes/streaming_agent.py b/skyvern/forge/sdk/routes/streaming_agent.py new file mode 100644 index 00000000..a84f298e --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming_agent.py @@ -0,0 +1,169 @@ +""" +A lightweight "agent" for interacting with the streaming browser over CDP. +""" + +import typing +from contextlib import asynccontextmanager + +import structlog +from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright + +import skyvern.forge.sdk.routes.streaming_clients as sc +from skyvern.config import settings + +LOG = structlog.get_logger() + + +class StreamingAgent: + """ + A minimal agent that can connect to a browser via CDP and execute JavaScript. + + Specifically for operations during streaming sessions (like copy/pasting selected text, etc.). + """ + + def __init__(self) -> None: + self.browser: Browser | None = None + self.browser_context: BrowserContext | None = None + self.page: Page | None = None + self.pw: Playwright | None = None + + async def connect(self, cdp_url: str | None = None) -> None: + url = cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL + + LOG.info("StreamingAgent connecting to CDP", cdp_url=url) + + pw = self.pw or await async_playwright().start() + + self.pw = pw + + self.browser = await pw.chromium.connect_over_cdp(url) + + contexts = self.browser.contexts + if contexts: + LOG.info("StreamingAgent using existing browser context") + self.browser_context = contexts[0] + else: + LOG.warning("No existing browser context found, creating new one") + self.browser_context = await self.browser.new_context() + + pages = self.browser_context.pages + if pages: + self.page = pages[0] + LOG.info("StreamingAgent connected to page", url=self.page.url) + else: + LOG.warning("No pages found in browser context") + + LOG.info("StreamingAgent connected successfully") + + async def evaluate_js( + self, expression: str, arg: str | int | float | bool | list | dict | None = None + ) -> str | int | float | bool | list | dict | None: + if not self.page: + raise RuntimeError("StreamingAgent is not connected to a page. Call connect() first.") + + LOG.info("StreamingAgent evaluating JS", expression=expression[:100]) + + try: + result = await self.page.evaluate(expression, arg) + LOG.info("StreamingAgent JS evaluation successful") + return result + except Exception as ex: + LOG.exception("StreamingAgent JS evaluation failed", expression=expression, ex=str(ex)) + raise + + async def get_selected_text(self) -> str: + LOG.info("StreamingAgent getting selected text") + + js_expression = """ + () => { + const selection = window.getSelection(); + return selection ? selection.toString() : ''; + } + """ + + selected_text = await self.evaluate_js(js_expression) + + if isinstance(selected_text, str) or selected_text is None: + LOG.info("StreamingAgent got selected text", length=len(selected_text) if selected_text else 0) + return selected_text or "" + + raise RuntimeError(f"StreamingAgent selected text is not a string, but a(n) '{type(selected_text)}'") + + async def paste_text(self, text: str) -> None: + LOG.info("StreamingAgent pasting text") + + js_expression = """ + (text) => { + const activeElement = document.activeElement; + if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) { + const start = activeElement.selectionStart || 0; + const end = activeElement.selectionEnd || 0; + const value = activeElement.value || ''; + activeElement.value = value.slice(0, start) + text + value.slice(end); + const newCursorPos = start + text.length; + activeElement.setSelectionRange(newCursorPos, newCursorPos); + } + } + """ + + await self.evaluate_js(js_expression, text) + + LOG.info("StreamingAgent pasted text successfully") + + async def close(self) -> None: + LOG.info("StreamingAgent closing connection") + + if self.browser: + await self.browser.close() + self.browser = None + self.browser_context = None + self.page = None + + if self.pw: + await self.pw.stop() + self.pw = None + + LOG.info("StreamingAgent closed") + + +@asynccontextmanager +async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterator[StreamingAgent]: + """ + The first pass at this has us doing the following for every operation: + - creating a new agent + - connecting + - [doing smth] + - closing the agent + + This may add latency, but locally it is pretty fast. This keeps things stateless for now. + + If it turns out it's too slow, we can refactor to keep a persistent agent per streaming client. + """ + + if not streaming: + msg = "connected_agent: no streaming client provided." + LOG.error(msg) + + raise Exception(msg) + if not streaming.browser_session or not streaming.browser_session.browser_address: + msg = "connected_agent: no browser session or browser address found for streaming client." + + LOG.error( + msg, + client_id=streaming.client_id, + organization_id=streaming.organization_id, + ) + + raise Exception(msg) + + agent = StreamingAgent() + + try: + await agent.connect(streaming.browser_session.browser_address) + + # NOTE(jdo:streaming-local-dev): use BROWSER_REMOTE_DEBUGGING_URL from settings + # await agent.connect() + + yield agent + finally: + await agent.close() diff --git a/skyvern/forge/sdk/routes/streaming_auth.py b/skyvern/forge/sdk/routes/streaming_auth.py index 0e53b5ee..00250d24 100644 --- a/skyvern/forge/sdk/routes/streaming_auth.py +++ b/skyvern/forge/sdk/routes/streaming_auth.py @@ -44,3 +44,18 @@ async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> s return None return organization_id + + +# NOTE(jdo:streaming-local-dev): use this instead of the above `auth` +async def _auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None: + """ + Dummy auth for local testing. + """ + + try: + await websocket.accept() + except ConnectionClosedOK: + LOG.info("WebSocket connection closed cleanly.") + return None + + return "o_temp123" diff --git a/skyvern/forge/sdk/routes/streaming_clients.py b/skyvern/forge/sdk/routes/streaming_clients.py index 8baf8c7b..dbb6e16e 100644 --- a/skyvern/forge/sdk/routes/streaming_clients.py +++ b/skyvern/forge/sdk/routes/streaming_clients.py @@ -22,30 +22,30 @@ Interactor = t.Literal["agent", "user"] Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs" -# Commands +# Messages -# a global registry for WS command clients -command_channels: dict[str, "CommandChannel"] = {} +# a global registry for WS message clients +message_channels: dict[str, "MessageChannel"] = {} -def add_command_client(command_channel: "CommandChannel") -> None: - command_channels[command_channel.client_id] = command_channel +def add_message_client(message_channel: "MessageChannel") -> None: + message_channels[message_channel.client_id] = message_channel -def get_command_client(client_id: str) -> t.Union["CommandChannel", None]: - return command_channels.get(client_id, None) +def get_message_client(client_id: str) -> t.Union["MessageChannel", None]: + return message_channels.get(client_id, None) -def del_command_client(client_id: str) -> None: +def del_message_client(client_id: str) -> None: try: - del command_channels[client_id] + del message_channels[client_id] except KeyError: pass @dataclasses.dataclass -class CommandChannel: +class MessageChannel: client_id: str organization_id: str websocket: WebSocket @@ -56,10 +56,10 @@ class CommandChannel: workflow_run: WorkflowRun | None = None def __post_init__(self) -> None: - add_command_client(self) + add_message_client(self) - async def close(self, code: int = 1000, reason: str | None = None) -> "CommandChannel": - LOG.info("Closing command stream.", reason=reason, code=code) + async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel": + LOG.info("Closing message stream.", reason=reason, code=code) self.browser_session = None self.workflow_run = None @@ -69,7 +69,7 @@ class CommandChannel: except Exception: pass - del_command_client(self.client_id) + del_message_client(self.client_id) return self @@ -81,43 +81,87 @@ class CommandChannel: if not self.workflow_run and not self.browser_session: return False - if not get_command_client(self.client_id): + if not get_message_client(self.client_id): return False return True + async def ask_for_clipboard(self, streaming: "Streaming") -> None: + try: + await self.websocket.send_json( + { + "kind": "ask-for-clipboard", + } + ) + LOG.info( + "Sent ask-for-clipboard to message channel", + organization_id=streaming.organization_id, + ) + except Exception: + LOG.exception( + "Failed to send ask-for-clipboard to message channel", + organization_id=streaming.organization_id, + ) -CommandKinds = t.Literal["take-control", "cede-control"] + async def send_copied_text(self, copied_text: str, streaming: "Streaming") -> None: + try: + await self.websocket.send_json( + { + "kind": "copied-text", + "text": copied_text, + } + ) + LOG.info( + "Sent copied text to message channel", + organization_id=streaming.organization_id, + ) + except Exception: + LOG.exception( + "Failed to send copied text to message channel", + organization_id=streaming.organization_id, + ) + + +MessageKinds = t.Literal["take-control", "cede-control", "ask-for-clipboard-response"] @dataclasses.dataclass -class Command: - kind: CommandKinds +class Message: + kind: MessageKinds @dataclasses.dataclass -class CommandTakeControl(Command): +class MessageTakeControl(Message): kind: t.Literal["take-control"] = "take-control" @dataclasses.dataclass -class CommandCedeControl(Command): +class MessageCedeControl(Message): kind: t.Literal["cede-control"] = "cede-control" -ChannelCommand = t.Union[CommandTakeControl, CommandCedeControl] +@dataclasses.dataclass +class MessageInAskForClipboardResponse(Message): + kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response" + text: str = "" -def reify_channel_command(data: dict) -> ChannelCommand: +ChannelMessage = t.Union[MessageTakeControl, MessageCedeControl, MessageInAskForClipboardResponse] + + +def reify_channel_message(data: dict) -> ChannelMessage: kind = data.get("kind", None) match kind: case "take-control": - return CommandTakeControl() + return MessageTakeControl() case "cede-control": - return CommandCedeControl() + return MessageCedeControl() + case "ask-for-clipboard-response": + text = data.get("text") or "" + return MessageInAskForClipboardResponse(text=text) case _: - raise ValueError(f"Unknown command kind: '{kind}'") + raise ValueError(f"Unknown message kind: '{kind}'") # Streaming @@ -159,7 +203,9 @@ class Keys: Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3" Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9" Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option + CKey = b"\x04\x01\x00\x00\x00\x00\x00c" OKey = b"\x04\x01\x00\x00\x00\x00\x00o" + VKey = b"\x04\x01\x00\x00\x00\x00\x00v" class Up: Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3" @@ -181,7 +227,6 @@ class KeyState: ctrl_is_down: bool = False alt_is_down: bool = False cmd_is_down: bool = False - o_is_down: bool = False def is_forbidden(self, data: bytes) -> bool: """ @@ -195,6 +240,18 @@ class KeyState: """ return self.ctrl_is_down and data == Keys.Down.OKey + def is_copy(self, data: bytes) -> bool: + """ + Detect Ctrl+C or Cmd+C for copy. + """ + return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey + + def is_paste(self, data: bytes) -> bool: + """ + Detect Ctrl+V or Cmd+V for paste. + """ + return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey + @dataclasses.dataclass class Streaming: diff --git a/skyvern/forge/sdk/routes/streaming_commands.py b/skyvern/forge/sdk/routes/streaming_messages.py similarity index 55% rename from skyvern/forge/sdk/routes/streaming_commands.py rename to skyvern/forge/sdk/routes/streaming_messages.py index 7854357c..ac2e4c85 100644 --- a/skyvern/forge/sdk/routes/streaming_commands.py +++ b/skyvern/forge/sdk/routes/streaming_messages.py @@ -1,5 +1,5 @@ """ -Streaming commands WebSocket connections. +Streaming messages for WebSocket connections. """ import asyncio @@ -10,6 +10,7 @@ from websockets.exceptions import ConnectionClosedError import skyvern.forge.sdk.routes.streaming_clients as sc from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router +from skyvern.forge.sdk.routes.streaming_agent import connected_agent from skyvern.forge.sdk.routes.streaming_auth import auth from skyvern.forge.sdk.routes.streaming_verify import ( loop_verify_browser_session, @@ -22,17 +23,17 @@ from skyvern.forge.sdk.utils.aio import collect LOG = structlog.get_logger() -async def get_commands_for_browser_session( +async def get_messages_for_browser_session( client_id: str, browser_session_id: str, organization_id: str, websocket: WebSocket, -) -> tuple[sc.CommandChannel, sc.Loops] | None: +) -> tuple[sc.MessageChannel, sc.Loops] | None: """ - Return a commands channel for a browser session, with a list of loops to run concurrently. + Return a message channel for a browser session, with a list of loops to run concurrently. """ - LOG.info("Getting commands channel for browser session.", browser_session_id=browser_session_id) + LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id) browser_session = await verify_browser_session( browser_session_id=browser_session_id, @@ -41,40 +42,40 @@ async def get_commands_for_browser_session( if not browser_session: LOG.info( - "Command channel: no initial browser session found.", + "Message channel: no initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id, ) return None - commands = sc.CommandChannel( + message_channel = sc.MessageChannel( client_id=client_id, organization_id=organization_id, browser_session=browser_session, websocket=websocket, ) - LOG.info("Got command channel for browser session.", commands=commands) + LOG.info("Got message channel for browser session.", message_channel=message_channel) loops = [ - asyncio.create_task(loop_verify_browser_session(commands)), - asyncio.create_task(loop_channel(commands)), + asyncio.create_task(loop_verify_browser_session(message_channel)), + asyncio.create_task(loop_channel(message_channel)), ] - return commands, loops + return message_channel, loops -async def get_commands_for_workflow_run( +async def get_messages_for_workflow_run( client_id: str, workflow_run_id: str, organization_id: str, websocket: WebSocket, -) -> tuple[sc.CommandChannel, sc.Loops] | None: +) -> tuple[sc.MessageChannel, sc.Loops] | None: """ - Return a commands channel for a workflow run, with a list of loops to run concurrently. + Return a message channel for a workflow run, with a list of loops to run concurrently. """ - LOG.info("Getting commands channel for workflow run.", workflow_run_id=workflow_run_id) + LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id) workflow_run, browser_session = await verify_workflow_run( workflow_run_id=workflow_run_id, @@ -83,7 +84,7 @@ async def get_commands_for_workflow_run( if not workflow_run: LOG.info( - "Command channel: no initial workflow run found.", + "Message channel: no initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id, ) @@ -91,13 +92,13 @@ async def get_commands_for_workflow_run( if not browser_session: LOG.info( - "Command channel: no initial browser session found for workflow run.", + "Message channel: no initial browser session found for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id, ) return None - commands = sc.CommandChannel( + message_channel = sc.MessageChannel( client_id=client_id, organization_id=organization_id, browser_session=browser_session, @@ -105,78 +106,100 @@ async def get_commands_for_workflow_run( websocket=websocket, ) - LOG.info("Got command channel for workflow run.", commands=commands) + LOG.info("Got message channel for workflow run.", message_channel=message_channel) loops = [ - asyncio.create_task(loop_verify_workflow_run(commands)), - asyncio.create_task(loop_channel(commands)), + asyncio.create_task(loop_verify_workflow_run(message_channel)), + asyncio.create_task(loop_channel(message_channel)), ] - return commands, loops + return message_channel, loops -async def loop_channel(commands: sc.CommandChannel) -> None: +async def loop_channel(message_channel: sc.MessageChannel) -> None: """ - Stream commands and their results back and forth. + Stream messages and their results back and forth. Loops until the workflow run is cleared or the websocket is closed. """ - if not commands.browser_session: + if not message_channel.browser_session: LOG.info( "No browser session found for workflow run.", - workflow_run=commands.workflow_run, - organization_id=commands.organization_id, + workflow_run=message_channel.workflow_run, + organization_id=message_channel.organization_id, ) return async def frontend_to_backend() -> None: - LOG.info("Starting frontend-to-backend channel loop.", commands=commands) + LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel) - while commands.is_open: + while message_channel.is_open: try: - data = await commands.websocket.receive_json() + data = await message_channel.websocket.receive_json() if not isinstance(data, dict): - LOG.error(f"Cannot create channel command: expected dict, got {type(data)}") + LOG.error(f"Cannot create channel message: expected dict, got {type(data)}") continue try: - command = sc.reify_channel_command(data) + message = sc.reify_channel_message(data) except ValueError: continue - command_kind = command.kind + message_kind = message.kind - match command_kind: + match message_kind: case "take-control": - streaming = sc.get_streaming_client(commands.client_id) + streaming = sc.get_streaming_client(message_channel.client_id) if not streaming: - LOG.error("No streaming client found for command.", commands=commands, command=command) + LOG.error( + "No streaming client found for message.", + message_channel=message_channel, + message=message, + ) continue streaming.interactor = "user" case "cede-control": - streaming = sc.get_streaming_client(commands.client_id) + streaming = sc.get_streaming_client(message_channel.client_id) if not streaming: - LOG.error("No streaming client found for command.", commands=commands, command=command) + LOG.error( + "No streaming client found for message.", + message_channel=message_channel, + message=message, + ) continue streaming.interactor = "agent" + case "ask-for-clipboard-response": + if not isinstance(message, sc.MessageInAskForClipboardResponse): + LOG.error( + "Invalid message type for ask-for-clipboard-response.", + message_channel=message_channel, + message=message, + ) + continue + + streaming = sc.get_streaming_client(message_channel.client_id) + text = message.text + + async with connected_agent(streaming) as agent: + await agent.paste_text(text) case _: - LOG.error(f"Unknown command kind: '{command_kind}'") + LOG.error(f"Unknown message kind: '{message_kind}'") continue except WebSocketDisconnect: LOG.info( "Frontend disconnected.", - workflow_run=commands.workflow_run, - organization_id=commands.organization_id, + workflow_run=message_channel.workflow_run, + organization_id=message_channel.organization_id, ) raise except ConnectionClosedError: LOG.info( "Frontend closed the streaming session.", - workflow_run=commands.workflow_run, - organization_id=commands.organization_id, + workflow_run=message_channel.workflow_run, + organization_id=message_channel.organization_id, ) raise except asyncio.CancelledError: @@ -184,8 +207,8 @@ async def loop_channel(commands: sc.CommandChannel) -> None: except Exception: LOG.exception( "An unexpected exception occurred.", - workflow_run=commands.workflow_run, - organization_id=commands.organization_id, + workflow_run=message_channel.workflow_run, + organization_id=message_channel.organization_id, ) raise @@ -198,27 +221,28 @@ async def loop_channel(commands: sc.CommandChannel) -> None: except Exception: LOG.exception( "An exception occurred in loop channel stream.", - workflow_run=commands.workflow_run, - organization_id=commands.organization_id, + workflow_run=message_channel.workflow_run, + organization_id=message_channel.organization_id, ) finally: LOG.info( "Closing the loop channel stream.", - workflow_run=commands.workflow_run, - organization_id=commands.organization_id, + workflow_run=message_channel.workflow_run, + organization_id=message_channel.organization_id, ) - await commands.close(reason="loop-channel-closed") + await message_channel.close(reason="loop-channel-closed") +@base_router.websocket("/stream/messages/browser_session/{browser_session_id}") @base_router.websocket("/stream/commands/browser_session/{browser_session_id}") -async def browser_session_commands( +async def browser_session_messages( websocket: WebSocket, browser_session_id: str, apikey: str | None = None, client_id: str | None = None, token: str | None = None, ) -> None: - LOG.info("Starting stream commands for browser session.", browser_session_id=browser_session_id) + LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id) organization_id = await auth(apikey=apikey, token=token, websocket=websocket) @@ -230,10 +254,10 @@ async def browser_session_commands( LOG.error("No client ID provided.", browser_session_id=browser_session_id) return - commands: sc.CommandChannel + message_channel: sc.MessageChannel loops: list[asyncio.Task] = [] - result = await get_commands_for_browser_session( + result = await get_messages_for_browser_session( client_id=client_id, browser_session_id=browser_session_id, organization_id=organization_id, @@ -249,40 +273,41 @@ async def browser_session_commands( await websocket.close(code=1013) return - commands, loops = result + message_channel, loops = result try: LOG.info( - "Starting command stream loops for browser session.", + "Starting message stream loops for browser session.", browser_session_id=browser_session_id, organization_id=organization_id, ) await collect(loops) except Exception: LOG.exception( - "An exception occurred in the command stream function for browser session.", + "An exception occurred in the message stream function for browser session.", browser_session_id=browser_session_id, organization_id=organization_id, ) finally: LOG.info( - "Closing the command stream session for browser session.", + "Closing the message stream session for browser session.", browser_session_id=browser_session_id, organization_id=organization_id, ) - await commands.close(reason="stream-closed") + await message_channel.close(reason="stream-closed") +@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}") @legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}") -async def workflow_run_commands( +async def workflow_run_messages( websocket: WebSocket, workflow_run_id: str, apikey: str | None = None, client_id: str | None = None, token: str | None = None, ) -> None: - LOG.info("Starting stream commands.", workflow_run_id=workflow_run_id) + LOG.info("Starting message stream.", workflow_run_id=workflow_run_id) organization_id = await auth(apikey=apikey, token=token, websocket=websocket) @@ -294,10 +319,10 @@ async def workflow_run_commands( LOG.error("No client ID provided.", workflow_run_id=workflow_run_id) return - commands: sc.CommandChannel + message_channel: sc.MessageChannel loops: list[asyncio.Task] = [] - result = await get_commands_for_workflow_run( + result = await get_messages_for_workflow_run( client_id=client_id, workflow_run_id=workflow_run_id, organization_id=organization_id, @@ -313,26 +338,26 @@ async def workflow_run_commands( await websocket.close(code=1013) return - commands, loops = result + message_channel, loops = result try: LOG.info( - "Starting command stream loops.", + "Starting message stream loops.", workflow_run_id=workflow_run_id, organization_id=organization_id, ) await collect(loops) except Exception: LOG.exception( - "An exception occurred in the command stream function.", + "An exception occurred in the message stream function.", workflow_run_id=workflow_run_id, organization_id=organization_id, ) finally: LOG.info( - "Closing the command stream session.", + "Closing the message stream session.", workflow_run_id=workflow_run_id, organization_id=organization_id, ) - await commands.close(reason="stream-closed") + await message_channel.close(reason="stream-closed") diff --git a/skyvern/forge/sdk/routes/streaming_verify.py b/skyvern/forge/sdk/routes/streaming_verify.py index f6e47943..5f28542d 100644 --- a/skyvern/forge/sdk/routes/streaming_verify.py +++ b/skyvern/forge/sdk/routes/streaming_verify.py @@ -233,7 +233,7 @@ async def verify_workflow_run( return workflow_run, addressable_browser_session -async def loop_verify_browser_session(verifiable: sc.CommandChannel | sc.Streaming) -> None: +async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streaming) -> None: """ Loop until the browser session is cleared or the websocket is closed. """ @@ -266,7 +266,7 @@ async def loop_verify_task(streaming: sc.Streaming) -> None: await asyncio.sleep(2) -async def loop_verify_workflow_run(verifiable: sc.CommandChannel | sc.Streaming) -> None: +async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) -> None: """ Loop until the workflow run is cleared or the websocket is closed. """ diff --git a/skyvern/forge/sdk/routes/streaming_vnc.py b/skyvern/forge/sdk/routes/streaming_vnc.py index 577c6b3b..b46baf55 100644 --- a/skyvern/forge/sdk/routes/streaming_vnc.py +++ b/skyvern/forge/sdk/routes/streaming_vnc.py @@ -1,8 +1,16 @@ """ Streaming VNC WebSocket connections. + +NOTE(jdo:streaming-local-dev) +----------------------------- + - grep the above for local development seams + - augment those seams as indicated, then + - stand up https://github.com/jomido/whyvern + """ import asyncio +import typing as t from urllib.parse import urlparse import structlog @@ -14,6 +22,7 @@ from websockets.exceptions import ConnectionClosedError import skyvern.forge.sdk.routes.streaming_clients as sc from skyvern.config import settings from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router +from skyvern.forge.sdk.routes.streaming_agent import connected_agent from skyvern.forge.sdk.routes.streaming_auth import auth from skyvern.forge.sdk.routes.streaming_verify import ( loop_verify_browser_session, @@ -157,6 +166,68 @@ async def get_streaming_for_workflow_run( return streaming, loops +def verify_message_channel( + message_channel: sc.MessageChannel | None, streaming: sc.Streaming +) -> sc.MessageChannel | t.Literal[False]: + if message_channel and message_channel.is_open: + return message_channel + + LOG.warning( + "No message channel found for client, or it is not open", + message_channel=message_channel, + client_id=streaming.client_id, + organization_id=streaming.organization_id, + ) + + return False + + +async def copy_text(streaming: sc.Streaming) -> None: + try: + async with connected_agent(streaming) as agent: + copied_text = await agent.get_selected_text() + + LOG.info( + "Retrieved selected text via CDP", + organization_id=streaming.organization_id, + ) + + message_channel = sc.get_message_client(streaming.client_id) + + if cc := verify_message_channel(message_channel, streaming): + await cc.send_copied_text(copied_text, streaming) + else: + LOG.warning( + "No message channel found for client, or it is not open", + message_channel=message_channel, + client_id=streaming.client_id, + organization_id=streaming.organization_id, + ) + except Exception: + LOG.exception( + "Failed to retrieve selected text via CDP", + organization_id=streaming.organization_id, + ) + + +async def ask_for_clipboard(streaming: sc.Streaming) -> None: + try: + LOG.info( + "Asking for clipboard data via CDP", + organization_id=streaming.organization_id, + ) + + message_channel = sc.get_message_client(streaming.client_id) + + if cc := verify_message_channel(message_channel, streaming): + await cc.ask_for_clipboard(streaming) + except Exception: + LOG.exception( + "Failed to ask for clipboard via CDP", + organization_id=streaming.organization_id, + ) + + async def loop_stream_vnc(streaming: sc.Streaming) -> None: """ Actually stream the VNC session data between a frontend and a browser @@ -183,6 +254,9 @@ async def loop_stream_vnc(streaming: sc.Streaming) -> None: host = parsed_browser_address.hostname vnc_url = f"ws://{host}:{streaming.vnc_port}" + # NOTE(jdo:streaming-local-dev) + # vnc_url = "ws://localhost:9001/ws/novnc" + LOG.info( "Connecting to VNC URL.", vnc_url=vnc_url, @@ -207,6 +281,12 @@ async def loop_stream_vnc(streaming: sc.Streaming) -> None: if message_type == sc.MessageType.Keyboard.value: streaming.update_key_state(data) + if streaming.key_state.is_copy(data): + await copy_text(streaming) + + if streaming.key_state.is_paste(data): + await ask_for_clipboard(streaming) + if streaming.key_state.is_forbidden(data): continue