Files
Dorod-Sky/skyvern/forge/sdk/routes/streaming.py
Shuchang Zheng 83ad2adabd Rename old router to legacy_base_router (#2048)
Co-authored-by: Suchintan Singh <suchintansingh@gmail.com>
2025-03-31 02:57:54 -04:00

254 lines
9.6 KiB
Python

import asyncio
import base64
from datetime import datetime
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from pydantic import ValidationError
from websockets.exceptions import ConnectionClosedOK
from skyvern.forge import app
from skyvern.forge.sdk.routes.routers import legacy_base_router
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.services.org_auth_service import get_current_org
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
LOG = structlog.get_logger()
STREAMING_TIMEOUT = 300
@legacy_base_router.websocket("/stream/tasks/{task_id}")
async def task_stream(
websocket: WebSocket,
task_id: str,
apikey: str | None = None,
token: str | None = None,
) -> None:
try:
await websocket.accept()
if not token and not apikey:
await websocket.send_text("No valid credential provided")
return
except ConnectionClosedOK:
LOG.info("ConnectionClosedOK error. Streaming won't start")
return
try:
organization = await get_current_org(x_api_key=apikey, authorization=token)
organization_id = organization.organization_id
except Exception:
LOG.exception("Error while getting organization", task_id=task_id)
try:
await websocket.send_text("Invalid credential provided")
except ConnectionClosedOK:
LOG.info("ConnectionClosedOK error while sending invalid credential message")
return
LOG.info("Started task streaming", task_id=task_id, organization_id=organization_id)
# timestamp last time when streaming activity happens
last_activity_timestamp = datetime.utcnow()
try:
while True:
# if no activity for 5 minutes, close the connection
if (datetime.utcnow() - last_activity_timestamp).total_seconds() > STREAMING_TIMEOUT:
LOG.info(
"No activity for 5 minutes. Closing connection", task_id=task_id, organization_id=organization_id
)
await websocket.send_json(
{
"task_id": task_id,
"status": "timeout",
}
)
return
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("Task not found. Closing connection", task_id=task_id, organization_id=organization_id)
await websocket.send_json(
{
"task_id": task_id,
"status": "not_found",
}
)
return
if task.status.is_final():
LOG.info(
"Task is in a final state. Closing connection",
task_status=task.status,
task_id=task_id,
organization_id=organization_id,
)
await websocket.send_json(
{
"task_id": task_id,
"status": task.status,
}
)
return
if task.status == TaskStatus.running:
file_name = f"{task_id}.png"
if task.workflow_run_id:
file_name = f"{task.workflow_run_id}.png"
screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name)
if screenshot:
encoded_screenshot = base64.b64encode(screenshot).decode("utf-8")
await websocket.send_json(
{
"task_id": task_id,
"status": task.status,
"screenshot": encoded_screenshot,
}
)
last_activity_timestamp = datetime.utcnow()
await asyncio.sleep(2)
except ValidationError as e:
await websocket.send_text(f"Invalid data: {e}")
except WebSocketDisconnect:
LOG.info("WebSocket connection closed", task_id=task_id, organization_id=organization_id)
except ConnectionClosedOK:
LOG.info("ConnectionClosedOK error while streaming", task_id=task_id, organization_id=organization_id)
return
except Exception:
LOG.warning("Error while streaming", task_id=task_id, organization_id=organization_id, exc_info=True)
return
LOG.info("WebSocket connection closed successfully", task_id=task_id, organization_id=organization_id)
return
@legacy_base_router.websocket("/stream/workflow_runs/{workflow_run_id}")
async def workflow_run_streaming(
websocket: WebSocket,
workflow_run_id: str,
apikey: str | None = None,
token: str | None = None,
) -> None:
try:
await websocket.accept()
if not token and not apikey:
await websocket.send_text("No valid credential provided")
return
except ConnectionClosedOK:
LOG.info("WofklowRun Streaming: ConnectionClosedOK error. Streaming won't start")
return
try:
organization = await get_current_org(x_api_key=apikey, authorization=token)
organization_id = organization.organization_id
except Exception:
LOG.exception("WofklowRun Streaming: Error while getting organization", workflow_run_id=workflow_run_id)
try:
await websocket.send_text("Invalid credential provided")
except ConnectionClosedOK:
LOG.info("WofklowRun Streaming: ConnectionClosedOK error while sending invalid credential message")
return
LOG.info(
"WofklowRun Streaming: Started workflow run streaming",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
# timestamp last time when streaming activity happens
last_activity_timestamp = datetime.utcnow()
try:
while True:
# if no activity for 5 minutes, close the connection
if (datetime.utcnow() - last_activity_timestamp).total_seconds() > STREAMING_TIMEOUT:
LOG.info(
"WofklowRun Streaming: No activity for 5 minutes. Closing connection",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.send_json(
{
"workflow_run_id": workflow_run_id,
"status": "timeout",
}
)
return
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run or workflow_run.organization_id != organization_id:
LOG.info(
"WofklowRun Streaming: Workflow not found",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.send_json(
{
"workflow_run_id": workflow_run_id,
"status": "not_found",
}
)
return
if workflow_run.status in [
WorkflowRunStatus.completed,
WorkflowRunStatus.failed,
WorkflowRunStatus.terminated,
]:
LOG.info(
"Workflow run is in a final state. Closing connection",
workflow_run_status=workflow_run.status,
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
await websocket.send_json(
{
"workflow_run_id": workflow_run_id,
"status": workflow_run.status,
}
)
return
if workflow_run.status == WorkflowRunStatus.running:
file_name = f"{workflow_run_id}.png"
screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name)
if screenshot:
encoded_screenshot = base64.b64encode(screenshot).decode("utf-8")
await websocket.send_json(
{
"workflow_run_id": workflow_run_id,
"status": workflow_run.status,
"screenshot": encoded_screenshot,
}
)
last_activity_timestamp = datetime.utcnow()
await asyncio.sleep(2)
except ValidationError as e:
await websocket.send_text(f"Invalid data: {e}")
except WebSocketDisconnect:
LOG.info(
"WofklowRun Streaming: WebSocket connection closed",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
except ConnectionClosedOK:
LOG.info(
"WofklowRun Streaming: ConnectionClosedOK error while streaming",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return
except Exception:
LOG.warning(
"WofklowRun Streaming: Error while streaming",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
exc_info=True,
)
return
LOG.info(
"WofklowRun Streaming: WebSocket connection closed successfully",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return