add vnc streaming endpoints (#2695)
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, List, Sequence
|
from typing import Any, List, Sequence
|
||||||
@@ -3207,3 +3208,21 @@ class AgentDB:
|
|||||||
query = query.filter_by(organization_id=organization_id)
|
query = query.filter_by(organization_id=organization_id)
|
||||||
task_run = (await session.scalars(query)).first()
|
task_run = (await session.scalars(query)).first()
|
||||||
return Run.model_validate(task_run) if task_run else None
|
return Run.model_validate(task_run) if task_run else None
|
||||||
|
|
||||||
|
async def wait_on_persistent_browser_address(self, session_id: str, organization_id: str) -> str:
|
||||||
|
async with asyncio.timeout(10 * 60):
|
||||||
|
while True:
|
||||||
|
persistent_browser_session = await self.get_persistent_browser_session(session_id, organization_id)
|
||||||
|
if persistent_browser_session is None:
|
||||||
|
raise Exception(f"Persistent browser session not found for {session_id}")
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Checking browser address",
|
||||||
|
session_id=session_id,
|
||||||
|
address=persistent_browser_session.browser_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
if persistent_browser_session.browser_address:
|
||||||
|
return persistent_browser_session.browser_address
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|||||||
31
skyvern/forge/sdk/db/polls.py
Normal file
31
skyvern/forge/sdk/db/polls.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from structlog import get_logger
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.db.client import AgentDB
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_on_persistent_browser_address(db: AgentDB, session_id: str, organization_id: str) -> str | None:
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(10 * 60):
|
||||||
|
while True:
|
||||||
|
persistent_browser_session = await db.get_persistent_browser_session(session_id, organization_id)
|
||||||
|
if persistent_browser_session is None:
|
||||||
|
raise Exception(f"Persistent browser session not found for {session_id}")
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Checking browser address",
|
||||||
|
session_id=session_id,
|
||||||
|
address=persistent_browser_session.browser_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
if persistent_browser_session.browser_address:
|
||||||
|
return persistent_browser_session.browser_address
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
LOG.warning(f"Browser address not found for persistent browser session {session_id}")
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -2,3 +2,4 @@ from skyvern.forge.sdk.routes import agent_protocol # noqa: F401
|
|||||||
from skyvern.forge.sdk.routes import browser_sessions # noqa: F401
|
from skyvern.forge.sdk.routes import browser_sessions # noqa: F401
|
||||||
from skyvern.forge.sdk.routes import credentials # noqa: F401
|
from skyvern.forge.sdk.routes import credentials # noqa: F401
|
||||||
from skyvern.forge.sdk.routes import streaming # noqa: F401
|
from skyvern.forge.sdk.routes import streaming # noqa: F401
|
||||||
|
from skyvern.forge.sdk.routes import streaming_vnc # noqa: F401
|
||||||
|
|||||||
644
skyvern/forge/sdk/routes/streaming_vnc.py
Normal file
644
skyvern/forge/sdk/routes/streaming_vnc.py
Normal file
@@ -0,0 +1,644 @@
|
|||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import typing as t
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
import websockets
|
||||||
|
from cloud.config import settings
|
||||||
|
from fastapi import WebSocket, WebSocketDisconnect
|
||||||
|
from starlette.websockets import WebSocketState
|
||||||
|
from websockets import Data
|
||||||
|
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||||
|
|
||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.sdk.routes.routers import legacy_base_router
|
||||||
|
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
|
||||||
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||||
|
from skyvern.forge.sdk.services.org_auth_service import get_current_org
|
||||||
|
from skyvern.forge.sdk.utils.aio import collect
|
||||||
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunStatus
|
||||||
|
|
||||||
|
Interactor = t.Literal["agent", "user"]
|
||||||
|
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(IntEnum):
|
||||||
|
keyboard = 4
|
||||||
|
mouse = 5
|
||||||
|
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Streaming:
|
||||||
|
"""
|
||||||
|
Streaming state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
interactor: Interactor
|
||||||
|
"""
|
||||||
|
Whether the user or the agent are the interactor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
organization_id: str
|
||||||
|
vnc_port: int
|
||||||
|
websocket: WebSocket
|
||||||
|
|
||||||
|
# --
|
||||||
|
|
||||||
|
browser_session: AddressablePersistentBrowserSession | None = None
|
||||||
|
task: Task | None = None
|
||||||
|
workflow_run: WorkflowRun | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_open(self) -> bool:
|
||||||
|
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not self.task and not self.workflow_run:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def close(self, code: int = 1000, reason: str | None = None) -> "Streaming":
|
||||||
|
LOG.info("Closing Streaming.", reason=reason, code=code)
|
||||||
|
|
||||||
|
self.browser_session = None
|
||||||
|
self.task = None
|
||||||
|
self.workflow_run = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.websocket.close(code=code, reason=reason)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
||||||
|
"""
|
||||||
|
Accepts the websocket connection.
|
||||||
|
|
||||||
|
Authenticates the user; cannot proceed with WS connection if an organization_id cannot be
|
||||||
|
determined.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
await websocket.accept()
|
||||||
|
if not token and not apikey:
|
||||||
|
await websocket.close(code=1002)
|
||||||
|
return None
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
LOG.info("WebSocket connection closed cleanly.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
organization = await get_current_org(x_api_key=apikey, authorization=token)
|
||||||
|
organization_id = organization.organization_id
|
||||||
|
|
||||||
|
if not organization_id:
|
||||||
|
await websocket.close(code=1002)
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
LOG.exception("Error occurred while retrieving organization information.")
|
||||||
|
try:
|
||||||
|
await websocket.close(code=1002)
|
||||||
|
except ConnectionClosedOK:
|
||||||
|
LOG.info("WebSocket connection closed due to invalid credentials.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return organization_id
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_task(
|
||||||
|
task_id: str, organization_id: str
|
||||||
|
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
|
||||||
|
"""
|
||||||
|
Verify the task is running, and that it has a browser session associated
|
||||||
|
with it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
LOG.info("Task not found.", task_id=task_id, organization_id=organization_id)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if task.status.is_final():
|
||||||
|
LOG.info("Task is in a final state.", task_status=task.status, task_id=task_id, organization_id=organization_id)
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not task.status == TaskStatus.running:
|
||||||
|
LOG.info("Task is not running.", task_status=task.status, task_id=task_id, organization_id=organization_id)
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
|
||||||
|
organization_id=organization_id,
|
||||||
|
runnable_id=task_id, # is this correct; is there a task_run_id?
|
||||||
|
)
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info("No browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||||
|
return task, None
|
||||||
|
|
||||||
|
if not browser_session.browser_address:
|
||||||
|
LOG.info("Browser session address not found for task.", task_id=task_id, organization_id=organization_id)
|
||||||
|
return task, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||||
|
**browser_session.model_dump() | {"browser_address": browser_session.browser_address},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(
|
||||||
|
"streaming-vnc.browser-session-reify-error", task_id=task_id, organization_id=organization_id, error=e
|
||||||
|
)
|
||||||
|
return task, None
|
||||||
|
|
||||||
|
return task, addressable_browser_session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_streaming_for_task(
|
||||||
|
task_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[Streaming, Loops] | None:
|
||||||
|
"""
|
||||||
|
Return a streaming context for a task, with a list of loops to run concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task, browser_session = await verify_task(task_id=task_id, organization_id=organization_id)
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
LOG.info("No initial task found.", task_id=task_id, organization_id=organization_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
streaming = Streaming(
|
||||||
|
interactor="user",
|
||||||
|
organization_id=organization_id,
|
||||||
|
vnc_port=settings.PERSISTENT_BROWSER_VNC_PORT,
|
||||||
|
websocket=websocket,
|
||||||
|
# --
|
||||||
|
browser_session=browser_session,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(loop_verify_task(streaming)),
|
||||||
|
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return streaming, loops
|
||||||
|
|
||||||
|
|
||||||
|
async def get_streaming_for_workflow_run(
|
||||||
|
workflow_run_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[Streaming, Loops] | None:
|
||||||
|
"""
|
||||||
|
Return a streaming context for a workflow run, with a list of loops to run concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOG.info("Getting streaming context for workflow run.", workflow_run_id=workflow_run_id)
|
||||||
|
|
||||||
|
workflow_run, browser_session = await verify_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow_run:
|
||||||
|
LOG.info("No initial workflow run found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info(
|
||||||
|
"No initial browser session found for workflow run.",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
streaming = Streaming(
|
||||||
|
interactor="user",
|
||||||
|
organization_id=organization_id,
|
||||||
|
vnc_port=settings.PERSISTENT_BROWSER_VNC_PORT,
|
||||||
|
# --
|
||||||
|
browser_session=browser_session,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(loop_verify_workflow_run(streaming)),
|
||||||
|
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return streaming, loops
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_workflow_run(
|
||||||
|
workflow_run_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
) -> tuple[WorkflowRun | None, AddressablePersistentBrowserSession | None]:
|
||||||
|
"""
|
||||||
|
Verify the workflow run is running, and that it has a browser session associated
|
||||||
|
with it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
workflow_run = await app.DATABASE.get_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow_run:
|
||||||
|
LOG.info("Workflow run not found.", workflow_run_id=workflow_run_id, organization_id=organization_id)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if workflow_run.status in [
|
||||||
|
WorkflowRunStatus.completed,
|
||||||
|
WorkflowRunStatus.failed,
|
||||||
|
WorkflowRunStatus.terminated,
|
||||||
|
]:
|
||||||
|
LOG.info(
|
||||||
|
"Workflow run is in a final state. Closing connection.",
|
||||||
|
workflow_run_status=workflow_run.status,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if workflow_run.status not in [WorkflowRunStatus.created, WorkflowRunStatus.queued, WorkflowRunStatus.running]:
|
||||||
|
LOG.info(
|
||||||
|
"Workflow run is not running.",
|
||||||
|
workflow_run_status=workflow_run.status,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session_by_runnable_id(
|
||||||
|
organization_id=organization_id,
|
||||||
|
runnable_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info(
|
||||||
|
"No browser session found for workflow run.",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return workflow_run, None
|
||||||
|
|
||||||
|
browser_address = browser_session.browser_address
|
||||||
|
|
||||||
|
if not browser_address:
|
||||||
|
LOG.info(
|
||||||
|
"Waiting for browser session address.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
|
||||||
|
session_id=browser_session.persistent_browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
browser_address = f"{host}:{cdp_port}"
|
||||||
|
except Exception as ex:
|
||||||
|
LOG.info(
|
||||||
|
"Browser session address not found for workflow run.",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
ex=ex,
|
||||||
|
)
|
||||||
|
return workflow_run, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||||
|
**browser_session.model_dump() | {"browser_address": browser_address},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return workflow_run, None
|
||||||
|
|
||||||
|
return workflow_run, addressable_browser_session
|
||||||
|
|
||||||
|
|
||||||
|
async def loop_verify_task(streaming: Streaming) -> None:
|
||||||
|
"""
|
||||||
|
Loop until the task is cleared or the websocket is closed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
while streaming.task and streaming.is_open:
|
||||||
|
task, browser_session = await verify_task(
|
||||||
|
task_id=streaming.task.task_id,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming.task = task
|
||||||
|
streaming.browser_session = browser_session
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
|
async def loop_verify_workflow_run(streaming: Streaming) -> None:
|
||||||
|
"""
|
||||||
|
Loop until the workflow run is cleared or the websocket is closed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
while streaming.workflow_run and streaming.is_open:
|
||||||
|
workflow_run, browser_session = await verify_workflow_run(
|
||||||
|
workflow_run_id=streaming.workflow_run.workflow_run_id,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming.workflow_run = workflow_run
|
||||||
|
streaming.browser_session = browser_session
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
|
async def loop_stream_vnc(streaming: Streaming) -> None:
|
||||||
|
"""
|
||||||
|
Actually stream the VNC session data between a frontend and a browser
|
||||||
|
session.
|
||||||
|
|
||||||
|
Loops until the task is cleared or the websocket is closed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not streaming.browser_session:
|
||||||
|
LOG.info("No browser session found for task.", task=streaming.task, organization_id=streaming.organization_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
browser_address = streaming.browser_session.browser_address
|
||||||
|
host, _ = browser_address.rsplit(":")
|
||||||
|
vnc_url = f"ws://{host}:{streaming.vnc_port}"
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Connecting to VNC URL.",
|
||||||
|
browser_address=browser_address,
|
||||||
|
vnc_url=vnc_url,
|
||||||
|
task=streaming.task,
|
||||||
|
workflow_run=streaming.workflow_run,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with websockets.connect(vnc_url) as novnc_ws:
|
||||||
|
|
||||||
|
async def frontend_to_browser() -> None:
|
||||||
|
LOG.info("Starting frontend-to-browser data transfer.", streaming=streaming)
|
||||||
|
data: Data | None = None
|
||||||
|
|
||||||
|
while streaming.is_open:
|
||||||
|
try:
|
||||||
|
data = await streaming.websocket.receive_bytes()
|
||||||
|
|
||||||
|
if data:
|
||||||
|
message_type = data[0]
|
||||||
|
|
||||||
|
# TODO: verify 4,5 are keyboard/mouse; they seem to be
|
||||||
|
if not streaming.interactor == "user" and message_type in (
|
||||||
|
MessageType.keyboard.value,
|
||||||
|
MessageType.mouse.value,
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"Blocking user message.", task=streaming.task, organization_id=streaming.organization_id
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info("Frontend disconnected.", task=streaming.task, organization_id=streaming.organization_id)
|
||||||
|
raise
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.info(
|
||||||
|
"Frontend closed the streaming session.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
"An unexpected exception occurred.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await novnc_ws.send(data)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info(
|
||||||
|
"Browser disconnected from the streaming session.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.info(
|
||||||
|
"Browser closed the streaming session.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
"An unexpected exception occurred in frontend-to-browser loop.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def browser_to_frontend() -> None:
|
||||||
|
LOG.info("Starting browser-to-frontend data transfer.", streaming=streaming)
|
||||||
|
data: Data | None = None
|
||||||
|
|
||||||
|
while streaming.is_open:
|
||||||
|
try:
|
||||||
|
data = await novnc_ws.recv()
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info(
|
||||||
|
"Browser disconnected from the streaming session.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
await streaming.close(reason="browser-disconnected")
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.info(
|
||||||
|
"Browser closed the streaming session.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
await streaming.close(reason="browser-closed")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
"An unexpected exception occurred in browser-to-frontend loop.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await streaming.websocket.send_bytes(data)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
LOG.info(
|
||||||
|
"Frontend disconnected from the streaming session.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
await streaming.close(reason="frontend-disconnected")
|
||||||
|
except ConnectionClosedError:
|
||||||
|
LOG.info(
|
||||||
|
"Frontend closed the streaming session.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
await streaming.close(reason="frontend-closed")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
"An unexpected exception occurred.",
|
||||||
|
task=streaming.task,
|
||||||
|
organization_id=streaming.organization_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(frontend_to_browser()),
|
||||||
|
asyncio.create_task(browser_to_frontend()),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await collect(loops)
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
"An exception occurred in loop stream.", task=streaming.task, organization_id=streaming.organization_id
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
LOG.info("Closing the loop stream.", task=streaming.task, organization_id=streaming.organization_id)
|
||||||
|
await streaming.close(reason="loop-stream-vnc-closed")
|
||||||
|
|
||||||
|
|
||||||
|
@legacy_base_router.websocket("/stream/vnc/task/{task_id}")
|
||||||
|
async def task_stream(
|
||||||
|
websocket: WebSocket,
|
||||||
|
task_id: str,
|
||||||
|
apikey: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
await stream(websocket, apikey=apikey, task_id=task_id, token=token)
|
||||||
|
|
||||||
|
|
||||||
|
@legacy_base_router.websocket("/stream/vnc/workflow_run/{workflow_run_id}")
|
||||||
|
async def workflow_run_stream(
|
||||||
|
websocket: WebSocket,
|
||||||
|
workflow_run_id: str,
|
||||||
|
apikey: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
await stream(websocket, apikey=apikey, workflow_run_id=workflow_run_id, token=token)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream(
|
||||||
|
websocket: WebSocket,
|
||||||
|
*,
|
||||||
|
apikey: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
workflow_run_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
LOG.info("Starting VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||||
|
|
||||||
|
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||||
|
|
||||||
|
if not organization_id:
|
||||||
|
LOG.info("Authentication failed.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
streaming: Streaming
|
||||||
|
loops: list[asyncio.Task] = []
|
||||||
|
|
||||||
|
if task_id:
|
||||||
|
result = await get_streaming_for_task(task_id=task_id, organization_id=organization_id, websocket=websocket)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
LOG.error("No streaming context found for the task.", task_id=task_id, organization_id=organization_id)
|
||||||
|
await websocket.close(code=1013)
|
||||||
|
return
|
||||||
|
|
||||||
|
streaming, loops = result
|
||||||
|
|
||||||
|
LOG.info("Starting streaming for task.", task_id=task_id, organization_id=organization_id)
|
||||||
|
|
||||||
|
elif workflow_run_id:
|
||||||
|
LOG.info(
|
||||||
|
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||||
|
)
|
||||||
|
result = await get_streaming_for_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
LOG.error(
|
||||||
|
"No streaming context found for the workflow run.",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
await websocket.close(code=1013)
|
||||||
|
return
|
||||||
|
|
||||||
|
streaming, loops = result
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Starting streaming for workflow run.", workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.error("Neither task ID nor workflow run ID was provided.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
LOG.info(
|
||||||
|
"Starting streaming loops.",
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
await collect(loops)
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
"An exception occurred in the stream function.",
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
LOG.info(
|
||||||
|
"Closing the streaming session.",
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
await streaming.close(reason="stream-closed")
|
||||||
@@ -18,3 +18,7 @@ class PersistentBrowserSession(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
deleted_at: datetime | None = None
|
deleted_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AddressablePersistentBrowserSession(PersistentBrowserSession):
|
||||||
|
browser_address: str
|
||||||
|
|||||||
27
skyvern/forge/sdk/utils/aio.py
Normal file
27
skyvern/forge/sdk/utils/aio.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
async def collect(tasks: Sequence[asyncio.Task]) -> list[Any]:
|
||||||
|
"""
|
||||||
|
An alternative to 'gather'.
|
||||||
|
|
||||||
|
Waits for the first task to complete or fail, cancels others, and propagates
|
||||||
|
the first exception.
|
||||||
|
|
||||||
|
Returns the results of all tasks (if all tasks succeed).
|
||||||
|
"""
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
|
||||||
|
|
||||||
|
for p in pending:
|
||||||
|
p.cancel()
|
||||||
|
|
||||||
|
await asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
|
||||||
|
for task in done:
|
||||||
|
exc = task.exception()
|
||||||
|
if exc:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
return [task.result() for task in done]
|
||||||
@@ -6,6 +6,7 @@ import structlog
|
|||||||
from playwright._impl._errors import TargetClosedError
|
from playwright._impl._errors import TargetClosedError
|
||||||
|
|
||||||
from skyvern.forge.sdk.db.client import AgentDB
|
from skyvern.forge.sdk.db.client import AgentDB
|
||||||
|
from skyvern.forge.sdk.db.polls import wait_on_persistent_browser_address
|
||||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
|
from skyvern.forge.sdk.schemas.persistent_browser_sessions import PersistentBrowserSession
|
||||||
from skyvern.webeye.browser_factory import BrowserState
|
from skyvern.webeye.browser_factory import BrowserState
|
||||||
|
|
||||||
@@ -30,6 +31,23 @@ class PersistentSessionsManager:
|
|||||||
cls.instance.database = database
|
cls.instance.database = database
|
||||||
return cls.instance
|
return cls.instance
|
||||||
|
|
||||||
|
async def get_browser_address(self, session_id: str, organization_id: str) -> tuple[str, str, str]:
|
||||||
|
address = await wait_on_persistent_browser_address(self.database, session_id, organization_id)
|
||||||
|
|
||||||
|
if address is None:
|
||||||
|
raise Exception(f"Browser address not found for persistent browser session {session_id}")
|
||||||
|
|
||||||
|
protocol = "http"
|
||||||
|
host, cdp_port = address.split(":")
|
||||||
|
|
||||||
|
return protocol, host, cdp_port
|
||||||
|
|
||||||
|
async def get_session_by_runnable_id(
|
||||||
|
self, runnable_id: str, organization_id: str
|
||||||
|
) -> PersistentBrowserSession | None:
|
||||||
|
"""Get a specific browser session by runnable ID."""
|
||||||
|
return await self.database.get_persistent_browser_session_by_runnable_id(runnable_id, organization_id)
|
||||||
|
|
||||||
async def get_active_sessions(self, organization_id: str) -> list[PersistentBrowserSession]:
|
async def get_active_sessions(self, organization_id: str) -> list[PersistentBrowserSession]:
|
||||||
"""Get all active sessions for an organization."""
|
"""Get all active sessions for an organization."""
|
||||||
return await self.database.get_active_persistent_browser_sessions(organization_id)
|
return await self.database.get_active_persistent_browser_sessions(organization_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user