BE portion of seamless clipboard transfer in browser stream (#3788)

This commit is contained in:
Jonathan Dobson
2025-10-22 11:57:50 -04:00
committed by GitHub
parent 24763b6a5a
commit b52e88bd99
8 changed files with 445 additions and 97 deletions

View File

@@ -6,6 +6,6 @@ from skyvern.forge.sdk.routes import pylon # noqa: F401
from skyvern.forge.sdk.routes import run_blocks # noqa: F401
from skyvern.forge.sdk.routes import scripts # noqa: F401
from skyvern.forge.sdk.routes import streaming # noqa: F401
from skyvern.forge.sdk.routes import streaming_commands # noqa: F401
from skyvern.forge.sdk.routes import streaming_messages # noqa: F401
from skyvern.forge.sdk.routes import streaming_vnc # noqa: F401
from skyvern.forge.sdk.routes import webhooks # noqa: F401

View File

@@ -225,6 +225,8 @@ async def new_debug_session(
user_id=current_user_id,
workflow_permanent_id=workflow_permanent_id,
vnc_streaming_supported=True if new_browser_session.ip_address else False,
# NOTE(jdo:streaming-local-dev)
# vnc_streaming_supported=True,
)
LOG.info(

View File

@@ -0,0 +1,169 @@
"""
A lightweight "agent" for interacting with the streaming browser over CDP.
"""
import typing
from contextlib import asynccontextmanager
import structlog
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
import skyvern.forge.sdk.routes.streaming_clients as sc
from skyvern.config import settings
LOG = structlog.get_logger()
class StreamingAgent:
"""
A minimal agent that can connect to a browser via CDP and execute JavaScript.
Specifically for operations during streaming sessions (like copy/pasting selected text, etc.).
"""
def __init__(self) -> None:
self.browser: Browser | None = None
self.browser_context: BrowserContext | None = None
self.page: Page | None = None
self.pw: Playwright | None = None
async def connect(self, cdp_url: str | None = None) -> None:
url = cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL
LOG.info("StreamingAgent connecting to CDP", cdp_url=url)
pw = self.pw or await async_playwright().start()
self.pw = pw
self.browser = await pw.chromium.connect_over_cdp(url)
contexts = self.browser.contexts
if contexts:
LOG.info("StreamingAgent using existing browser context")
self.browser_context = contexts[0]
else:
LOG.warning("No existing browser context found, creating new one")
self.browser_context = await self.browser.new_context()
pages = self.browser_context.pages
if pages:
self.page = pages[0]
LOG.info("StreamingAgent connected to page", url=self.page.url)
else:
LOG.warning("No pages found in browser context")
LOG.info("StreamingAgent connected successfully")
async def evaluate_js(
self, expression: str, arg: str | int | float | bool | list | dict | None = None
) -> str | int | float | bool | list | dict | None:
if not self.page:
raise RuntimeError("StreamingAgent is not connected to a page. Call connect() first.")
LOG.info("StreamingAgent evaluating JS", expression=expression[:100])
try:
result = await self.page.evaluate(expression, arg)
LOG.info("StreamingAgent JS evaluation successful")
return result
except Exception as ex:
LOG.exception("StreamingAgent JS evaluation failed", expression=expression, ex=str(ex))
raise
async def get_selected_text(self) -> str:
LOG.info("StreamingAgent getting selected text")
js_expression = """
() => {
const selection = window.getSelection();
return selection ? selection.toString() : '';
}
"""
selected_text = await self.evaluate_js(js_expression)
if isinstance(selected_text, str) or selected_text is None:
LOG.info("StreamingAgent got selected text", length=len(selected_text) if selected_text else 0)
return selected_text or ""
raise RuntimeError(f"StreamingAgent selected text is not a string, but a(n) '{type(selected_text)}'")
async def paste_text(self, text: str) -> None:
LOG.info("StreamingAgent pasting text")
js_expression = """
(text) => {
const activeElement = document.activeElement;
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
const start = activeElement.selectionStart || 0;
const end = activeElement.selectionEnd || 0;
const value = activeElement.value || '';
activeElement.value = value.slice(0, start) + text + value.slice(end);
const newCursorPos = start + text.length;
activeElement.setSelectionRange(newCursorPos, newCursorPos);
}
}
"""
await self.evaluate_js(js_expression, text)
LOG.info("StreamingAgent pasted text successfully")
async def close(self) -> None:
LOG.info("StreamingAgent closing connection")
if self.browser:
await self.browser.close()
self.browser = None
self.browser_context = None
self.page = None
if self.pw:
await self.pw.stop()
self.pw = None
LOG.info("StreamingAgent closed")
@asynccontextmanager
async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterator[StreamingAgent]:
"""
The first pass at this has us doing the following for every operation:
- creating a new agent
- connecting
- [doing smth]
- closing the agent
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
If it turns out it's too slow, we can refactor to keep a persistent agent per streaming client.
"""
if not streaming:
msg = "connected_agent: no streaming client provided."
LOG.error(msg)
raise Exception(msg)
if not streaming.browser_session or not streaming.browser_session.browser_address:
msg = "connected_agent: no browser session or browser address found for streaming client."
LOG.error(
msg,
client_id=streaming.client_id,
organization_id=streaming.organization_id,
)
raise Exception(msg)
agent = StreamingAgent()
try:
await agent.connect(streaming.browser_session.browser_address)
# NOTE(jdo:streaming-local-dev): use BROWSER_REMOTE_DEBUGGING_URL from settings
# await agent.connect()
yield agent
finally:
await agent.close()

View File

@@ -44,3 +44,18 @@ async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> s
return None
return organization_id
# NOTE(jdo:streaming-local-dev): use this instead of the above `auth`
async def _auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
"""
Dummy auth for local testing.
"""
try:
await websocket.accept()
except ConnectionClosedOK:
LOG.info("WebSocket connection closed cleanly.")
return None
return "o_temp123"

View File

@@ -22,30 +22,30 @@ Interactor = t.Literal["agent", "user"]
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
# Commands
# Messages
# a global registry for WS command clients
command_channels: dict[str, "CommandChannel"] = {}
# a global registry for WS message clients
message_channels: dict[str, "MessageChannel"] = {}
def add_command_client(command_channel: "CommandChannel") -> None:
command_channels[command_channel.client_id] = command_channel
def add_message_client(message_channel: "MessageChannel") -> None:
message_channels[message_channel.client_id] = message_channel
def get_command_client(client_id: str) -> t.Union["CommandChannel", None]:
return command_channels.get(client_id, None)
def get_message_client(client_id: str) -> t.Union["MessageChannel", None]:
return message_channels.get(client_id, None)
def del_command_client(client_id: str) -> None:
def del_message_client(client_id: str) -> None:
try:
del command_channels[client_id]
del message_channels[client_id]
except KeyError:
pass
@dataclasses.dataclass
class CommandChannel:
class MessageChannel:
client_id: str
organization_id: str
websocket: WebSocket
@@ -56,10 +56,10 @@ class CommandChannel:
workflow_run: WorkflowRun | None = None
def __post_init__(self) -> None:
add_command_client(self)
add_message_client(self)
async def close(self, code: int = 1000, reason: str | None = None) -> "CommandChannel":
LOG.info("Closing command stream.", reason=reason, code=code)
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
LOG.info("Closing message stream.", reason=reason, code=code)
self.browser_session = None
self.workflow_run = None
@@ -69,7 +69,7 @@ class CommandChannel:
except Exception:
pass
del_command_client(self.client_id)
del_message_client(self.client_id)
return self
@@ -81,43 +81,87 @@ class CommandChannel:
if not self.workflow_run and not self.browser_session:
return False
if not get_command_client(self.client_id):
if not get_message_client(self.client_id):
return False
return True
async def ask_for_clipboard(self, streaming: "Streaming") -> None:
try:
await self.websocket.send_json(
{
"kind": "ask-for-clipboard",
}
)
LOG.info(
"Sent ask-for-clipboard to message channel",
organization_id=streaming.organization_id,
)
except Exception:
LOG.exception(
"Failed to send ask-for-clipboard to message channel",
organization_id=streaming.organization_id,
)
CommandKinds = t.Literal["take-control", "cede-control"]
async def send_copied_text(self, copied_text: str, streaming: "Streaming") -> None:
try:
await self.websocket.send_json(
{
"kind": "copied-text",
"text": copied_text,
}
)
LOG.info(
"Sent copied text to message channel",
organization_id=streaming.organization_id,
)
except Exception:
LOG.exception(
"Failed to send copied text to message channel",
organization_id=streaming.organization_id,
)
MessageKinds = t.Literal["take-control", "cede-control", "ask-for-clipboard-response"]
@dataclasses.dataclass
class Command:
kind: CommandKinds
class Message:
kind: MessageKinds
@dataclasses.dataclass
class CommandTakeControl(Command):
class MessageTakeControl(Message):
kind: t.Literal["take-control"] = "take-control"
@dataclasses.dataclass
class CommandCedeControl(Command):
class MessageCedeControl(Message):
kind: t.Literal["cede-control"] = "cede-control"
ChannelCommand = t.Union[CommandTakeControl, CommandCedeControl]
@dataclasses.dataclass
class MessageInAskForClipboardResponse(Message):
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
text: str = ""
def reify_channel_command(data: dict) -> ChannelCommand:
ChannelMessage = t.Union[MessageTakeControl, MessageCedeControl, MessageInAskForClipboardResponse]
def reify_channel_message(data: dict) -> ChannelMessage:
kind = data.get("kind", None)
match kind:
case "take-control":
return CommandTakeControl()
return MessageTakeControl()
case "cede-control":
return CommandCedeControl()
return MessageCedeControl()
case "ask-for-clipboard-response":
text = data.get("text") or ""
return MessageInAskForClipboardResponse(text=text)
case _:
raise ValueError(f"Unknown command kind: '{kind}'")
raise ValueError(f"Unknown message kind: '{kind}'")
# Streaming
@@ -159,7 +203,9 @@ class Keys:
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
class Up:
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
@@ -181,7 +227,6 @@ class KeyState:
ctrl_is_down: bool = False
alt_is_down: bool = False
cmd_is_down: bool = False
o_is_down: bool = False
def is_forbidden(self, data: bytes) -> bool:
"""
@@ -195,6 +240,18 @@ class KeyState:
"""
return self.ctrl_is_down and data == Keys.Down.OKey
def is_copy(self, data: bytes) -> bool:
"""
Detect Ctrl+C or Cmd+C for copy.
"""
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
def is_paste(self, data: bytes) -> bool:
"""
Detect Ctrl+V or Cmd+V for paste.
"""
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
@dataclasses.dataclass
class Streaming:

View File

@@ -1,5 +1,5 @@
"""
Streaming commands WebSocket connections.
Streaming messages for WebSocket connections.
"""
import asyncio
@@ -10,6 +10,7 @@ from websockets.exceptions import ConnectionClosedError
import skyvern.forge.sdk.routes.streaming_clients as sc
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming_agent import connected_agent
from skyvern.forge.sdk.routes.streaming_auth import auth
from skyvern.forge.sdk.routes.streaming_verify import (
loop_verify_browser_session,
@@ -22,17 +23,17 @@ from skyvern.forge.sdk.utils.aio import collect
LOG = structlog.get_logger()
async def get_commands_for_browser_session(
async def get_messages_for_browser_session(
client_id: str,
browser_session_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.CommandChannel, sc.Loops] | None:
) -> tuple[sc.MessageChannel, sc.Loops] | None:
"""
Return a commands channel for a browser session, with a list of loops to run concurrently.
Return a message channel for a browser session, with a list of loops to run concurrently.
"""
LOG.info("Getting commands channel for browser session.", browser_session_id=browser_session_id)
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
browser_session = await verify_browser_session(
browser_session_id=browser_session_id,
@@ -41,40 +42,40 @@ async def get_commands_for_browser_session(
if not browser_session:
LOG.info(
"Command channel: no initial browser session found.",
"Message channel: no initial browser session found.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
return None
commands = sc.CommandChannel(
message_channel = sc.MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
websocket=websocket,
)
LOG.info("Got command channel for browser session.", commands=commands)
LOG.info("Got message channel for browser session.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_browser_session(commands)),
asyncio.create_task(loop_channel(commands)),
asyncio.create_task(loop_verify_browser_session(message_channel)),
asyncio.create_task(loop_channel(message_channel)),
]
return commands, loops
return message_channel, loops
async def get_commands_for_workflow_run(
async def get_messages_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.CommandChannel, sc.Loops] | None:
) -> tuple[sc.MessageChannel, sc.Loops] | None:
"""
Return a commands channel for a workflow run, with a list of loops to run concurrently.
Return a message channel for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting commands channel for workflow run.", workflow_run_id=workflow_run_id)
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
@@ -83,7 +84,7 @@ async def get_commands_for_workflow_run(
if not workflow_run:
LOG.info(
"Command channel: no initial workflow run found.",
"Message channel: no initial workflow run found.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
@@ -91,13 +92,13 @@ async def get_commands_for_workflow_run(
if not browser_session:
LOG.info(
"Command channel: no initial browser session found for workflow run.",
"Message channel: no initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
commands = sc.CommandChannel(
message_channel = sc.MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
@@ -105,78 +106,100 @@ async def get_commands_for_workflow_run(
websocket=websocket,
)
LOG.info("Got command channel for workflow run.", commands=commands)
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_workflow_run(commands)),
asyncio.create_task(loop_channel(commands)),
asyncio.create_task(loop_verify_workflow_run(message_channel)),
asyncio.create_task(loop_channel(message_channel)),
]
return commands, loops
return message_channel, loops
async def loop_channel(commands: sc.CommandChannel) -> None:
async def loop_channel(message_channel: sc.MessageChannel) -> None:
"""
Stream commands and their results back and forth.
Stream messages and their results back and forth.
Loops until the workflow run is cleared or the websocket is closed.
"""
if not commands.browser_session:
if not message_channel.browser_session:
LOG.info(
"No browser session found for workflow run.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
return
async def frontend_to_backend() -> None:
LOG.info("Starting frontend-to-backend channel loop.", commands=commands)
LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel)
while commands.is_open:
while message_channel.is_open:
try:
data = await commands.websocket.receive_json()
data = await message_channel.websocket.receive_json()
if not isinstance(data, dict):
LOG.error(f"Cannot create channel command: expected dict, got {type(data)}")
LOG.error(f"Cannot create channel message: expected dict, got {type(data)}")
continue
try:
command = sc.reify_channel_command(data)
message = sc.reify_channel_message(data)
except ValueError:
continue
command_kind = command.kind
message_kind = message.kind
match command_kind:
match message_kind:
case "take-control":
streaming = sc.get_streaming_client(commands.client_id)
streaming = sc.get_streaming_client(message_channel.client_id)
if not streaming:
LOG.error("No streaming client found for command.", commands=commands, command=command)
LOG.error(
"No streaming client found for message.",
message_channel=message_channel,
message=message,
)
continue
streaming.interactor = "user"
case "cede-control":
streaming = sc.get_streaming_client(commands.client_id)
streaming = sc.get_streaming_client(message_channel.client_id)
if not streaming:
LOG.error("No streaming client found for command.", commands=commands, command=command)
LOG.error(
"No streaming client found for message.",
message_channel=message_channel,
message=message,
)
continue
streaming.interactor = "agent"
case "ask-for-clipboard-response":
if not isinstance(message, sc.MessageInAskForClipboardResponse):
LOG.error(
"Invalid message type for ask-for-clipboard-response.",
message_channel=message_channel,
message=message,
)
continue
streaming = sc.get_streaming_client(message_channel.client_id)
text = message.text
async with connected_agent(streaming) as agent:
await agent.paste_text(text)
case _:
LOG.error(f"Unknown command kind: '{command_kind}'")
LOG.error(f"Unknown message kind: '{message_kind}'")
continue
except WebSocketDisconnect:
LOG.info(
"Frontend disconnected.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
except asyncio.CancelledError:
@@ -184,8 +207,8 @@ async def loop_channel(commands: sc.CommandChannel) -> None:
except Exception:
LOG.exception(
"An unexpected exception occurred.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
@@ -198,27 +221,28 @@ async def loop_channel(commands: sc.CommandChannel) -> None:
except Exception:
LOG.exception(
"An exception occurred in loop channel stream.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
finally:
LOG.info(
"Closing the loop channel stream.",
workflow_run=commands.workflow_run,
organization_id=commands.organization_id,
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
await commands.close(reason="loop-channel-closed")
await message_channel.close(reason="loop-channel-closed")
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
async def browser_session_commands(
async def browser_session_messages(
websocket: WebSocket,
browser_session_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info("Starting stream commands for browser session.", browser_session_id=browser_session_id)
LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
@@ -230,10 +254,10 @@ async def browser_session_commands(
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
return
commands: sc.CommandChannel
message_channel: sc.MessageChannel
loops: list[asyncio.Task] = []
result = await get_commands_for_browser_session(
result = await get_messages_for_browser_session(
client_id=client_id,
browser_session_id=browser_session_id,
organization_id=organization_id,
@@ -249,40 +273,41 @@ async def browser_session_commands(
await websocket.close(code=1013)
return
commands, loops = result
message_channel, loops = result
try:
LOG.info(
"Starting command stream loops for browser session.",
"Starting message stream loops for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the command stream function for browser session.",
"An exception occurred in the message stream function for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the command stream session for browser session.",
"Closing the message stream session for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await commands.close(reason="stream-closed")
await message_channel.close(reason="stream-closed")
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
async def workflow_run_commands(
async def workflow_run_messages(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info("Starting stream commands.", workflow_run_id=workflow_run_id)
LOG.info("Starting message stream.", workflow_run_id=workflow_run_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
@@ -294,10 +319,10 @@ async def workflow_run_commands(
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
return
commands: sc.CommandChannel
message_channel: sc.MessageChannel
loops: list[asyncio.Task] = []
result = await get_commands_for_workflow_run(
result = await get_messages_for_workflow_run(
client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
@@ -313,26 +338,26 @@ async def workflow_run_commands(
await websocket.close(code=1013)
return
commands, loops = result
message_channel, loops = result
try:
LOG.info(
"Starting command stream loops.",
"Starting message stream loops.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the command stream function.",
"An exception occurred in the message stream function.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the command stream session.",
"Closing the message stream session.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await commands.close(reason="stream-closed")
await message_channel.close(reason="stream-closed")

View File

@@ -233,7 +233,7 @@ async def verify_workflow_run(
return workflow_run, addressable_browser_session
async def loop_verify_browser_session(verifiable: sc.CommandChannel | sc.Streaming) -> None:
async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streaming) -> None:
"""
Loop until the browser session is cleared or the websocket is closed.
"""
@@ -266,7 +266,7 @@ async def loop_verify_task(streaming: sc.Streaming) -> None:
await asyncio.sleep(2)
async def loop_verify_workflow_run(verifiable: sc.CommandChannel | sc.Streaming) -> None:
async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) -> None:
"""
Loop until the workflow run is cleared or the websocket is closed.
"""

View File

@@ -1,8 +1,16 @@
"""
Streaming VNC WebSocket connections.
NOTE(jdo:streaming-local-dev)
-----------------------------
- grep the above for local development seams
- augment those seams as indicated, then
- stand up https://github.com/jomido/whyvern
"""
import asyncio
import typing as t
from urllib.parse import urlparse
import structlog
@@ -14,6 +22,7 @@ from websockets.exceptions import ConnectionClosedError
import skyvern.forge.sdk.routes.streaming_clients as sc
from skyvern.config import settings
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming_agent import connected_agent
from skyvern.forge.sdk.routes.streaming_auth import auth
from skyvern.forge.sdk.routes.streaming_verify import (
loop_verify_browser_session,
@@ -157,6 +166,68 @@ async def get_streaming_for_workflow_run(
return streaming, loops
def verify_message_channel(
message_channel: sc.MessageChannel | None, streaming: sc.Streaming
) -> sc.MessageChannel | t.Literal[False]:
if message_channel and message_channel.is_open:
return message_channel
LOG.warning(
"No message channel found for client, or it is not open",
message_channel=message_channel,
client_id=streaming.client_id,
organization_id=streaming.organization_id,
)
return False
async def copy_text(streaming: sc.Streaming) -> None:
try:
async with connected_agent(streaming) as agent:
copied_text = await agent.get_selected_text()
LOG.info(
"Retrieved selected text via CDP",
organization_id=streaming.organization_id,
)
message_channel = sc.get_message_client(streaming.client_id)
if cc := verify_message_channel(message_channel, streaming):
await cc.send_copied_text(copied_text, streaming)
else:
LOG.warning(
"No message channel found for client, or it is not open",
message_channel=message_channel,
client_id=streaming.client_id,
organization_id=streaming.organization_id,
)
except Exception:
LOG.exception(
"Failed to retrieve selected text via CDP",
organization_id=streaming.organization_id,
)
async def ask_for_clipboard(streaming: sc.Streaming) -> None:
try:
LOG.info(
"Asking for clipboard data via CDP",
organization_id=streaming.organization_id,
)
message_channel = sc.get_message_client(streaming.client_id)
if cc := verify_message_channel(message_channel, streaming):
await cc.ask_for_clipboard(streaming)
except Exception:
LOG.exception(
"Failed to ask for clipboard via CDP",
organization_id=streaming.organization_id,
)
async def loop_stream_vnc(streaming: sc.Streaming) -> None:
"""
Actually stream the VNC session data between a frontend and a browser
@@ -183,6 +254,9 @@ async def loop_stream_vnc(streaming: sc.Streaming) -> None:
host = parsed_browser_address.hostname
vnc_url = f"ws://{host}:{streaming.vnc_port}"
# NOTE(jdo:streaming-local-dev)
# vnc_url = "ws://localhost:9001/ws/novnc"
LOG.info(
"Connecting to VNC URL.",
vnc_url=vnc_url,
@@ -207,6 +281,12 @@ async def loop_stream_vnc(streaming: sc.Streaming) -> None:
if message_type == sc.MessageType.Keyboard.value:
streaming.update_key_state(data)
if streaming.key_state.is_copy(data):
await copy_text(streaming)
if streaming.key_state.is_paste(data):
await ask_for_clipboard(streaming)
if streaming.key_state.is_forbidden(data):
continue