Browser streaming: reorganize & rename (#4033)
This commit is contained in:
191
skyvern/forge/sdk/routes/streaming/agent.py
Normal file
191
skyvern/forge/sdk/routes/streaming/agent.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
A lightweight "agent" for interacting with the streaming browser over CDP.
|
||||
"""
|
||||
|
||||
import typing
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
||||
from skyvern.config import settings
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class StreamingAgent:
|
||||
"""
|
||||
A minimal agent that can connect to a browser via CDP and execute JavaScript.
|
||||
|
||||
Specifically for operations during streaming sessions (like copy/pasting selected text, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, streaming: sc.Streaming) -> None:
|
||||
self.streaming = streaming
|
||||
self.browser: Browser | None = None
|
||||
self.browser_context: BrowserContext | None = None
|
||||
self.page: Page | None = None
|
||||
self.pw: Playwright | None = None
|
||||
|
||||
async def connect(self, cdp_url: str | None = None) -> None:
|
||||
url = cdp_url or settings.BROWSER_REMOTE_DEBUGGING_URL
|
||||
|
||||
LOG.info("StreamingAgent connecting to CDP", cdp_url=url)
|
||||
|
||||
pw = self.pw or await async_playwright().start()
|
||||
|
||||
self.pw = pw
|
||||
|
||||
headers = {
|
||||
"x-api-key": self.streaming.x_api_key,
|
||||
}
|
||||
|
||||
self.browser = await pw.chromium.connect_over_cdp(url, headers=headers)
|
||||
|
||||
org_id = self.streaming.organization_id
|
||||
browser_session_id = (
|
||||
self.streaming.browser_session.persistent_browser_session_id if self.streaming.browser_session else None
|
||||
)
|
||||
|
||||
if browser_session_id:
|
||||
cdp_session = await self.browser.new_browser_cdp_session()
|
||||
await cdp_session.send(
|
||||
"Browser.setDownloadBehavior",
|
||||
{
|
||||
"behavior": "allow",
|
||||
"downloadPath": f"/app/downloads/{org_id}/{browser_session_id}",
|
||||
"eventsEnabled": True,
|
||||
},
|
||||
)
|
||||
|
||||
contexts = self.browser.contexts
|
||||
if contexts:
|
||||
LOG.info("StreamingAgent using existing browser context")
|
||||
self.browser_context = contexts[0]
|
||||
else:
|
||||
LOG.warning("No existing browser context found, creating new one")
|
||||
self.browser_context = await self.browser.new_context()
|
||||
|
||||
pages = self.browser_context.pages
|
||||
if pages:
|
||||
self.page = pages[0]
|
||||
LOG.info("StreamingAgent connected to page", url=self.page.url)
|
||||
else:
|
||||
LOG.warning("No pages found in browser context")
|
||||
|
||||
LOG.info("StreamingAgent connected successfully")
|
||||
|
||||
async def evaluate_js(
|
||||
self, expression: str, arg: str | int | float | bool | list | dict | None = None
|
||||
) -> str | int | float | bool | list | dict | None:
|
||||
if not self.page:
|
||||
raise RuntimeError("StreamingAgent is not connected to a page. Call connect() first.")
|
||||
|
||||
LOG.info("StreamingAgent evaluating JS", expression=expression[:100])
|
||||
|
||||
try:
|
||||
result = await self.page.evaluate(expression, arg)
|
||||
LOG.info("StreamingAgent JS evaluation successful")
|
||||
return result
|
||||
except Exception as ex:
|
||||
LOG.exception("StreamingAgent JS evaluation failed", expression=expression, ex=str(ex))
|
||||
raise
|
||||
|
||||
async def get_selected_text(self) -> str:
|
||||
LOG.info("StreamingAgent getting selected text")
|
||||
|
||||
js_expression = """
|
||||
() => {
|
||||
const selection = window.getSelection();
|
||||
return selection ? selection.toString() : '';
|
||||
}
|
||||
"""
|
||||
|
||||
selected_text = await self.evaluate_js(js_expression)
|
||||
|
||||
if isinstance(selected_text, str) or selected_text is None:
|
||||
LOG.info("StreamingAgent got selected text", length=len(selected_text) if selected_text else 0)
|
||||
return selected_text or ""
|
||||
|
||||
raise RuntimeError(f"StreamingAgent selected text is not a string, but a(n) '{type(selected_text)}'")
|
||||
|
||||
async def paste_text(self, text: str) -> None:
|
||||
LOG.info("StreamingAgent pasting text")
|
||||
|
||||
js_expression = """
|
||||
(text) => {
|
||||
const activeElement = document.activeElement;
|
||||
if (activeElement && (activeElement.tagName === 'INPUT' || activeElement.tagName === 'TEXTAREA' || activeElement.isContentEditable)) {
|
||||
const start = activeElement.selectionStart || 0;
|
||||
const end = activeElement.selectionEnd || 0;
|
||||
const value = activeElement.value || '';
|
||||
activeElement.value = value.slice(0, start) + text + value.slice(end);
|
||||
const newCursorPos = start + text.length;
|
||||
activeElement.setSelectionRange(newCursorPos, newCursorPos);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
await self.evaluate_js(js_expression, text)
|
||||
|
||||
LOG.info("StreamingAgent pasted text successfully")
|
||||
|
||||
async def close(self) -> None:
|
||||
LOG.info("StreamingAgent closing connection")
|
||||
|
||||
if self.browser:
|
||||
await self.browser.close()
|
||||
self.browser = None
|
||||
self.browser_context = None
|
||||
self.page = None
|
||||
|
||||
if self.pw:
|
||||
await self.pw.stop()
|
||||
self.pw = None
|
||||
|
||||
LOG.info("StreamingAgent closed")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterator[StreamingAgent]:
|
||||
"""
|
||||
The first pass at this has us doing the following for every operation:
|
||||
- creating a new agent
|
||||
- connecting
|
||||
- [doing smth]
|
||||
- closing the agent
|
||||
|
||||
This may add latency, but locally it is pretty fast. This keeps things stateless for now.
|
||||
|
||||
If it turns out it's too slow, we can refactor to keep a persistent agent per streaming client.
|
||||
"""
|
||||
|
||||
if not streaming:
|
||||
msg = "connected_agent: no streaming client provided."
|
||||
LOG.error(msg)
|
||||
|
||||
raise Exception(msg)
|
||||
|
||||
if not streaming.browser_session or not streaming.browser_session.browser_address:
|
||||
msg = "connected_agent: no browser session or browser address found for streaming client."
|
||||
|
||||
LOG.error(
|
||||
msg,
|
||||
client_id=streaming.client_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
raise Exception(msg)
|
||||
|
||||
agent = StreamingAgent(streaming=streaming)
|
||||
|
||||
try:
|
||||
await agent.connect(streaming.browser_session.browser_address)
|
||||
|
||||
# NOTE(jdo:streaming-local-dev): use BROWSER_REMOTE_DEBUGGING_URL from settings
|
||||
# await agent.connect()
|
||||
|
||||
yield agent
|
||||
finally:
|
||||
await agent.close()
|
||||
61
skyvern/forge/sdk/routes/streaming/auth.py
Normal file
61
skyvern/forge/sdk/routes/streaming/auth.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Streaming auth.
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from skyvern.forge.sdk.services.org_auth_service import get_current_org
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
||||
"""
|
||||
Accepts the websocket connection.
|
||||
|
||||
Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
|
||||
determined.
|
||||
"""
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
if not token and not apikey:
|
||||
await websocket.close(code=1002)
|
||||
return None
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WebSocket connection closed cleanly.")
|
||||
return None
|
||||
|
||||
try:
|
||||
organization = await get_current_org(x_api_key=apikey, authorization=token)
|
||||
organization_id = organization.organization_id
|
||||
|
||||
if not organization_id:
|
||||
await websocket.close(code=1002)
|
||||
return None
|
||||
except Exception:
|
||||
LOG.exception("Error occurred while retrieving organization information.")
|
||||
try:
|
||||
await websocket.close(code=1002)
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WebSocket connection closed due to invalid credentials.")
|
||||
return None
|
||||
|
||||
return organization_id
|
||||
|
||||
|
||||
# NOTE(jdo:streaming-local-dev): use this instead of the above `auth`
|
||||
async def _auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
||||
"""
|
||||
Dummy auth for local testing.
|
||||
"""
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WebSocket connection closed cleanly.")
|
||||
return None
|
||||
|
||||
return "o_temp123"
|
||||
328
skyvern/forge/sdk/routes/streaming/clients.py
Normal file
328
skyvern/forge/sdk/routes/streaming/clients.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Streaming types.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.tasks import Task
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
Interactor = t.Literal["agent", "user"]
|
||||
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||
|
||||
|
||||
# Messages
|
||||
|
||||
|
||||
# a global registry for WS message clients
|
||||
message_channels: dict[str, "MessageChannel"] = {}
|
||||
|
||||
|
||||
def add_message_client(message_channel: "MessageChannel") -> None:
|
||||
message_channels[message_channel.client_id] = message_channel
|
||||
|
||||
|
||||
def get_message_client(client_id: str) -> t.Union["MessageChannel", None]:
|
||||
return message_channels.get(client_id, None)
|
||||
|
||||
|
||||
def del_message_client(client_id: str) -> None:
|
||||
try:
|
||||
del message_channels[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageChannel:
|
||||
client_id: str
|
||||
organization_id: str
|
||||
websocket: WebSocket
|
||||
|
||||
# --
|
||||
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
add_message_client(self)
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
|
||||
LOG.info("Closing message stream.", reason=reason, code=code)
|
||||
|
||||
self.browser_session = None
|
||||
self.workflow_run = None
|
||||
|
||||
try:
|
||||
await self.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
del_message_client(self.client_id)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||
return False
|
||||
|
||||
if not self.workflow_run and not self.browser_session:
|
||||
return False
|
||||
|
||||
if not get_message_client(self.client_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def ask_for_clipboard(self, streaming: "Streaming") -> None:
|
||||
try:
|
||||
await self.websocket.send_json(
|
||||
{
|
||||
"kind": "ask-for-clipboard",
|
||||
}
|
||||
)
|
||||
LOG.info(
|
||||
"Sent ask-for-clipboard to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to send ask-for-clipboard to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
async def send_copied_text(self, copied_text: str, streaming: "Streaming") -> None:
|
||||
try:
|
||||
await self.websocket.send_json(
|
||||
{
|
||||
"kind": "copied-text",
|
||||
"text": copied_text,
|
||||
}
|
||||
)
|
||||
LOG.info(
|
||||
"Sent copied text to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to send copied text to message channel",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
|
||||
MessageKinds = t.Literal["take-control", "cede-control", "ask-for-clipboard-response"]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Message:
|
||||
kind: MessageKinds
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageTakeControl(Message):
|
||||
kind: t.Literal["take-control"] = "take-control"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageCedeControl(Message):
|
||||
kind: t.Literal["cede-control"] = "cede-control"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MessageInAskForClipboardResponse(Message):
|
||||
kind: t.Literal["ask-for-clipboard-response"] = "ask-for-clipboard-response"
|
||||
text: str = ""
|
||||
|
||||
|
||||
ChannelMessage = t.Union[MessageTakeControl, MessageCedeControl, MessageInAskForClipboardResponse]
|
||||
|
||||
|
||||
def reify_channel_message(data: dict) -> ChannelMessage:
|
||||
kind = data.get("kind", None)
|
||||
|
||||
match kind:
|
||||
case "take-control":
|
||||
return MessageTakeControl()
|
||||
case "cede-control":
|
||||
return MessageCedeControl()
|
||||
case "ask-for-clipboard-response":
|
||||
text = data.get("text") or ""
|
||||
return MessageInAskForClipboardResponse(text=text)
|
||||
case _:
|
||||
raise ValueError(f"Unknown message kind: '{kind}'")
|
||||
|
||||
|
||||
# Streaming
|
||||
|
||||
|
||||
# a global registry for WS streaming VNC clients
|
||||
streaming_clients: dict[str, "Streaming"] = {}
|
||||
|
||||
|
||||
def add_streaming_client(streaming: "Streaming") -> None:
|
||||
streaming_clients[streaming.client_id] = streaming
|
||||
|
||||
|
||||
def get_streaming_client(client_id: str) -> t.Union["Streaming", None]:
|
||||
return streaming_clients.get(client_id, None)
|
||||
|
||||
|
||||
def del_streaming_client(client_id: str) -> None:
|
||||
try:
|
||||
del streaming_clients[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
class MessageType(IntEnum):
|
||||
Keyboard = 4
|
||||
Mouse = 5
|
||||
|
||||
|
||||
class Keys:
|
||||
"""
|
||||
VNC RFB keycodes. There's likely a pithier repr (indexes 6-7). This is ok for now.
|
||||
|
||||
ref: https://www.notion.so/References-21c426c42cd480fb9258ecc9eb8f09b4
|
||||
ref: https://github.com/novnc/noVNC/blob/master/docs/rfbproto-3.8.pdf
|
||||
"""
|
||||
|
||||
class Down:
|
||||
Ctrl = b"\x04\x01\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x01\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x01\x00\x00\x00\x00\xff~" # option
|
||||
CKey = b"\x04\x01\x00\x00\x00\x00\x00c"
|
||||
OKey = b"\x04\x01\x00\x00\x00\x00\x00o"
|
||||
VKey = b"\x04\x01\x00\x00\x00\x00\x00v"
|
||||
|
||||
class Up:
|
||||
Ctrl = b"\x04\x00\x00\x00\x00\x00\xff\xe3"
|
||||
Cmd = b"\x04\x00\x00\x00\x00\x00\xff\xe9"
|
||||
Alt = b"\x04\x00\x00\x00\x00\x00\xff\x7e" # option
|
||||
|
||||
|
||||
def is_rmb(data: bytes) -> bool:
|
||||
return data[0:2] == b"\x05\x04"
|
||||
|
||||
|
||||
class Mouse:
|
||||
class Up:
|
||||
Right = is_rmb
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KeyState:
|
||||
ctrl_is_down: bool = False
|
||||
alt_is_down: bool = False
|
||||
cmd_is_down: bool = False
|
||||
|
||||
def is_forbidden(self, data: bytes) -> bool:
|
||||
"""
|
||||
:return: True if the key is forbidden, else False
|
||||
"""
|
||||
return self.is_ctrl_o(data)
|
||||
|
||||
def is_ctrl_o(self, data: bytes) -> bool:
|
||||
"""
|
||||
Do not allow the opening of files.
|
||||
"""
|
||||
return self.ctrl_is_down and data == Keys.Down.OKey
|
||||
|
||||
def is_copy(self, data: bytes) -> bool:
|
||||
"""
|
||||
Detect Ctrl+C or Cmd+C for copy.
|
||||
"""
|
||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.CKey
|
||||
|
||||
def is_paste(self, data: bytes) -> bool:
|
||||
"""
|
||||
Detect Ctrl+V or Cmd+V for paste.
|
||||
"""
|
||||
return (self.ctrl_is_down or self.cmd_is_down) and data == Keys.Down.VKey
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Streaming:
|
||||
"""
|
||||
Streaming state.
|
||||
"""
|
||||
|
||||
client_id: str
|
||||
"""
|
||||
Unique to frontend app instance.
|
||||
"""
|
||||
|
||||
interactor: Interactor
|
||||
"""
|
||||
Whether the user or the agent are the interactor.
|
||||
"""
|
||||
|
||||
organization_id: str
|
||||
vnc_port: int
|
||||
x_api_key: str
|
||||
websocket: WebSocket
|
||||
|
||||
# --
|
||||
|
||||
browser_session: AddressablePersistentBrowserSession | None = None
|
||||
key_state: KeyState = dataclasses.field(default_factory=KeyState)
|
||||
task: Task | None = None
|
||||
workflow_run: WorkflowRun | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
add_streaming_client(self)
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||
return False
|
||||
|
||||
if not self.task and not self.workflow_run and not self.browser_session:
|
||||
return False
|
||||
|
||||
if not get_streaming_client(self.client_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
|
||||
LOG.info("Closing Streaming.", reason=reason, code=code)
|
||||
|
||||
self.browser_session = None
|
||||
self.task = None
|
||||
self.workflow_run = None
|
||||
|
||||
try:
|
||||
await self.websocket.close(code=code, reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
del_streaming_client(self.client_id)
|
||||
|
||||
return self
|
||||
|
||||
def update_key_state(self, data: bytes) -> None:
|
||||
if data == Keys.Down.Ctrl:
|
||||
self.key_state.ctrl_is_down = True
|
||||
elif data == Keys.Up.Ctrl:
|
||||
self.key_state.ctrl_is_down = False
|
||||
elif data == Keys.Down.Alt:
|
||||
self.key_state.alt_is_down = True
|
||||
elif data == Keys.Up.Alt:
|
||||
self.key_state.alt_is_down = False
|
||||
elif data == Keys.Down.Cmd:
|
||||
self.key_state.cmd_is_down = True
|
||||
elif data == Keys.Up.Cmd:
|
||||
self.key_state.cmd_is_down = False
|
||||
363
skyvern/forge/sdk/routes/streaming/messages.py
Normal file
363
skyvern/forge/sdk/routes/streaming/messages.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Streaming messages for WebSocket connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming.verify import (
|
||||
loop_verify_browser_session,
|
||||
loop_verify_workflow_run,
|
||||
verify_browser_session,
|
||||
verify_workflow_run,
|
||||
)
|
||||
from skyvern.forge.sdk.utils.aio import collect
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_messages_for_browser_session(
|
||||
client_id: str,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
||||
"""
|
||||
Return a message channel for a browser session, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Message channel: no initial browser session found.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
message_channel = sc.MessageChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got message channel for browser session.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_browser_session(message_channel)),
|
||||
asyncio.create_task(loop_channel(message_channel)),
|
||||
]
|
||||
|
||||
return message_channel, loops
|
||||
|
||||
|
||||
async def get_messages_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
||||
"""
|
||||
Return a message channel for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info(
|
||||
"Message channel: no initial workflow run found.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Message channel: no initial browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
message_channel = sc.MessageChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
workflow_run=workflow_run,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_workflow_run(message_channel)),
|
||||
asyncio.create_task(loop_channel(message_channel)),
|
||||
]
|
||||
|
||||
return message_channel, loops
|
||||
|
||||
|
||||
async def loop_channel(message_channel: sc.MessageChannel) -> None:
|
||||
"""
|
||||
Stream messages and their results back and forth.
|
||||
|
||||
Loops until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
if not message_channel.browser_session:
|
||||
LOG.info(
|
||||
"No browser session found for workflow run.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
async def frontend_to_backend() -> None:
|
||||
LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel)
|
||||
|
||||
while message_channel.is_open:
|
||||
try:
|
||||
data = await message_channel.websocket.receive_json()
|
||||
|
||||
if not isinstance(data, dict):
|
||||
LOG.error(f"Cannot create channel message: expected dict, got {type(data)}")
|
||||
continue
|
||||
|
||||
try:
|
||||
message = sc.reify_channel_message(data)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
message_kind = message.kind
|
||||
|
||||
match message_kind:
|
||||
case "take-control":
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
if not streaming:
|
||||
LOG.error(
|
||||
"No streaming client found for message.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
streaming.interactor = "user"
|
||||
case "cede-control":
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
if not streaming:
|
||||
LOG.error(
|
||||
"No streaming client found for message.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
streaming.interactor = "agent"
|
||||
case "ask-for-clipboard-response":
|
||||
if not isinstance(message, sc.MessageInAskForClipboardResponse):
|
||||
LOG.error(
|
||||
"Invalid message type for ask-for-clipboard-response.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
text = message.text
|
||||
|
||||
async with connected_agent(streaming) as agent:
|
||||
await agent.paste_text(text)
|
||||
case _:
|
||||
LOG.error(f"Unknown message kind: '{message_kind}'")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Frontend disconnected.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Frontend closed the streaming session.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(frontend_to_backend()),
|
||||
]
|
||||
|
||||
try:
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in loop channel stream.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the loop channel stream.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
await message_channel.close(reason="loop-channel-closed")
|
||||
|
||||
|
||||
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
|
||||
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
|
||||
async def browser_session_messages(
|
||||
websocket: WebSocket,
|
||||
browser_session_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
LOG.error("Authentication failed.", browser_session_id=browser_session_id)
|
||||
return
|
||||
|
||||
if not client_id:
|
||||
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
|
||||
return
|
||||
|
||||
message_channel: sc.MessageChannel
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
result = await get_messages_for_browser_session(
|
||||
client_id=client_id,
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
message_channel, loops = result
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting message stream loops for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the message stream function for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the message stream session for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await message_channel.close(reason="stream-closed")
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
|
||||
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
|
||||
async def workflow_run_messages(
|
||||
websocket: WebSocket,
|
||||
workflow_run_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting message stream.", workflow_run_id=workflow_run_id)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
if not client_id:
|
||||
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
message_channel: sc.MessageChannel
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
result = await get_messages_for_workflow_run(
|
||||
client_id=client_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
message_channel, loops = result
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting message stream loops.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the message stream function.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the message stream session.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await message_channel.close(reason="stream-closed")
|
||||
271
skyvern/forge/sdk/routes/streaming/screenshot.py
Normal file
271
skyvern/forge/sdk/routes/streaming/screenshot.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import asyncio
|
||||
import base64
|
||||
from datetime import datetime
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from pydantic import ValidationError
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.routes.routers import legacy_base_router
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.services.org_auth_service import get_current_org
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
STREAMING_TIMEOUT = 300
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/tasks/{task_id}")
|
||||
async def task_stream(
|
||||
websocket: WebSocket,
|
||||
task_id: str,
|
||||
apikey: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
await websocket.accept()
|
||||
if not token and not apikey:
|
||||
await websocket.send_text("No valid credential provided")
|
||||
return
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("ConnectionClosedOK error. Streaming won't start")
|
||||
return
|
||||
|
||||
try:
|
||||
organization = await get_current_org(x_api_key=apikey, authorization=token)
|
||||
organization_id = organization.organization_id
|
||||
except Exception:
|
||||
LOG.exception("Error while getting organization", task_id=task_id)
|
||||
try:
|
||||
await websocket.send_text("Invalid credential provided")
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("ConnectionClosedOK error while sending invalid credential message")
|
||||
return
|
||||
|
||||
LOG.info("Started task streaming", task_id=task_id, organization_id=organization_id)
|
||||
# timestamp last time when streaming activity happens
|
||||
last_activity_timestamp = datetime.utcnow()
|
||||
|
||||
try:
|
||||
while True:
|
||||
# if no activity for 5 minutes, close the connection
|
||||
if (datetime.utcnow() - last_activity_timestamp).total_seconds() > STREAMING_TIMEOUT:
|
||||
LOG.info(
|
||||
"No activity for 5 minutes. Closing connection", task_id=task_id, organization_id=organization_id
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": "timeout",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
if not task:
|
||||
LOG.info("Task not found. Closing connection", task_id=task_id, organization_id=organization_id)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": "not_found",
|
||||
}
|
||||
)
|
||||
return
|
||||
if task.status.is_final():
|
||||
LOG.info(
|
||||
"Task is in a final state. Closing connection",
|
||||
task_status=task.status,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": task.status,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if task.status == TaskStatus.running:
|
||||
file_name = f"{task_id}.png"
|
||||
if task.workflow_run_id:
|
||||
file_name = f"{task.workflow_run_id}.png"
|
||||
screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name)
|
||||
if screenshot:
|
||||
encoded_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": task.status,
|
||||
"screenshot": encoded_screenshot,
|
||||
}
|
||||
)
|
||||
last_activity_timestamp = datetime.utcnow()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
except ValidationError as e:
|
||||
await websocket.send_text(f"Invalid data: {e}")
|
||||
except WebSocketDisconnect:
|
||||
LOG.info("WebSocket connection closed", task_id=task_id, organization_id=organization_id)
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("ConnectionClosedOK error while streaming", task_id=task_id, organization_id=organization_id)
|
||||
return
|
||||
except ConnectionClosedError:
|
||||
LOG.warning(
|
||||
"ConnectionClosedError while streaming (client likely disconnected)",
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
LOG.warning("Error while streaming", task_id=task_id, organization_id=organization_id, exc_info=True)
|
||||
return
|
||||
LOG.info("WebSocket connection closed successfully", task_id=task_id, organization_id=organization_id)
|
||||
return
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/workflow_runs/{workflow_run_id}")
|
||||
async def workflow_run_streaming(
|
||||
websocket: WebSocket,
|
||||
workflow_run_id: str,
|
||||
apikey: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
await websocket.accept()
|
||||
if not token and not apikey:
|
||||
await websocket.send_text("No valid credential provided")
|
||||
return
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WofklowRun Streaming: ConnectionClosedOK error. Streaming won't start")
|
||||
return
|
||||
|
||||
try:
|
||||
organization = await get_current_org(x_api_key=apikey, authorization=token)
|
||||
organization_id = organization.organization_id
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"WofklowRun Streaming: Error while getting organization",
|
||||
workflow_run_id=workflow_run_id,
|
||||
token=token,
|
||||
)
|
||||
try:
|
||||
await websocket.send_text("Invalid credential provided")
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WofklowRun Streaming: ConnectionClosedOK error while sending invalid credential message")
|
||||
return
|
||||
|
||||
LOG.info(
|
||||
"WofklowRun Streaming: Started workflow run streaming",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
# timestamp last time when streaming activity happens
|
||||
last_activity_timestamp = datetime.utcnow()
|
||||
|
||||
try:
|
||||
while True:
|
||||
# if no activity for 5 minutes, close the connection
|
||||
if (datetime.utcnow() - last_activity_timestamp).total_seconds() > STREAMING_TIMEOUT:
|
||||
LOG.info(
|
||||
"WofklowRun Streaming: No activity for 5 minutes. Closing connection",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"status": "timeout",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if not workflow_run or workflow_run.organization_id != organization_id:
|
||||
LOG.info(
|
||||
"WofklowRun Streaming: Workflow not found",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"status": "not_found",
|
||||
}
|
||||
)
|
||||
return
|
||||
if workflow_run.status in [
|
||||
WorkflowRunStatus.completed,
|
||||
WorkflowRunStatus.failed,
|
||||
WorkflowRunStatus.terminated,
|
||||
]:
|
||||
LOG.info(
|
||||
"Workflow run is in a final state. Closing connection",
|
||||
workflow_run_status=workflow_run.status,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"status": workflow_run.status,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.running:
|
||||
file_name = f"{workflow_run_id}.png"
|
||||
screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name)
|
||||
if screenshot:
|
||||
encoded_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"status": workflow_run.status,
|
||||
"screenshot": encoded_screenshot,
|
||||
}
|
||||
)
|
||||
last_activity_timestamp = datetime.utcnow()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
except ValidationError as e:
|
||||
await websocket.send_text(f"Invalid data: {e}")
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"WofklowRun Streaming: WebSocket connection closed",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
except ConnectionClosedOK:
|
||||
LOG.info(
|
||||
"WofklowRun Streaming: ConnectionClosedOK error while streaming",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
except ConnectionClosedError:
|
||||
LOG.warning(
|
||||
"WofklowRun Streaming: ConnectionClosedError while streaming (client likely disconnected)",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"WofklowRun Streaming: Error while streaming",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
LOG.info(
|
||||
"WofklowRun Streaming: WebSocket connection closed successfully",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return
|
||||
288
skyvern/forge/sdk/routes/streaming/verify.py
Normal file
288
skyvern/forge/sdk/routes/streaming/verify.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import structlog
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def verify_browser_session(
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
) -> AddressablePersistentBrowserSession | None:
|
||||
"""
|
||||
Verify the browser session exists, and is usable.
|
||||
"""
|
||||
|
||||
if settings.ENV == "local":
|
||||
dummy_browser_session = AddressablePersistentBrowserSession(
|
||||
persistent_browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
browser_address="0.0.0.0:9223",
|
||||
created_at=datetime.now(),
|
||||
modified_at=datetime.now(),
|
||||
)
|
||||
|
||||
return dummy_browser_session
|
||||
|
||||
browser_session = await app.DATABASE.get_persistent_browser_session(browser_session_id, organization_id)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No browser session found.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
browser_address = browser_session.browser_address
|
||||
|
||||
if not browser_address:
|
||||
LOG.info(
|
||||
"Waiting for browser session address.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
try:
|
||||
browser_address = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
|
||||
session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
except Exception as ex:
|
||||
LOG.info(
|
||||
"Browser session address not found for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
ex=ex,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||
**browser_session.model_dump() | {"browser_address": browser_address},
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return addressable_browser_session
|
||||
|
||||
|
||||
async def verify_task(
|
||||
task_id: str, organization_id: str
|
||||
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
|
||||
"""
|
||||
Verify the task is running, and that it has a browser session associated
|
||||
with it.
|
||||
"""
|
||||
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
|
||||
if not task:
|
||||
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
|
||||
return None, None
|
||||
|
||||
if task.status.is_final():
|
||||
LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
|
||||
|
||||
return None, None
|
||||
|
||||
if task.status not in [TaskStatus.created, TaskStatus.queued, TaskStatus.running]:
|
||||
LOG.info(
|
||||
"Task is not created, queued, or running.",
|
||||
task_status=task.status,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
|
||||
organization_id=organization_id,
|
||||
runnable_id=task_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return task, None
|
||||
|
||||
if not browser_session.browser_address:
|
||||
LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return task, None
|
||||
|
||||
try:
|
||||
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
|
||||
)
|
||||
return task, None
|
||||
|
||||
return task, addressable_browser_session
|
||||
|
||||
|
||||
async def verify_workflow_run(
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
|
||||
"""
|
||||
Verify the workflow run is running, and that it has a browser session associated
|
||||
with it.
|
||||
"""
|
||||
|
||||
if settings.ENV == "local":
|
||||
dummy_workflow_run = WorkflowRun(
|
||||
workflow_id="123",
|
||||
workflow_permanent_id="wpid_123",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
status=WorkflowRunStatus.running,
|
||||
created_at=datetime.now(),
|
||||
modified_at=datetime.now(),
|
||||
)
|
||||
|
||||
dummy_browser_session = AddressablePersistentBrowserSession(
|
||||
persistent_browser_session_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
browser_address="0.0.0.0:9223",
|
||||
created_at=datetime.now(),
|
||||
modified_at=datetime.now(),
|
||||
)
|
||||
|
||||
return dummy_workflow_run, dummy_browser_session
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
return None, None
|
||||
|
||||
if workflow_run.status.is_final():
|
||||
LOG.info(
|
||||
"Workflow run is in a final state. Closing connection.",
|
||||
workflow_run_status=workflow_run.status,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
if workflow_run.status not in [
|
||||
WorkflowRunStatus.created,
|
||||
WorkflowRunStatus.queued,
|
||||
WorkflowRunStatus.running,
|
||||
WorkflowRunStatus.paused,
|
||||
]:
|
||||
LOG.info(
|
||||
"Workflow run is not running.",
|
||||
workflow_run_status=workflow_run.status,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
|
||||
organization_id=organization_id,
|
||||
runnable_id=workflow_run_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return workflow_run, None
|
||||
|
||||
browser_address = browser_session.browser_address
|
||||
|
||||
if not browser_address:
|
||||
LOG.info(
|
||||
"Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
try:
|
||||
browser_address = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
|
||||
session_id=browser_session.persistent_browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
except Exception as ex:
|
||||
LOG.info(
|
||||
"Browser session address not found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
ex=ex,
|
||||
)
|
||||
return workflow_run, None
|
||||
|
||||
try:
|
||||
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||
**browser_session.model_dump() | {"browser_address": browser_address},
|
||||
)
|
||||
except Exception:
|
||||
return workflow_run, None
|
||||
|
||||
return workflow_run, addressable_browser_session
|
||||
|
||||
|
||||
async def loop_verify_browser_session(verifiable: sc.MessageChannel | sc.Streaming) -> None:
|
||||
"""
|
||||
Loop until the browser session is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
while verifiable.browser_session and verifiable.is_open:
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=verifiable.browser_session.persistent_browser_session_id,
|
||||
organization_id=verifiable.organization_id,
|
||||
)
|
||||
|
||||
verifiable.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
async def loop_verify_task(streaming: sc.Streaming) -> None:
|
||||
"""
|
||||
Loop until the task is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
while streaming.task and streaming.is_open:
|
||||
task, browser_session = await verify_task(
|
||||
task_id=streaming.task.task_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
streaming.task = task
|
||||
streaming.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
async def loop_verify_workflow_run(verifiable: sc.MessageChannel | sc.Streaming) -> None:
|
||||
"""
|
||||
Loop until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
while verifiable.workflow_run and verifiable.is_open:
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=verifiable.workflow_run.workflow_run_id,
|
||||
organization_id=verifiable.organization_id,
|
||||
)
|
||||
|
||||
verifiable.workflow_run = workflow_run
|
||||
verifiable.browser_session = browser_session
|
||||
|
||||
await asyncio.sleep(2)
|
||||
626
skyvern/forge/sdk/routes/streaming/vnc.py
Normal file
626
skyvern/forge/sdk/routes/streaming/vnc.py
Normal file
@@ -0,0 +1,626 @@
|
||||
"""
|
||||
Streaming VNC WebSocket connections.
|
||||
|
||||
NOTE(jdo:streaming-local-dev)
|
||||
-----------------------------
|
||||
- grep the above for local development seams
|
||||
- augment those seams as indicated, then
|
||||
- stand up https://github.com/jomido/whyvern
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import typing as t
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import structlog
|
||||
import websockets
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets import Data
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming.clients as sc
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming.verify import (
|
||||
loop_verify_browser_session,
|
||||
loop_verify_task,
|
||||
loop_verify_workflow_run,
|
||||
verify_browser_session,
|
||||
verify_task,
|
||||
verify_workflow_run,
|
||||
)
|
||||
from skyvern.forge.sdk.utils.aio import collect
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class Constants:
|
||||
MissingXApiKey = "<missing-x-api-key>"
|
||||
|
||||
|
||||
async def get_x_api_key(organization_id: str) -> str:
|
||||
token = await app.DATABASE.get_valid_org_auth_token(
|
||||
organization_id,
|
||||
OrganizationAuthTokenType.api.value,
|
||||
)
|
||||
|
||||
if not token:
|
||||
LOG.warning(
|
||||
"No valid API key found for organization when streaming.",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
x_api_key = Constants.MissingXApiKey
|
||||
else:
|
||||
x_api_key = token.token
|
||||
|
||||
return x_api_key
|
||||
|
||||
|
||||
async def get_streaming_for_browser_session(
|
||||
client_id: str,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||
"""
|
||||
Return a streaming context for a browser session, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting streaming context for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
|
||||
)
|
||||
return None
|
||||
|
||||
x_api_key = await get_x_api_key(organization_id)
|
||||
|
||||
streaming = sc.Streaming(
|
||||
client_id=client_id,
|
||||
interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
browser_session=browser_session,
|
||||
x_api_key=x_api_key,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got streaming context for browser session.", streaming=streaming)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_browser_session(streaming)),
|
||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||
]
|
||||
|
||||
return streaming, loops
|
||||
|
||||
|
||||
async def get_streaming_for_task(
|
||||
client_id: str,
|
||||
task_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||
"""
|
||||
Return a streaming context for a task, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
|
||||
|
||||
if not task:
|
||||
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
x_api_key = await get_x_api_key(organization_id)
|
||||
|
||||
streaming = sc.Streaming(
|
||||
client_id=client_id,
|
||||
interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
x_api_key=x_api_key,
|
||||
websocket=websocket,
|
||||
browser_session=browser_session,
|
||||
task=task,
|
||||
)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_task(streaming)),
|
||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||
]
|
||||
|
||||
return streaming, loops
|
||||
|
||||
|
||||
async def get_streaming_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||
"""
|
||||
Return a streaming context for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting streaming context for workflow run.", workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No initial browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
x_api_key = await get_x_api_key(organization_id)
|
||||
|
||||
streaming = sc.Streaming(
|
||||
client_id=client_id,
|
||||
interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
browser_session=browser_session,
|
||||
workflow_run=workflow_run,
|
||||
x_api_key=x_api_key,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got streaming context for workflow run.", streaming=streaming)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_workflow_run(streaming)),
|
||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||
]
|
||||
|
||||
return streaming, loops
|
||||
|
||||
|
||||
def verify_message_channel(
|
||||
message_channel: sc.MessageChannel | None, streaming: sc.Streaming
|
||||
) -> sc.MessageChannel | t.Literal[False]:
|
||||
if message_channel and message_channel.is_open:
|
||||
return message_channel
|
||||
|
||||
LOG.warning(
|
||||
"No message channel found for client, or it is not open",
|
||||
message_channel=message_channel,
|
||||
client_id=streaming.client_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def copy_text(streaming: sc.Streaming) -> None:
|
||||
try:
|
||||
async with connected_agent(streaming) as agent:
|
||||
copied_text = await agent.get_selected_text()
|
||||
|
||||
LOG.info(
|
||||
"Retrieved selected text via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
message_channel = sc.get_message_client(streaming.client_id)
|
||||
|
||||
if cc := verify_message_channel(message_channel, streaming):
|
||||
await cc.send_copied_text(copied_text, streaming)
|
||||
else:
|
||||
LOG.warning(
|
||||
"No message channel found for client, or it is not open",
|
||||
message_channel=message_channel,
|
||||
client_id=streaming.client_id,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to retrieve selected text via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
|
||||
async def ask_for_clipboard(streaming: sc.Streaming) -> None:
|
||||
try:
|
||||
LOG.info(
|
||||
"Asking for clipboard data via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
message_channel = sc.get_message_client(streaming.client_id)
|
||||
|
||||
if cc := verify_message_channel(message_channel, streaming):
|
||||
await cc.ask_for_clipboard(streaming)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to ask for clipboard via CDP",
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
|
||||
async def loop_stream_vnc(streaming: sc.Streaming) -> None:
|
||||
"""
|
||||
Actually stream the VNC session data between a frontend and a browser
|
||||
session.
|
||||
|
||||
Loops until the task is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
if not streaming.browser_session:
|
||||
LOG.info("No browser session found for task.", task=streaming.task, organization_id=streaming.organization_id)
|
||||
return
|
||||
|
||||
vnc_url: str = ""
|
||||
if streaming.browser_session.ip_address:
|
||||
if ":" in streaming.browser_session.ip_address:
|
||||
ip, _ = streaming.browser_session.ip_address.split(":")
|
||||
vnc_url = f"ws://{ip}:{streaming.vnc_port}"
|
||||
else:
|
||||
vnc_url = f"ws://{streaming.browser_session.ip_address}:{streaming.vnc_port}"
|
||||
else:
|
||||
browser_address = streaming.browser_session.browser_address
|
||||
|
||||
parsed_browser_address = urlparse(browser_address)
|
||||
host = parsed_browser_address.hostname
|
||||
vnc_url = f"ws://{host}:{streaming.vnc_port}"
|
||||
|
||||
# NOTE(jdo:streaming-local-dev)
|
||||
# vnc_url = "ws://localhost:9001/ws/novnc"
|
||||
|
||||
LOG.info(
|
||||
"Connecting to VNC URL.",
|
||||
vnc_url=vnc_url,
|
||||
task=streaming.task,
|
||||
workflow_run=streaming.workflow_run,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
|
||||
async with websockets.connect(vnc_url) as novnc_ws:
|
||||
|
||||
async def frontend_to_browser() -> None:
|
||||
LOG.info("Starting frontend-to-browser data transfer.", streaming=streaming)
|
||||
data: Data | None = None
|
||||
|
||||
while streaming.is_open:
|
||||
try:
|
||||
data = await streaming.websocket.receive_bytes()
|
||||
|
||||
if data:
|
||||
message_type = data[0]
|
||||
|
||||
if message_type == sc.MessageType.Keyboard.value:
|
||||
streaming.update_key_state(data)
|
||||
|
||||
if streaming.key_state.is_copy(data):
|
||||
await copy_text(streaming)
|
||||
|
||||
if streaming.key_state.is_paste(data):
|
||||
await ask_for_clipboard(streaming)
|
||||
|
||||
if streaming.key_state.is_forbidden(data):
|
||||
continue
|
||||
|
||||
if message_type == sc.MessageType.Mouse.value:
|
||||
if sc.Mouse.Up.Right(data):
|
||||
continue
|
||||
|
||||
if not streaming.interactor == "user" and message_type in (
|
||||
sc.MessageType.Keyboard.value,
|
||||
sc.MessageType.Mouse.value,
|
||||
):
|
||||
LOG.info(
|
||||
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
|
||||
)
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info("Frontend disconnected.", task=streaming.task, organization_id=streaming.organization_id)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Frontend closed the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
try:
|
||||
await novnc_ws.send(data)
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Browser disconnected from the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Browser closed the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred in frontend-to-browser loop.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def browser_to_frontend() -> None:
|
||||
LOG.info("Starting browser-to-frontend data transfer.", streaming=streaming)
|
||||
data: Data | None = None
|
||||
|
||||
while streaming.is_open:
|
||||
try:
|
||||
data = await novnc_ws.recv()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Browser disconnected from the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
await streaming.close(reason="browser-disconnected")
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Browser closed the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
await streaming.close(reason="browser-closed")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred in browser-to-frontend loop.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
try:
|
||||
await streaming.websocket.send_bytes(data)
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Frontend disconnected from the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
await streaming.close(reason="frontend-disconnected")
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Frontend closed the streaming session.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
await streaming.close(reason="frontend-closed")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred.",
|
||||
task=streaming.task,
|
||||
organization_id=streaming.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(frontend_to_browser()),
|
||||
asyncio.create_task(browser_to_frontend()),
|
||||
]
|
||||
|
||||
try:
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in loop stream.", task=streaming.task, organization_id=streaming.organization_id
|
||||
)
|
||||
finally:
|
||||
LOG.info("Closing the loop stream.", task=streaming.task, organization_id=streaming.organization_id)
|
||||
await streaming.close(reason="loop-stream-vnc-closed")
|
||||
|
||||
|
||||
@base_router.websocket("/stream/vnc/browser_session/{browser_session_id}")
|
||||
async def browser_session_stream(
|
||||
websocket: WebSocket,
|
||||
browser_session_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
await stream(websocket, apikey=apikey, client_id=client_id, browser_session_id=browser_session_id, token=token)
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/vnc/task/{task_id}")
|
||||
async def task_stream(
|
||||
websocket: WebSocket,
|
||||
task_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
await stream(websocket, apikey=apikey, client_id=client_id, task_id=task_id, token=token)
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/vnc/workflow_run/{workflow_run_id}")
|
||||
async def workflow_run_stream(
|
||||
websocket: WebSocket,
|
||||
workflow_run_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
await stream(websocket, apikey=apikey, client_id=client_id, workflow_run_id=workflow_run_id, token=token)
|
||||
|
||||
|
||||
async def stream(
|
||||
websocket: WebSocket,
|
||||
*,
|
||||
apikey: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
client_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
token: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
) -> None:
|
||||
if not client_id:
|
||||
LOG.error(
|
||||
"Client ID not provided for VNC stream.",
|
||||
browser_session_id=browser_session_id,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
LOG.info(
|
||||
"Starting VNC stream.",
|
||||
browser_session_id=browser_session_id,
|
||||
client_id=client_id,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
LOG.error("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
streaming: sc.Streaming
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
if browser_session_id:
|
||||
result = await get_streaming_for_browser_session(
|
||||
client_id=client_id,
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
streaming, loops = result
|
||||
|
||||
LOG.info(
|
||||
"Starting streaming for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
elif task_id:
|
||||
result = await get_streaming_for_task(
|
||||
client_id=client_id,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
streaming, loops = result
|
||||
|
||||
LOG.info("Starting streaming for task.", task_id=task_id, organization_id=organization_id)
|
||||
|
||||
elif workflow_run_id:
|
||||
LOG.info(
|
||||
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
result = await get_streaming_for_workflow_run(
|
||||
client_id=client_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
streaming, loops = result
|
||||
|
||||
LOG.info(
|
||||
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
else:
|
||||
LOG.error("Neither task ID nor workflow run ID was provided.")
|
||||
return
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting streaming loops.",
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the stream function.",
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the streaming session.",
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await streaming.close(reason="stream-closed")
|
||||
Reference in New Issue
Block a user