Jon/browser session stream (#2910)
This commit is contained in:
@@ -12,11 +12,13 @@ from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.routes.routers import legacy_base_router
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming_auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming_verify import (
|
||||
loop_verify_browser_session,
|
||||
loop_verify_task,
|
||||
loop_verify_workflow_run,
|
||||
verify_browser_session,
|
||||
verify_task,
|
||||
verify_workflow_run,
|
||||
)
|
||||
@@ -25,6 +27,48 @@ from skyvern.forge.sdk.utils.aio import collect
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_streaming_for_browser_session(
|
||||
client_id: str,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||
"""
|
||||
Return a streaming context for a browser session, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting streaming context for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
|
||||
)
|
||||
return None
|
||||
|
||||
streaming = sc.Streaming(
|
||||
client_id=client_id,
|
||||
interactor="agent",
|
||||
organization_id=organization_id,
|
||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||
browser_session=browser_session,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got streaming context for browser session.", streaming=streaming)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_browser_session(streaming)),
|
||||
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||
]
|
||||
|
||||
return streaming, loops
|
||||
|
||||
|
||||
async def get_streaming_for_task(
|
||||
client_id: str,
|
||||
task_id: str,
|
||||
@@ -295,6 +339,17 @@ async def loop_stream_vnc(streaming: sc.Streaming) -> None:
|
||||
await streaming.close(reason="loop-stream-vnc-closed")
|
||||
|
||||
|
||||
@base_router.websocket("/stream/vnc/browser_session/{browser_session_id}")
|
||||
async def browser_session_stream(
|
||||
websocket: WebSocket,
|
||||
browser_session_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
await stream(websocket, apikey=apikey, client_id=client_id, browser_session_id=browser_session_id, token=token)
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/vnc/task/{task_id}")
|
||||
async def task_stream(
|
||||
websocket: WebSocket,
|
||||
@@ -321,16 +376,28 @@ async def stream(
|
||||
websocket: WebSocket,
|
||||
*,
|
||||
apikey: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
client_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
token: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
) -> None:
|
||||
if not client_id:
|
||||
LOG.error("Client ID not provided for VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
|
||||
LOG.error(
|
||||
"Client ID not provided for VNC stream.",
|
||||
browser_session_id=browser_session_id,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
LOG.info("Starting VNC stream.", client_id=client_id, task_id=task_id, workflow_run_id=workflow_run_id)
|
||||
LOG.info(
|
||||
"Starting VNC stream.",
|
||||
browser_session_id=browser_session_id,
|
||||
client_id=client_id,
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
@@ -341,7 +408,31 @@ async def stream(
|
||||
streaming: sc.Streaming
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
if task_id:
|
||||
if browser_session_id:
|
||||
result = await get_streaming_for_browser_session(
|
||||
client_id=client_id,
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
streaming, loops = result
|
||||
|
||||
LOG.info(
|
||||
"Starting streaming for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
elif task_id:
|
||||
result = await get_streaming_for_task(
|
||||
client_id=client_id,
|
||||
task_id=task_id,
|
||||
|
||||
Reference in New Issue
Block a user