Files
Dorod-Sky/skyvern/forge/sdk/routes/streaming/messages.py
2025-11-19 09:35:05 -05:00

364 lines
12 KiB
Python

"""
Streaming messages for WebSocket connections.
"""
import asyncio
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
import skyvern.forge.sdk.routes.streaming.clients as sc
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming.agent import connected_agent
from skyvern.forge.sdk.routes.streaming.auth import auth
from skyvern.forge.sdk.routes.streaming.verify import (
loop_verify_browser_session,
loop_verify_workflow_run,
verify_browser_session,
verify_workflow_run,
)
from skyvern.forge.sdk.utils.aio import collect
LOG = structlog.get_logger()
async def get_messages_for_browser_session(
client_id: str,
browser_session_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.MessageChannel, sc.Loops] | None:
"""
Return a message channel for a browser session, with a list of loops to run concurrently.
"""
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
browser_session = await verify_browser_session(
browser_session_id=browser_session_id,
organization_id=organization_id,
)
if not browser_session:
LOG.info(
"Message channel: no initial browser session found.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
return None
message_channel = sc.MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
websocket=websocket,
)
LOG.info("Got message channel for browser session.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_browser_session(message_channel)),
asyncio.create_task(loop_channel(message_channel)),
]
return message_channel, loops
async def get_messages_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[sc.MessageChannel, sc.Loops] | None:
"""
Return a message channel for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info(
"Message channel: no initial workflow run found.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
if not browser_session:
LOG.info(
"Message channel: no initial browser session found for workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
message_channel = sc.MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
workflow_run=workflow_run,
websocket=websocket,
)
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_workflow_run(message_channel)),
asyncio.create_task(loop_channel(message_channel)),
]
return message_channel, loops
async def loop_channel(message_channel: sc.MessageChannel) -> None:
"""
Stream messages and their results back and forth.
Loops until the workflow run is cleared or the websocket is closed.
"""
if not message_channel.browser_session:
LOG.info(
"No browser session found for workflow run.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
return
async def frontend_to_backend() -> None:
LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel)
while message_channel.is_open:
try:
data = await message_channel.websocket.receive_json()
if not isinstance(data, dict):
LOG.error(f"Cannot create channel message: expected dict, got {type(data)}")
continue
try:
message = sc.reify_channel_message(data)
except ValueError:
continue
message_kind = message.kind
match message_kind:
case "take-control":
streaming = sc.get_streaming_client(message_channel.client_id)
if not streaming:
LOG.error(
"No streaming client found for message.",
message_channel=message_channel,
message=message,
)
continue
streaming.interactor = "user"
case "cede-control":
streaming = sc.get_streaming_client(message_channel.client_id)
if not streaming:
LOG.error(
"No streaming client found for message.",
message_channel=message_channel,
message=message,
)
continue
streaming.interactor = "agent"
case "ask-for-clipboard-response":
if not isinstance(message, sc.MessageInAskForClipboardResponse):
LOG.error(
"Invalid message type for ask-for-clipboard-response.",
message_channel=message_channel,
message=message,
)
continue
streaming = sc.get_streaming_client(message_channel.client_id)
text = message.text
async with connected_agent(streaming) as agent:
await agent.paste_text(text)
case _:
LOG.error(f"Unknown message kind: '{message_kind}'")
continue
except WebSocketDisconnect:
LOG.info(
"Frontend disconnected.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
except ConnectionClosedError:
LOG.info(
"Frontend closed the streaming session.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
except asyncio.CancelledError:
pass
except Exception:
LOG.exception(
"An unexpected exception occurred.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
raise
loops = [
asyncio.create_task(frontend_to_backend()),
]
try:
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in loop channel stream.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
finally:
LOG.info(
"Closing the loop channel stream.",
workflow_run=message_channel.workflow_run,
organization_id=message_channel.organization_id,
)
await message_channel.close(reason="loop-channel-closed")
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
async def browser_session_messages(
websocket: WebSocket,
browser_session_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.error("Authentication failed.", browser_session_id=browser_session_id)
return
if not client_id:
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
return
message_channel: sc.MessageChannel
loops: list[asyncio.Task] = []
result = await get_messages_for_browser_session(
client_id=client_id,
browser_session_id=browser_session_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error(
"No streaming context found for the browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
message_channel, loops = result
try:
LOG.info(
"Starting message stream loops for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the message stream function for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the message stream session for browser session.",
browser_session_id=browser_session_id,
organization_id=organization_id,
)
await message_channel.close(reason="stream-closed")
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
async def workflow_run_messages(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
client_id: str | None = None,
token: str | None = None,
) -> None:
LOG.info("Starting message stream.", workflow_run_id=workflow_run_id)
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
return
if not client_id:
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
return
message_channel: sc.MessageChannel
loops: list[asyncio.Task] = []
result = await get_messages_for_workflow_run(
client_id=client_id,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
websocket=websocket,
)
if not result:
LOG.error(
"No streaming context found for the workflow run.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.close(code=1013)
return
message_channel, loops = result
try:
LOG.info(
"Starting message stream loops.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await collect(loops)
except Exception:
LOG.exception(
"An exception occurred in the message stream function.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
finally:
LOG.info(
"Closing the message stream session.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await message_channel.close(reason="stream-closed")