Screen streaming under docker environment (#674)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-08-12 19:36:24 +03:00
committed by GitHub
parent 9342dfbf2a
commit 3f92c50a8f
12 changed files with 222 additions and 10 deletions

View File

@@ -3,7 +3,7 @@ from datetime import datetime
from typing import Awaitable, Callable
import structlog
from fastapi import APIRouter, FastAPI, Response, status
from fastapi import FastAPI, Response, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import ValidationError
@@ -17,6 +17,7 @@ from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.routes.agent_protocol import base_router
from skyvern.forge.sdk.routes.streaming import websocket_router
from skyvern.forge.sdk.settings_manager import SettingsManager
from skyvern.scheduler import SCHEDULER
@@ -30,7 +31,7 @@ class ExecutionDatePlugin(Plugin):
return datetime.now()
def get_agent_app(router: APIRouter = base_router) -> FastAPI:
def get_agent_app() -> FastAPI:
"""
Start the agent server.
"""
@@ -46,7 +47,8 @@ def get_agent_app(router: APIRouter = base_router) -> FastAPI:
allow_headers=["*"],
)
app.include_router(router, prefix="/api/v1")
app.include_router(base_router, prefix="/api/v1")
app.include_router(websocket_router, prefix="/api/v1/stream")
app.add_middleware(
RawContextMiddleware,

View File

@@ -51,3 +51,11 @@ class BaseStorage(ABC):
@abstractmethod
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
pass
@abstractmethod
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
pass
@abstractmethod
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
pass

View File

@@ -67,6 +67,24 @@ class LocalStorage(BaseStorage):
async def get_share_links(self, artifacts: list[Artifact]) -> list[str]:
return [artifact.uri for artifact in artifacts]
async def save_streaming_file(self, organization_id: str, file_name: str) -> None:
return
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
file_path = Path(f"{SettingsManager.get_settings().STREAMING_FILE_BASE_PATH}/skyvern_screenshot.png")
if not use_default:
file_path = Path(f"{SettingsManager.get_settings().STREAMING_FILE_BASE_PATH}/{organization_id}/{file_name}")
try:
with open(file_path, "rb") as f:
return f.read()
except Exception:
LOG.exception(
"Failed to retrieve streaming file.",
organization_id=organization_id,
file_name=file_name,
)
return None
@staticmethod
def _parse_uri_to_path(uri: str) -> str:
parsed_uri = urlparse(uri)

View File

@@ -0,0 +1,115 @@
import asyncio
import base64
from datetime import datetime
import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from pydantic import ValidationError
from websockets.exceptions import ConnectionClosedOK
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.services.org_auth_service import get_current_org
LOG = structlog.get_logger()
websocket_router = APIRouter()
STREAMING_TIMEOUT = 300
@websocket_router.websocket("/tasks/{task_id}")
async def task_stream(
websocket: WebSocket,
task_id: str,
apikey: str | None = None,
token: str | None = None,
) -> None:
try:
await websocket.accept()
if not token and not apikey:
await websocket.send_text("No valid credential provided")
return
except ConnectionClosedOK:
LOG.info("ConnectionClosedOK error. Streaming won't start")
return
try:
organization = await get_current_org(x_api_key=apikey, authorization=token)
organization_id = organization.organization_id
except Exception:
try:
await websocket.send_text("Invalid credential provided")
except ConnectionClosedOK:
LOG.info("ConnectionClosedOK error while sending invalid credential message")
return
LOG.info("Started task streaming", task_id=task_id, organization_id=organization_id)
# timestamp last time when streaming activity happens
last_activity_timestamp = datetime.utcnow()
try:
while True:
# if no activity for 5 minutes, close the connection
if (datetime.utcnow() - last_activity_timestamp).total_seconds() > STREAMING_TIMEOUT:
LOG.info(
"No activity for 5 minutes. Closing connection", task_id=task_id, organization_id=organization_id
)
await websocket.send_json(
{
"task_id": task_id,
"status": "timeout",
}
)
return
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
if not task:
LOG.info("Task not found. Closing connection", task_id=task_id, organization_id=organization_id)
await websocket.send_json(
{
"task_id": task_id,
"status": "not_found",
}
)
return
if task.status.is_final():
LOG.info(
"Task is in a final state. Closing connection",
task_status=task.status,
task_id=task_id,
organization_id=organization_id,
)
await websocket.send_json(
{
"task_id": task_id,
"status": task.status,
}
)
return
if task.status == TaskStatus.running:
file_name = f"{task_id}.png"
screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name)
if screenshot:
encoded_screenshot = base64.b64encode(screenshot).decode("utf-8")
await websocket.send_json(
{
"task_id": task_id,
"status": task.status,
"screenshot": encoded_screenshot,
}
)
last_activity_timestamp = datetime.utcnow()
await asyncio.sleep(2)
except ValidationError as e:
await websocket.send_text(f"Invalid data: {e}")
except WebSocketDisconnect:
LOG.info("WebSocket connection closed")
except ConnectionClosedOK:
LOG.info("ConnectionClosedOK error while streaming", exc_info=True)
return
except Exception:
LOG.warning("Error while streaming", exc_info=True)
return
LOG.info("WebSocket connection closed successfully")
return