Files
Dorod-Sky/skyvern/forge/sdk/routes/streaming_vnc.py

643 lines
21 KiB
Python
Raw Normal View History

2025-06-12 09:43:16 -07:00
import asyncio
import dataclasses
import typing as t
from enum import IntEnum
import structlog
import websockets
from fastapi import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
from websockets import Data
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
2025-06-12 13:17:50 -07:00
from skyvern.config import settings
2025-06-12 09:43:16 -07:00
from skyvern.forge import app
from skyvern.forge.sdk.routes.routers import legacy_base_router
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
from skyvern.forge.sdk.services.org_auth_service import get_current_org
from skyvern.forge.sdk.utils.aio import collect
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
Interactor = t.Literal["agent", "user"]
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
class MessageType(IntEnum):
keyboard = 4
mouse = 5
LOG = structlog.get_logger()
@dataclasses.dataclass
class Streaming:
"""
Streaming state.
"""
interactor: Interactor
"""
Whether the user or the agent are the interactor.
"""
organization_id: str
vnc_port: int
websocket: WebSocket
# --
browser_session: AddressablePersistentBrowserSession | None = None
task: Task | None = None
workflow_run: WorkflowRun | None = None
@property
def is_open(self) -> bool:
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
return False
if not self.task and not self.workflow_run:
return False
return True
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
LOG.info("Closing Streaming.", reason=reason, code=code)
self.browser_session = None
self.task = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
return self
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
"""
Accepts the websocket connection.
Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
determined.
"""
try:
await websocket.accept()
if not token and not apikey:
await websocket.close(code=1002)
return None
except ConnectionClosedOK:
LOG.info("WebSocket connection closed cleanly.")
return None
try:
organization = await get_current_org(x_api_key=apikey, authorization=token)
organization_id = organization.organization_id
if not organization_id:
await websocket.close(code=1002)
return None
except Exception:
LOG.exception("Error occurred while retrieving organization information.")
try:
await websocket.close(code=1002)
except ConnectionClosedOK:
LOG.info("WebSocket connection closed due to invalid credentials.")
return None
return organization_id
async def verify_task(
task_id: str, organization_id: str
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
"""
Verify the task is running, and that it has a browser session associated
with it.
"""
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
return None, None
if task.status.is_final():
LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
return None, None
if not task.status == TaskStatus.running:
LOG.info("Task is not running.", task_status=task.status, task_id=task_id, organization_id=organization_id)
return None, None
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
organization_id=organization_id,
runnable_id=task_id, # is this correct; is there a task_run_id?
)
if not browser_session:
LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
return task, None
if not browser_session.browser_address:
LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
return task, None
try:
addressable_browser_session = AddressablePersistentBrowserSession(
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
)
except Exception as e:
LOG.error(
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
)
return task, None
return task, addressable_browser_session
async def get_streaming_for_task(
task_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[Streaming, Loops] | None:
"""
Return a streaming context for a task, with a list of loops to run concurrently.
"""
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
return None
if not browser_session:
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
return None
streaming = Streaming(
interactor="user",
organization_id=organization_id,
2025-06-12 13:17:50 -07:00
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
2025-06-12 09:43:16 -07:00
websocket=websocket,
browser_session=browser_session,
task=task,
)
loops = [
asyncio.create_task(loop_verify_task(streaming)),
asyncio.create_task(loop_stream_vnc(streaming)),
]
return streaming, loops
async def get_streaming_for_workflow_run(
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[Streaming, Loops] | None:
"""
Return a streaming context for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting streaming context for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
return None
if not browser_session:
LOG.info(
"No initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
streaming = Streaming(
interactor="user",
organization_id=organization_id,
2025-06-12 13:17:50 -07:00
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
2025-06-12 09:43:16 -07:00
browser_session=browser_session,
workflow_run=workflow_run,
websocket=websocket,
)
loops = [
asyncio.create_task(loop_verify_workflow_run(streaming)),
asyncio.create_task(loop_stream_vnc(streaming)),
]
return streaming, loops
async def verify_workflow_run(
workflow_run_id: str,
organization_id: str,
) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
"""
Verify the workflow run is running, and that it has a browser session associated
with it.
"""
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
return None, None
if workflow_run.status in [
WorkflowRunStatus.completed,
WorkflowRunStatus.failed,
WorkflowRunStatus.terminated,
]:
LOG.info(
"Workflow run is in a final state. Closing connection.",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None, None
if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]:
LOG.info(
"Workflow run is not running.",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None, None
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
organization_id=organization_id,
runnable_id=workflow_run_id,
)
if not browser_session:
LOG.info(
"No browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return workflow_run, None
browser_address = browser_session.browser_address
if not browser_address:
LOG.info(
"Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
try:
_, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
session_id=browser_session.persistent_browser_session_id,
organization_id=organization_id,
)
browser_address = f"{host}:{cdp_port}"
except Exception as ex:
LOG.info(
"Browser session address not found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
ex=ex,
)
return workflow_run, None
try:
addressable_browser_session = AddressablePersistentBrowserSession(
**browser_session.model_dump() | {"browser_address": browser_address},
)
except Exception:
return workflow_run, None
return workflow_run, addressable_browser_session
async def loop_verify_task(streaming: Streaming) -> None:
"""
Loop until the task is cleared or the websocket is closed.
"""
while streaming.task and streaming.is_open:
task, browser_session = await verify_task(
task_id=streaming.task.task_id,
organization_id=streaming.organization_id,
)
streaming.task = task
streaming.browser_session = browser_session
await asyncio.sleep(2)
async def loop_verify_workflow_run(streaming: Streaming) -> None:
"""
Loop until the workflow run is cleared or the websocket is closed.
"""
while streaming.workflow_run and streaming.is_open:
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=streaming.workflow_run.workflow_run_id,
organization_id=streaming.organization_id,
)
streaming.workflow_run = workflow_run
streaming.browser_session = browser_session
await asyncio.sleep(2)
async def loop_stream_vnc(streaming: Streaming) -> None:
"""
Actually stream the VNC session data between a frontend and a browser
session.
Loops until the task is cleared or the websocket is closed.
"""
if not streaming.browser_session:
LOG.info("No browser session found for task.", task=streaming.task, organization_id=streaming.organization_id)
return
browser_address = streaming.browser_session.browser_address
host, _ = browser_address.rsplit(":")
vnc_url = f"ws://{host}:{streaming.vnc_port}"
LOG.info(
"Connecting to VNC URL.",
browser_address=browser_address,
vnc_url=vnc_url,
task=streaming.task,
workflow_run=streaming.workflow_run,
organization_id=streaming.organization_id,
)
async with websockets.connect(vnc_url) as novnc_ws:
async def frontend_to_browser() -> None:
LOG.info("Starting frontend-to-browser data transfer.", streaming=streaming)
data: Data | None = None
while streaming.is_open:
try:
data = await streaming.websocket.receive_bytes()
if data:
message_type = data[0]
# TODO: verify 4,5 are keyboard/mouse; they seem to be
if not streaming.interactor == "user" and message_type in (
MessageType.keyboard.value,
MessageType.mouse.value,
):
LOG.info(
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
)
continue
except WebSocketDisconnect:
LOG.info("Frontend disconnected.", task=streaming.task, organization_id=streaming.organization_id)
raise
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
if not data:
continue
try:
await novnc_ws.send(data)
except WebSocketDisconnect:
LOG.info(
"Browser disconnected from the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
except ConnectionClosedError:
LOG.info(
"Browser closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred in frontend-to-browser loop.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
async def browser_to_frontend() -> None:
LOG.info("Starting browser-to-frontend data transfer.", streaming=streaming)
data: Data | None = None
while streaming.is_open:
try:
data = await novnc_ws.recv()
except WebSocketDisconnect:
LOG.info(
"Browser disconnected from the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="browser-disconnected")
except ConnectionClosedError:
LOG.info(
"Browser closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="browser-closed")
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred in browser-to-frontend loop.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
if not data:
continue
try:
await streaming.websocket.send_bytes(data)
except WebSocketDisconnect:
LOG.info(
"Frontend disconnected from the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="frontend-disconnected")
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
task=streaming.task,
organization_id=streaming.organization_id,
)
await streaming.close(reason="frontend-closed")
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
task=streaming.task,
organization_id=streaming.organization_id,
)
raise
loops = [
asyncio.create_task(frontend_to_browser()),
asyncio.create_task(browser_to_frontend()),
]
try:
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in loop stream.", task=streaming.task, organization_id=streaming.organization_id
)
finally:
LOG.info("Closing the loop stream.", task=streaming.task, organization_id=streaming.organization_id)
await streaming.close(reason="loop-stream-vnc-closed")
@legacy_base_router.websocket("/stream/vnc/task/{task_id}")
async def task_stream(
websocket: WebSocket,
task_id: str,
apikey: str | None = None,
token: str | None = None,
) -> None:
await stream(websocket, apikey=apikey, task_id=task_id, token=token)
@legacy_base_router.websocket("/stream/vnc/workflow_run/{workflow_run_id}")
async def workflow_run_stream(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
token: str | None = None,
) -> None:
await stream(websocket, apikey=apikey, workflow_run_id=workflow_run_id, token=token)
async def stream(
websocket: WebSocket,
*,
apikey: str | None = None,
task_id: str | None = None,
token: str | None = None,
workflow_run_id: str | None = None,
) -> None:
LOG.info("Starting VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.info("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
return
streaming: Streaming
loops: list[asyncio.Task] = []
if task_id:
result = await get_streaming_for_task(task_id=task_id, organization_id=organization_id, websocket=websocket)
if not result:
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
await websocket.close(code=1013)
return
streaming, loops = result
LOG.info("Starting streaming for task.", task_id=task_id, organization_id=organization_id)
elif workflow_run_id:
LOG.info(
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
result = await get_streaming_for_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error(
"No streaming context found for the workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
streaming, loops = result
LOG.info(
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
)
else:
LOG.error("Neither task ID nor workflow run ID was provided.")
return
try:
LOG.info(
"Starting streaming loops.",
task_id=task_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the stream function.",
task_id=task_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the streaming session.",
task_id=task_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await streaming.close(reason="stream-closed")