Screen streaming under docker environment (#674)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, FastAPI, Response, status
|
||||
from fastapi import FastAPI, Response, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import ValidationError
|
||||
@@ -17,6 +17,7 @@ from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.routes.agent_protocol import base_router
|
||||
from skyvern.forge.sdk.routes.streaming import websocket_router
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
from skyvern.scheduler import SCHEDULER
|
||||
|
||||
@@ -30,7 +31,7 @@ class ExecutionDatePlugin(Plugin):
|
||||
return datetime.now()
|
||||
|
||||
|
||||
def get_agent_app(router: APIRouter = base_router) -> FastAPI:
|
||||
def get_agent_app() -> FastAPI:
|
||||
"""
|
||||
Start the agent server.
|
||||
"""
|
||||
@@ -46,7 +47,8 @@ def get_agent_app(router: APIRouter = base_router) -> FastAPI:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
app.include_router(base_router, prefix="/api/v1")
|
||||
app.include_router(websocket_router, prefix="/api/v1/stream")
|
||||
|
||||
app.add_middleware(
|
||||
RawContextMiddleware,
|
||||
|
||||
@@ -51,3 +51,11 @@ class BaseStorage(ABC):
|
||||
@abstractmethod
|
||||
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
|
||||
pass
|
||||
|
||||
@@ -67,6 +67,24 @@ class LocalStorage(BaseStorage):
|
||||
async def get_share_links(self, artifacts: list[Artifact]) -> list[str]:
|
||||
return [artifact.uri for artifact in artifacts]
|
||||
|
||||
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
|
||||
return
|
||||
|
||||
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
|
||||
file_path = Path(f"{SettingsManager.get_settings().STREAMING_FILE_BASE_PATH}/skyvern_screenshot.png")
|
||||
if not use_default:
|
||||
file_path = Path(f"{SettingsManager.get_settings().STREAMING_FILE_BASE_PATH}/{organization_id}/{file_name}")
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to retrieve streaming file.",
|
||||
organization_id=organization_id,
|
||||
file_name=file_name,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_uri_to_path(uri: str) -> str:
|
||||
parsed_uri = urlparse(uri)
|
||||
|
||||
115
skyvern/forge/sdk/routes/streaming.py
Normal file
115
skyvern/forge/sdk/routes/streaming.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import asyncio
|
||||
import base64
|
||||
from datetime import datetime
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from pydantic import ValidationError
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.services.org_auth_service import get_current_org
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
websocket_router = APIRouter()
|
||||
STREAMING_TIMEOUT = 300
|
||||
|
||||
|
||||
@websocket_router.websocket("/tasks/{task_id}")
|
||||
async def task_stream(
|
||||
websocket: WebSocket,
|
||||
task_id: str,
|
||||
apikey: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
await websocket.accept()
|
||||
if not token and not apikey:
|
||||
await websocket.send_text("No valid credential provided")
|
||||
return
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("ConnectionClosedOK error. Streaming won't start")
|
||||
return
|
||||
|
||||
try:
|
||||
organization = await get_current_org(x_api_key=apikey, authorization=token)
|
||||
organization_id = organization.organization_id
|
||||
except Exception:
|
||||
try:
|
||||
await websocket.send_text("Invalid credential provided")
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("ConnectionClosedOK error while sending invalid credential message")
|
||||
return
|
||||
|
||||
LOG.info("Started task streaming", task_id=task_id, organization_id=organization_id)
|
||||
# timestamp last time when streaming activity happens
|
||||
last_activity_timestamp = datetime.utcnow()
|
||||
|
||||
try:
|
||||
while True:
|
||||
# if no activity for 5 minutes, close the connection
|
||||
if (datetime.utcnow() - last_activity_timestamp).total_seconds() > STREAMING_TIMEOUT:
|
||||
LOG.info(
|
||||
"No activity for 5 minutes. Closing connection", task_id=task_id, organization_id=organization_id
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": "timeout",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
if not task:
|
||||
LOG.info("Task not found. Closing connection", task_id=task_id, organization_id=organization_id)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": "not_found",
|
||||
}
|
||||
)
|
||||
return
|
||||
if task.status.is_final():
|
||||
LOG.info(
|
||||
"Task is in a final state. Closing connection",
|
||||
task_status=task.status,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": task.status,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if task.status == TaskStatus.running:
|
||||
file_name = f"{task_id}.png"
|
||||
screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name)
|
||||
if screenshot:
|
||||
encoded_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": task.status,
|
||||
"screenshot": encoded_screenshot,
|
||||
}
|
||||
)
|
||||
last_activity_timestamp = datetime.utcnow()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
except ValidationError as e:
|
||||
await websocket.send_text(f"Invalid data: {e}")
|
||||
except WebSocketDisconnect:
|
||||
LOG.info("WebSocket connection closed")
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("ConnectionClosedOK error while streaming", exc_info=True)
|
||||
return
|
||||
except Exception:
|
||||
LOG.warning("Error while streaming", exc_info=True)
|
||||
return
|
||||
LOG.info("WebSocket connection closed successfully")
|
||||
return
|
||||
Reference in New Issue
Block a user