144 lines
4.1 KiB
Python
144 lines
4.1 KiB
Python
"""
|
|
Provides WS endpoints for streaming messages to/from our frontend application.
|
|
"""
|
|
|
|
import structlog
|
|
from fastapi import WebSocket
|
|
|
|
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
|
from skyvern.forge.sdk.routes.streaming.auth import auth
|
|
from skyvern.forge.sdk.routes.streaming.channels.message import (
|
|
Loops,
|
|
MessageChannel,
|
|
get_message_channel_for_browser_session,
|
|
get_message_channel_for_workflow_run,
|
|
)
|
|
from skyvern.forge.sdk.utils.aio import collect
|
|
|
|
LOG = structlog.get_logger()
|
|
|
|
|
|
@base_router.websocket("/stream/messages/browser_session/{browser_session_id}")
|
|
async def browser_session_messages(
|
|
websocket: WebSocket,
|
|
browser_session_id: str,
|
|
apikey: str | None = None,
|
|
client_id: str | None = None,
|
|
token: str | None = None,
|
|
) -> None:
|
|
return await messages(
|
|
websocket=websocket,
|
|
browser_session_id=browser_session_id,
|
|
apikey=apikey,
|
|
client_id=client_id,
|
|
token=token,
|
|
)
|
|
|
|
|
|
@legacy_base_router.websocket("/stream/messages/workflow_run/{workflow_run_id}")
|
|
async def workflow_run_messages(
|
|
websocket: WebSocket,
|
|
workflow_run_id: str,
|
|
apikey: str | None = None,
|
|
client_id: str | None = None,
|
|
token: str | None = None,
|
|
) -> None:
|
|
return await messages(
|
|
websocket=websocket,
|
|
workflow_run_id=workflow_run_id,
|
|
apikey=apikey,
|
|
client_id=client_id,
|
|
token=token,
|
|
)
|
|
|
|
|
|
async def messages(
|
|
websocket: WebSocket,
|
|
browser_session_id: str | None = None,
|
|
workflow_run_id: str | None = None,
|
|
apikey: str | None = None,
|
|
client_id: str | None = None,
|
|
token: str | None = None,
|
|
) -> None:
|
|
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
|
|
|
if not organization_id:
|
|
LOG.warning(
|
|
"Authentication failed.",
|
|
browser_session_id=browser_session_id,
|
|
workflow_run_id=workflow_run_id,
|
|
)
|
|
return
|
|
|
|
if not client_id:
|
|
LOG.error(
|
|
"No client ID provided.",
|
|
browser_session_id=browser_session_id,
|
|
workflow_run_id=workflow_run_id,
|
|
)
|
|
await websocket.close(code=1002)
|
|
return
|
|
|
|
message_channel: MessageChannel
|
|
loops: Loops = []
|
|
|
|
if browser_session_id:
|
|
result = await get_message_channel_for_browser_session(
|
|
client_id=client_id,
|
|
browser_session_id=browser_session_id,
|
|
organization_id=organization_id,
|
|
websocket=websocket,
|
|
)
|
|
elif workflow_run_id:
|
|
result = await get_message_channel_for_workflow_run(
|
|
client_id=client_id,
|
|
workflow_run_id=workflow_run_id,
|
|
organization_id=organization_id,
|
|
websocket=websocket,
|
|
)
|
|
else:
|
|
LOG.error(
|
|
"Message channel: no browser_session_id or workflow_run_id provided.",
|
|
client_id=client_id,
|
|
organization_id=organization_id,
|
|
)
|
|
await websocket.close(code=1002)
|
|
return
|
|
|
|
if not result:
|
|
LOG.debug(
|
|
"No message channel found.",
|
|
browser_session_id=browser_session_id,
|
|
workflow_run_id=workflow_run_id,
|
|
organization_id=organization_id,
|
|
)
|
|
await websocket.close(code=1013)
|
|
return
|
|
|
|
message_channel, loops = result
|
|
|
|
try:
|
|
LOG.info(
|
|
"Starting message channel loops.",
|
|
browser_session_id=browser_session_id,
|
|
workflow_run_id=workflow_run_id,
|
|
organization_id=organization_id,
|
|
)
|
|
await collect(loops)
|
|
except Exception:
|
|
LOG.exception(
|
|
"An exception occurred in the message loop function(s).",
|
|
browser_session_id=browser_session_id,
|
|
workflow_run_id=workflow_run_id,
|
|
organization_id=organization_id,
|
|
)
|
|
finally:
|
|
LOG.info(
|
|
"Closing the message channel.",
|
|
browser_session_id=browser_session_id,
|
|
workflow_run_id=workflow_run_id,
|
|
organization_id=organization_id,
|
|
)
|
|
|
|
await message_channel.close(reason="message-stream-closed")
|