BE portion of seamless clipboard transfer in browser stream (#3788)
This commit is contained in:
363
skyvern/forge/sdk/routes/streaming_messages.py
Normal file
363
skyvern/forge/sdk/routes/streaming_messages.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
Streaming messages for WebSocket connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming_agent import connected_agent
|
||||
from skyvern.forge.sdk.routes.streaming_auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming_verify import (
|
||||
loop_verify_browser_session,
|
||||
loop_verify_workflow_run,
|
||||
verify_browser_session,
|
||||
verify_workflow_run,
|
||||
)
|
||||
from skyvern.forge.sdk.utils.aio import collect
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_messages_for_browser_session(
|
||||
client_id: str,
|
||||
browser_session_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
||||
"""
|
||||
Return a message channel for a browser session, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting message channel for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
browser_session = await verify_browser_session(
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Message channel: no initial browser session found.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
message_channel = sc.MessageChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got message channel for browser session.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_browser_session(message_channel)),
|
||||
asyncio.create_task(loop_channel(message_channel)),
|
||||
]
|
||||
|
||||
return message_channel, loops
|
||||
|
||||
|
||||
async def get_messages_for_workflow_run(
|
||||
client_id: str,
|
||||
workflow_run_id: str,
|
||||
organization_id: str,
|
||||
websocket: WebSocket,
|
||||
) -> tuple[sc.MessageChannel, sc.Loops] | None:
|
||||
"""
|
||||
Return a message channel for a workflow run, with a list of loops to run concurrently.
|
||||
"""
|
||||
|
||||
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run, browser_session = await verify_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
LOG.info(
|
||||
"Message channel: no initial workflow run found.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
if not browser_session:
|
||||
LOG.info(
|
||||
"Message channel: no initial browser session found for workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
message_channel = sc.MessageChannel(
|
||||
client_id=client_id,
|
||||
organization_id=organization_id,
|
||||
browser_session=browser_session,
|
||||
workflow_run=workflow_run,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(loop_verify_workflow_run(message_channel)),
|
||||
asyncio.create_task(loop_channel(message_channel)),
|
||||
]
|
||||
|
||||
return message_channel, loops
|
||||
|
||||
|
||||
async def loop_channel(message_channel: sc.MessageChannel) -> None:
|
||||
"""
|
||||
Stream messages and their results back and forth.
|
||||
|
||||
Loops until the workflow run is cleared or the websocket is closed.
|
||||
"""
|
||||
|
||||
if not message_channel.browser_session:
|
||||
LOG.info(
|
||||
"No browser session found for workflow run.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
return
|
||||
|
||||
async def frontend_to_backend() -> None:
|
||||
LOG.info("Starting frontend-to-backend channel loop.", message_channel=message_channel)
|
||||
|
||||
while message_channel.is_open:
|
||||
try:
|
||||
data = await message_channel.websocket.receive_json()
|
||||
|
||||
if not isinstance(data, dict):
|
||||
LOG.error(f"Cannot create channel message: expected dict, got {type(data)}")
|
||||
continue
|
||||
|
||||
try:
|
||||
message = sc.reify_channel_message(data)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
message_kind = message.kind
|
||||
|
||||
match message_kind:
|
||||
case "take-control":
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
if not streaming:
|
||||
LOG.error(
|
||||
"No streaming client found for message.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
streaming.interactor = "user"
|
||||
case "cede-control":
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
if not streaming:
|
||||
LOG.error(
|
||||
"No streaming client found for message.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
streaming.interactor = "agent"
|
||||
case "ask-for-clipboard-response":
|
||||
if not isinstance(message, sc.MessageInAskForClipboardResponse):
|
||||
LOG.error(
|
||||
"Invalid message type for ask-for-clipboard-response.",
|
||||
message_channel=message_channel,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
|
||||
streaming = sc.get_streaming_client(message_channel.client_id)
|
||||
text = message.text
|
||||
|
||||
async with connected_agent(streaming) as agent:
|
||||
await agent.paste_text(text)
|
||||
case _:
|
||||
LOG.error(f"Unknown message kind: '{message_kind}'")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
LOG.info(
|
||||
"Frontend disconnected.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
except ConnectionClosedError:
|
||||
LOG.info(
|
||||
"Frontend closed the streaming session.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An unexpected exception occurred.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
raise
|
||||
|
||||
loops = [
|
||||
asyncio.create_task(frontend_to_backend()),
|
||||
]
|
||||
|
||||
try:
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in loop channel stream.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the loop channel stream.",
|
||||
workflow_run=message_channel.workflow_run,
|
||||
organization_id=message_channel.organization_id,
|
||||
)
|
||||
await message_channel.close(reason="loop-channel-closed")
|
||||
|
||||
|
||||
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
|
||||
@base_router.websocket("/stream/commands/browser_session/{browser_session_id}")
|
||||
async def browser_session_messages(
|
||||
websocket: WebSocket,
|
||||
browser_session_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting message stream for browser session.", browser_session_id=browser_session_id)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
LOG.error("Authentication failed.", browser_session_id=browser_session_id)
|
||||
return
|
||||
|
||||
if not client_id:
|
||||
LOG.error("No client ID provided.", browser_session_id=browser_session_id)
|
||||
return
|
||||
|
||||
message_channel: sc.MessageChannel
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
result = await get_messages_for_browser_session(
|
||||
client_id=client_id,
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
message_channel, loops = result
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting message stream loops for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the message stream function for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the message stream session for browser session.",
|
||||
browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await message_channel.close(reason="stream-closed")
|
||||
|
||||
|
||||
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
|
||||
@legacy_base_router.websocket("/stream/commands/workflow_run/{workflow_run_id}")
|
||||
async def workflow_run_messages(
|
||||
websocket: WebSocket,
|
||||
workflow_run_id: str,
|
||||
apikey: str | None = None,
|
||||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
LOG.info("Starting message stream.", workflow_run_id=workflow_run_id)
|
||||
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
LOG.error("Authentication failed.", workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
if not client_id:
|
||||
LOG.error("No client ID provided.", workflow_run_id=workflow_run_id)
|
||||
return
|
||||
|
||||
message_channel: sc.MessageChannel
|
||||
loops: list[asyncio.Task] = []
|
||||
|
||||
result = await get_messages_for_workflow_run(
|
||||
client_id=client_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
if not result:
|
||||
LOG.error(
|
||||
"No streaming context found for the workflow run.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.close(code=1013)
|
||||
return
|
||||
|
||||
message_channel, loops = result
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Starting message stream loops.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await collect(loops)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"An exception occurred in the message stream function.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
finally:
|
||||
LOG.info(
|
||||
"Closing the message stream session.",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
await message_channel.close(reason="stream-closed")
|
||||
Reference in New Issue
Block a user