diff --git a/skyvern/forge/sdk/routes/streaming_clients.py b/skyvern/forge/sdk/routes/streaming_clients.py index 3d54e3a7..8baf8c7b 100644 --- a/skyvern/forge/sdk/routes/streaming_clients.py +++ b/skyvern/forge/sdk/routes/streaming_clients.py @@ -78,7 +78,7 @@ class CommandChannel: if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING): return False - if not self.workflow_run: + if not self.workflow_run and not self.browser_session: return False if not get_command_client(self.client_id): @@ -231,7 +231,7 @@ class Streaming: if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING): return False - if not self.task and not self.workflow_run: + if not self.task and not self.workflow_run and not self.browser_session: return False if not get_streaming_client(self.client_id): diff --git a/skyvern/forge/sdk/routes/streaming_commands.py b/skyvern/forge/sdk/routes/streaming_commands.py index fa2f8b6b..7854357c 100644 --- a/skyvern/forge/sdk/routes/streaming_commands.py +++ b/skyvern/forge/sdk/routes/streaming_commands.py @@ -9,14 +9,61 @@ from fastapi import WebSocket, WebSocketDisconnect from websockets.exceptions import ConnectionClosedError import skyvern.forge.sdk.routes.streaming_clients as sc -from skyvern.forge.sdk.routes.routers import legacy_base_router +from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router from skyvern.forge.sdk.routes.streaming_auth import auth -from skyvern.forge.sdk.routes.streaming_verify import loop_verify_workflow_run, verify_workflow_run +from skyvern.forge.sdk.routes.streaming_verify import ( + loop_verify_browser_session, + loop_verify_workflow_run, + verify_browser_session, + verify_workflow_run, +) from skyvern.forge.sdk.utils.aio import collect LOG = structlog.get_logger() +async def get_commands_for_browser_session( + client_id: str, + browser_session_id: str, + organization_id: str, + websocket: WebSocket, +) -> tuple[sc.CommandChannel, sc.Loops] | None: + """ + Return a commands channel for a browser session, with a list of loops to run concurrently. + """ + + LOG.info("Getting commands channel for browser session.", browser_session_id=browser_session_id) + + browser_session = await verify_browser_session( + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + + if not browser_session: + LOG.info( + "Command channel: no initial browser session found.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + return None + + commands = sc.CommandChannel( + client_id=client_id, + organization_id=organization_id, + browser_session=browser_session, + websocket=websocket, + ) + + LOG.info("Got command channel for browser session.", commands=commands) + + loops = [ + asyncio.create_task(loop_verify_browser_session(commands)), + asyncio.create_task(loop_channel(commands)), + ] + + return commands, loops + + async def get_commands_for_workflow_run( client_id: str, workflow_run_id: str, @@ -163,6 +210,70 @@ async def loop_channel(commands: sc.CommandChannel) -> None: await commands.close(reason="loop-channel-closed") +@base_router.websocket("/stream/commands/browser_session/{browser_session_id}") +async def browser_session_commands( + websocket: WebSocket, + browser_session_id: str, + apikey: str | None = None, + client_id: str | None = None, + token: str | None = None, +) -> None: + LOG.info("Starting stream commands for browser session.", browser_session_id=browser_session_id) + + organization_id = await auth(apikey=apikey, token=token, websocket=websocket) + + if not organization_id: + LOG.error("Authentication failed.", browser_session_id=browser_session_id) + return + + if not client_id: + LOG.error("No client ID provided.", browser_session_id=browser_session_id) + return + + commands: sc.CommandChannel + loops: list[asyncio.Task] = [] + + result = await get_commands_for_browser_session( + client_id=client_id, + browser_session_id=browser_session_id, + organization_id=organization_id, + websocket=websocket, + ) + + if not result: + LOG.error( + "No streaming context found for the browser session.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + await websocket.close(code=1013) + return + + commands, loops = result + + try: + LOG.info( + "Starting command stream loops for browser session.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + await collect(loops) + except Exception: + LOG.exception( + "An exception occurred in the command stream function for browser session.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + finally: + LOG.info( + "Closing the command stream session for browser session.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + + await commands.close(reason="stream-closed") + + @legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}") async def workflow_run_commands( websocket: WebSocket, diff --git a/skyvern/forge/sdk/routes/streaming_verify.py b/skyvern/forge/sdk/routes/streaming_verify.py index 4a74b5f7..24d9ab61 100644 --- a/skyvern/forge/sdk/routes/streaming_verify.py +++ b/skyvern/forge/sdk/routes/streaming_verify.py @@ -13,6 +13,69 @@ from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunS LOG = structlog.get_logger() +async def verify_browser_session( + browser_session_id: str, + organization_id: str, +) -> AddressablePersistentBrowserSession | None: + """ + Verify the browser session exists, and is usable. + """ + + if settings.ENV == "local": + dummy_browser_session = AddressablePersistentBrowserSession( + persistent_browser_session_id=browser_session_id, + organization_id=organization_id, + browser_address="0.0.0.0:9223", + created_at=datetime.now(), + modified_at=datetime.now(), + ) + + return dummy_browser_session + + browser_session = await app.DATABASE.get_persistent_browser_session(browser_session_id, organization_id) + + if not browser_session: + LOG.info( + "No browser session found.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + return None + + browser_address = browser_session.browser_address + + if not browser_address: + LOG.info( + "Waiting for browser session address.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + + try: + _, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address( + session_id=browser_session_id, + organization_id=organization_id, + ) + browser_address = f"{host}:{cdp_port}" + except Exception as ex: + LOG.info( + "Browser session address not found for browser session.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ex=ex, + ) + return None + + try: + addressable_browser_session = AddressablePersistentBrowserSession( + **browser_session.model_dump() | {"browser_address": browser_address}, + ) + except Exception: + return None + + return addressable_browser_session + + async def verify_task( task_id: str, organization_id: str ) -> tuple[Task | None, AddressablePersistentBrowserSession | None]: @@ -172,6 +235,22 @@ async def verify_workflow_run( return workflow_run, addressable_browser_session +async def loop_verify_browser_session(verifiable: sc.CommandChannel | sc.Streaming) -> None: + """ + Loop until the browser session is cleared or the websocket is closed. + """ + + while verifiable.browser_session and verifiable.is_open: + browser_session = await verify_browser_session( + browser_session_id=verifiable.browser_session.persistent_browser_session_id, + organization_id=verifiable.organization_id, + ) + + verifiable.browser_session = browser_session + + await asyncio.sleep(2) + + async def loop_verify_task(streaming: sc.Streaming) -> None: """ Loop until the task is cleared or the websocket is closed. diff --git a/skyvern/forge/sdk/routes/streaming_vnc.py b/skyvern/forge/sdk/routes/streaming_vnc.py index 0e065d13..77d51fbf 100644 --- a/skyvern/forge/sdk/routes/streaming_vnc.py +++ b/skyvern/forge/sdk/routes/streaming_vnc.py @@ -12,11 +12,13 @@ from websockets.exceptions import ConnectionClosedError import skyvern.forge.sdk.routes.streaming_clients as sc from skyvern.config import settings -from skyvern.forge.sdk.routes.routers import legacy_base_router +from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router from skyvern.forge.sdk.routes.streaming_auth import auth from skyvern.forge.sdk.routes.streaming_verify import ( + loop_verify_browser_session, loop_verify_task, loop_verify_workflow_run, + verify_browser_session, verify_task, verify_workflow_run, ) @@ -25,6 +27,48 @@ from skyvern.forge.sdk.utils.aio import collect LOG = structlog.get_logger() +async def get_streaming_for_browser_session( + client_id: str, + browser_session_id: str, + organization_id: str, + websocket: WebSocket, +) -> tuple[sc.Streaming, sc.Loops] | None: + """ + Return a streaming context for a browser session, with a list of loops to run concurrently. + """ + + LOG.info("Getting streaming context for browser session.", browser_session_id=browser_session_id) + + browser_session = await verify_browser_session( + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + + if not browser_session: + LOG.info( + "No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id + ) + return None + + streaming = sc.Streaming( + client_id=client_id, + interactor="agent", + organization_id=organization_id, + vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, + browser_session=browser_session, + websocket=websocket, + ) + + LOG.info("Got streaming context for browser session.", streaming=streaming) + + loops = [ + asyncio.create_task(loop_verify_browser_session(streaming)), + asyncio.create_task(loop_stream_vnc(streaming)), + ] + + return streaming, loops + + async def get_streaming_for_task( client_id: str, task_id: str, @@ -295,6 +339,17 @@ async def loop_stream_vnc(streaming: sc.Streaming) -> None: await streaming.close(reason="loop-stream-vnc-closed") +@base_router.websocket("/stream/vnc/browser_session/{browser_session_id}") +async def browser_session_stream( + websocket: WebSocket, + browser_session_id: str, + apikey: str | None = None, + client_id: str | None = None, + token: str | None = None, +) -> None: + await stream(websocket, apikey=apikey, client_id=client_id, browser_session_id=browser_session_id, token=token) + + @legacy_base_router.websocket("/stream/vnc/task/{task_id}") async def task_stream( websocket: WebSocket, @@ -321,16 +376,28 @@ async def stream( websocket: WebSocket, *, apikey: str | None = None, + browser_session_id: str | None = None, client_id: str | None = None, task_id: str | None = None, token: str | None = None, workflow_run_id: str | None = None, ) -> None: if not client_id: - LOG.error("Client ID not provided for VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id) + LOG.error( + "Client ID not provided for VNC stream.", + browser_session_id=browser_session_id, + task_id=task_id, + workflow_run_id=workflow_run_id, + ) return - LOG.info("Starting VNC stream.", client_id=client_id, task_id=task_id, workflow_run_id=workflow_run_id) + LOG.info( + "Starting VNC stream.", + browser_session_id=browser_session_id, + client_id=client_id, + task_id=task_id, + workflow_run_id=workflow_run_id, + ) organization_id = await auth(apikey=apikey, token=token, websocket=websocket) @@ -341,7 +408,31 @@ async def stream( streaming: sc.Streaming loops: list[asyncio.Task] = [] - if task_id: + if browser_session_id: + result = await get_streaming_for_browser_session( + client_id=client_id, + browser_session_id=browser_session_id, + organization_id=organization_id, + websocket=websocket, + ) + + if not result: + LOG.error( + "No streaming context found for the browser session.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + await websocket.close(code=1013) + return + + streaming, loops = result + + LOG.info( + "Starting streaming for browser session.", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + elif task_id: result = await get_streaming_for_task( client_id=client_id, task_id=task_id,