Jon/browser session stream (#2910)
This commit is contained in:
@@ -78,7 +78,7 @@ class CommandChannel:
|
|||||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not self.workflow_run:
|
if not self.workflow_run and not self.browser_session:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not get_command_client(self.client_id):
|
if not get_command_client(self.client_id):
|
||||||
@@ -231,7 +231,7 @@ class Streaming:
|
|||||||
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
if self.websocket.client_state not in (WebSocketState.CONNECTED, WebSocketState.CONNECTING):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not self.task and not self.workflow_run:
|
if not self.task and not self.workflow_run and not self.browser_session:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not get_streaming_client(self.client_id):
|
if not get_streaming_client(self.client_id):
|
||||||
|
|||||||
@@ -9,14 +9,61 @@ from fastapi import WebSocket, WebSocketDisconnect
|
|||||||
from websockets.exceptions import ConnectionClosedError
|
from websockets.exceptions import ConnectionClosedError
|
||||||
|
|
||||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||||
from skyvern.forge.sdk.routes.routers import legacy_base_router
|
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||||
from skyvern.forge.sdk.routes.streaming_auth import auth
|
from skyvern.forge.sdk.routes.streaming_auth import auth
|
||||||
from skyvern.forge.sdk.routes.streaming_verify import loop_verify_workflow_run, verify_workflow_run
|
from skyvern.forge.sdk.routes.streaming_verify import (
|
||||||
|
loop_verify_browser_session,
|
||||||
|
loop_verify_workflow_run,
|
||||||
|
verify_browser_session,
|
||||||
|
verify_workflow_run,
|
||||||
|
)
|
||||||
from skyvern.forge.sdk.utils.aio import collect
|
from skyvern.forge.sdk.utils.aio import collect
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_commands_for_browser_session(
|
||||||
|
client_id: str,
|
||||||
|
browser_session_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[sc.CommandChannel, sc.Loops] | None:
|
||||||
|
"""
|
||||||
|
Return a commands channel for a browser session, with a list of loops to run concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOG.info("Getting commands channel for browser session.", browser_session_id=browser_session_id)
|
||||||
|
|
||||||
|
browser_session = await verify_browser_session(
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info(
|
||||||
|
"Command channel: no initial browser session found.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
commands = sc.CommandChannel(
|
||||||
|
client_id=client_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
browser_session=browser_session,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Got command channel for browser session.", commands=commands)
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(loop_verify_browser_session(commands)),
|
||||||
|
asyncio.create_task(loop_channel(commands)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return commands, loops
|
||||||
|
|
||||||
|
|
||||||
async def get_commands_for_workflow_run(
|
async def get_commands_for_workflow_run(
|
||||||
client_id: str,
|
client_id: str,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@@ -163,6 +210,70 @@ async def loop_channel(commands: sc.CommandChannel) -> None:
|
|||||||
await commands.close(reason="loop-channel-closed")
|
await commands.close(reason="loop-channel-closed")
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
|
||||||
|
async def browser_session_commands(
|
||||||
|
websocket: WebSocket,
|
||||||
|
browser_session_id: str,
|
||||||
|
apikey: str | None = None,
|
||||||
|
client_id: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
LOG.info("Starting stream commands for browser session.", browser_session_id=browser_session_id)
|
||||||
|
|
||||||
|
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||||
|
|
||||||
|
if not organization_id:
|
||||||
|
LOG.error("Authentication failed.", browser_session_id=browser_session_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not client_id:
|
||||||
|
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
commands: sc.CommandChannel
|
||||||
|
loops: list[asyncio.Task] = []
|
||||||
|
|
||||||
|
result = await get_commands_for_browser_session(
|
||||||
|
client_id=client_id,
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
LOG.error(
|
||||||
|
"No streaming context found for the browser session.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
await websocket.close(code=1013)
|
||||||
|
return
|
||||||
|
|
||||||
|
commands, loops = result
|
||||||
|
|
||||||
|
try:
|
||||||
|
LOG.info(
|
||||||
|
"Starting command stream loops for browser session.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
await collect(loops)
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(
|
||||||
|
"An exception occurred in the command stream function for browser session.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
LOG.info(
|
||||||
|
"Closing the command stream session for browser session.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await commands.close(reason="stream-closed")
|
||||||
|
|
||||||
|
|
||||||
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
|
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
|
||||||
async def workflow_run_commands(
|
async def workflow_run_commands(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
|
|||||||
@@ -13,6 +13,69 @@ from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun, WorkflowRunS
|
|||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_browser_session(
|
||||||
|
browser_session_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
) -> AddressablePersistentBrowserSession | None:
|
||||||
|
"""
|
||||||
|
Verify the browser session exists, and is usable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if settings.ENV == "local":
|
||||||
|
dummy_browser_session = AddressablePersistentBrowserSession(
|
||||||
|
persistent_browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
browser_address="0.0.0.0:9223",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
modified_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return dummy_browser_session
|
||||||
|
|
||||||
|
browser_session = await app.DATABASE.get_persistent_browser_session(browser_session_id, organization_id)
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info(
|
||||||
|
"No browser session found.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
browser_address = browser_session.browser_address
|
||||||
|
|
||||||
|
if not browser_address:
|
||||||
|
LOG.info(
|
||||||
|
"Waiting for browser session address.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_, host, cdp_port = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_address(
|
||||||
|
session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
browser_address = f"{host}:{cdp_port}"
|
||||||
|
except Exception as ex:
|
||||||
|
LOG.info(
|
||||||
|
"Browser session address not found for browser session.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
ex=ex,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
addressable_browser_session = AddressablePersistentBrowserSession(
|
||||||
|
**browser_session.model_dump() | {"browser_address": browser_address},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return addressable_browser_session
|
||||||
|
|
||||||
|
|
||||||
async def verify_task(
|
async def verify_task(
|
||||||
task_id: str, organization_id: str
|
task_id: str, organization_id: str
|
||||||
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
|
) -> tuple[Task | None, AddressablePersistentBrowserSession | None]:
|
||||||
@@ -172,6 +235,22 @@ async def verify_workflow_run(
|
|||||||
return workflow_run, addressable_browser_session
|
return workflow_run, addressable_browser_session
|
||||||
|
|
||||||
|
|
||||||
|
async def loop_verify_browser_session(verifiable: sc.CommandChannel | sc.Streaming) -> None:
|
||||||
|
"""
|
||||||
|
Loop until the browser session is cleared or the websocket is closed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
while verifiable.browser_session and verifiable.is_open:
|
||||||
|
browser_session = await verify_browser_session(
|
||||||
|
browser_session_id=verifiable.browser_session.persistent_browser_session_id,
|
||||||
|
organization_id=verifiable.organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
verifiable.browser_session = browser_session
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
async def loop_verify_task(streaming: sc.Streaming) -> None:
|
async def loop_verify_task(streaming: sc.Streaming) -> None:
|
||||||
"""
|
"""
|
||||||
Loop until the task is cleared or the websocket is closed.
|
Loop until the task is cleared or the websocket is closed.
|
||||||
|
|||||||
@@ -12,11 +12,13 @@ from websockets.exceptions import ConnectionClosedError
|
|||||||
|
|
||||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.forge.sdk.routes.routers import legacy_base_router
|
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||||
from skyvern.forge.sdk.routes.streaming_auth import auth
|
from skyvern.forge.sdk.routes.streaming_auth import auth
|
||||||
from skyvern.forge.sdk.routes.streaming_verify import (
|
from skyvern.forge.sdk.routes.streaming_verify import (
|
||||||
|
loop_verify_browser_session,
|
||||||
loop_verify_task,
|
loop_verify_task,
|
||||||
loop_verify_workflow_run,
|
loop_verify_workflow_run,
|
||||||
|
verify_browser_session,
|
||||||
verify_task,
|
verify_task,
|
||||||
verify_workflow_run,
|
verify_workflow_run,
|
||||||
)
|
)
|
||||||
@@ -25,6 +27,48 @@ from skyvern.forge.sdk.utils.aio import collect
|
|||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_streaming_for_browser_session(
|
||||||
|
client_id: str,
|
||||||
|
browser_session_id: str,
|
||||||
|
organization_id: str,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[sc.Streaming, sc.Loops] | None:
|
||||||
|
"""
|
||||||
|
Return a streaming context for a browser session, with a list of loops to run concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOG.info("Getting streaming context for browser session.", browser_session_id=browser_session_id)
|
||||||
|
|
||||||
|
browser_session = await verify_browser_session(
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not browser_session:
|
||||||
|
LOG.info(
|
||||||
|
"No initial browser session found.", browser_session_id=browser_session_id, organization_id=organization_id
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
streaming = sc.Streaming(
|
||||||
|
client_id=client_id,
|
||||||
|
interactor="agent",
|
||||||
|
organization_id=organization_id,
|
||||||
|
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||||
|
browser_session=browser_session,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Got streaming context for browser session.", streaming=streaming)
|
||||||
|
|
||||||
|
loops = [
|
||||||
|
asyncio.create_task(loop_verify_browser_session(streaming)),
|
||||||
|
asyncio.create_task(loop_stream_vnc(streaming)),
|
||||||
|
]
|
||||||
|
|
||||||
|
return streaming, loops
|
||||||
|
|
||||||
|
|
||||||
async def get_streaming_for_task(
|
async def get_streaming_for_task(
|
||||||
client_id: str,
|
client_id: str,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
@@ -295,6 +339,17 @@ async def loop_stream_vnc(streaming: sc.Streaming) -> None:
|
|||||||
await streaming.close(reason="loop-stream-vnc-closed")
|
await streaming.close(reason="loop-stream-vnc-closed")
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.websocket("/stream/vnc/browser_session/{browser_session_id}")
|
||||||
|
async def browser_session_stream(
|
||||||
|
websocket: WebSocket,
|
||||||
|
browser_session_id: str,
|
||||||
|
apikey: str | None = None,
|
||||||
|
client_id: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
await stream(websocket, apikey=apikey, client_id=client_id, browser_session_id=browser_session_id, token=token)
|
||||||
|
|
||||||
|
|
||||||
@legacy_base_router.websocket("/stream/vnc/task/{task_id}")
|
@legacy_base_router.websocket("/stream/vnc/task/{task_id}")
|
||||||
async def task_stream(
|
async def task_stream(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
@@ -321,16 +376,28 @@ async def stream(
|
|||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
*,
|
*,
|
||||||
apikey: str | None = None,
|
apikey: str | None = None,
|
||||||
|
browser_session_id: str | None = None,
|
||||||
client_id: str | None = None,
|
client_id: str | None = None,
|
||||||
task_id: str | None = None,
|
task_id: str | None = None,
|
||||||
token: str | None = None,
|
token: str | None = None,
|
||||||
workflow_run_id: str | None = None,
|
workflow_run_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not client_id:
|
if not client_id:
|
||||||
LOG.error("Client ID not provided for VNC stream.", task_id=task_id, workflow_run_id=workflow_run_id)
|
LOG.error(
|
||||||
|
"Client ID not provided for VNC stream.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
LOG.info("Starting VNC stream.", client_id=client_id, task_id=task_id, workflow_run_id=workflow_run_id)
|
LOG.info(
|
||||||
|
"Starting VNC stream.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
client_id=client_id,
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||||
|
|
||||||
@@ -341,7 +408,31 @@ async def stream(
|
|||||||
streaming: sc.Streaming
|
streaming: sc.Streaming
|
||||||
loops: list[asyncio.Task] = []
|
loops: list[asyncio.Task] = []
|
||||||
|
|
||||||
if task_id:
|
if browser_session_id:
|
||||||
|
result = await get_streaming_for_browser_session(
|
||||||
|
client_id=client_id,
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
websocket=websocket,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
LOG.error(
|
||||||
|
"No streaming context found for the browser session.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
await websocket.close(code=1013)
|
||||||
|
return
|
||||||
|
|
||||||
|
streaming, loops = result
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Starting streaming for browser session.",
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
elif task_id:
|
||||||
result = await get_streaming_for_task(
|
result = await get_streaming_for_task(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user