diff --git a/skyvern/forge/sdk/routes/streaming_agent.py b/skyvern/forge/sdk/routes/streaming_agent.py index a84f298e..4169e4c0 100644 --- a/skyvern/forge/sdk/routes/streaming_agent.py +++ b/skyvern/forge/sdk/routes/streaming_agent.py @@ -21,7 +21,8 @@ class StreamingAgent: Specifically for operations during streaming sessions (like copy/pasting selected text, etc.). """ - def __init__(self) -> None: + def __init__(self, streaming: sc.Streaming) -> None: + self.streaming = streaming self.browser: Browser | None = None self.browser_context: BrowserContext | None = None self.page: Page | None = None @@ -36,7 +37,27 @@ class StreamingAgent: self.pw = pw - self.browser = await pw.chromium.connect_over_cdp(url) + headers = { + "x-api-key": self.streaming.x_api_key, + } + + self.browser = await pw.chromium.connect_over_cdp(url, headers=headers) + + org_id = self.streaming.organization_id + browser_session_id = ( + self.streaming.browser_session.persistent_browser_session_id if self.streaming.browser_session else None + ) + + if browser_session_id: + cdp_session = await self.browser.new_browser_cdp_session() + await cdp_session.send( + "Browser.setDownloadBehavior", + { + "behavior": "allow", + "downloadPath": f"/app/downloads/{org_id}/{browser_session_id}", + "eventsEnabled": True, + }, + ) contexts = self.browser.contexts if contexts: @@ -145,6 +166,7 @@ async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterato LOG.error(msg) raise Exception(msg) + if not streaming.browser_session or not streaming.browser_session.browser_address: msg = "connected_agent: no browser session or browser address found for streaming client." @@ -156,7 +178,7 @@ async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterato raise Exception(msg) - agent = StreamingAgent() + agent = StreamingAgent(streaming=streaming) try: await agent.connect(streaming.browser_session.browser_address) diff --git a/skyvern/forge/sdk/routes/streaming_clients.py b/skyvern/forge/sdk/routes/streaming_clients.py index dbb6e16e..3ec90cd4 100644 --- a/skyvern/forge/sdk/routes/streaming_clients.py +++ b/skyvern/forge/sdk/routes/streaming_clients.py @@ -271,6 +271,7 @@ class Streaming: organization_id: str vnc_port: int + x_api_key: str websocket: WebSocket # -- diff --git a/skyvern/forge/sdk/routes/streaming_vnc.py b/skyvern/forge/sdk/routes/streaming_vnc.py index b46baf55..0e91c099 100644 --- a/skyvern/forge/sdk/routes/streaming_vnc.py +++ b/skyvern/forge/sdk/routes/streaming_vnc.py @@ -21,6 +21,8 @@ from websockets.exceptions import ConnectionClosedError import skyvern.forge.sdk.routes.streaming_clients as sc from skyvern.config import settings +from skyvern.forge import app +from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router from skyvern.forge.sdk.routes.streaming_agent import connected_agent from skyvern.forge.sdk.routes.streaming_auth import auth @@ -37,6 +39,28 @@ from skyvern.forge.sdk.utils.aio import collect LOG = structlog.get_logger() +class Constants: + MissingXApiKey = "" + + +async def get_x_api_key(organization_id: str) -> str: + token = await app.DATABASE.get_valid_org_auth_token( + organization_id, + OrganizationAuthTokenType.api.value, + ) + + if not token: + LOG.warning( + "No valid API key found for organization when streaming.", + organization_id=organization_id, + ) + x_api_key = Constants.MissingXApiKey + else: + x_api_key = token.token + + return x_api_key + + async def get_streaming_for_browser_session( client_id: str, browser_session_id: str, @@ -60,12 +84,15 @@ async def get_streaming_for_browser_session( ) return None + x_api_key = await get_x_api_key(organization_id) + streaming = sc.Streaming( client_id=client_id, interactor="agent", organization_id=organization_id, vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, browser_session=browser_session, + x_api_key=x_api_key, websocket=websocket, ) @@ -99,11 +126,14 @@ async def get_streaming_for_task( LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id) return None + x_api_key = await get_x_api_key(organization_id) + streaming = sc.Streaming( client_id=client_id, interactor="agent", organization_id=organization_id, vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, + x_api_key=x_api_key, websocket=websocket, browser_session=browser_session, task=task, @@ -146,6 +176,8 @@ async def get_streaming_for_workflow_run( ) return None + x_api_key = await get_x_api_key(organization_id) + streaming = sc.Streaming( client_id=client_id, interactor="agent", @@ -153,6 +185,7 @@ async def get_streaming_for_workflow_run( vnc_port=settings.SKYVERN_BROWSER_VNC_PORT, browser_session=browser_session, workflow_run=workflow_run, + x_api_key=x_api_key, websocket=websocket, )