Browser streaming: add org token to CDP connection header (#3792)
This commit is contained in:
@@ -21,7 +21,8 @@ class StreamingAgent:
|
|||||||
Specifically for operations during streaming sessions (like copy/pasting selected text, etc.).
|
Specifically for operations during streaming sessions (like copy/pasting selected text, etc.).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, streaming: sc.Streaming) -> None:
|
||||||
|
self.streaming = streaming
|
||||||
self.browser: Browser | None = None
|
self.browser: Browser | None = None
|
||||||
self.browser_context: BrowserContext | None = None
|
self.browser_context: BrowserContext | None = None
|
||||||
self.page: Page | None = None
|
self.page: Page | None = None
|
||||||
@@ -36,7 +37,27 @@ class StreamingAgent:
|
|||||||
|
|
||||||
self.pw = pw
|
self.pw = pw
|
||||||
|
|
||||||
self.browser = await pw.chromium.connect_over_cdp(url)
|
headers = {
|
||||||
|
"x-api-key": self.streaming.x_api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.browser = await pw.chromium.connect_over_cdp(url, headers=headers)
|
||||||
|
|
||||||
|
org_id = self.streaming.organization_id
|
||||||
|
browser_session_id = (
|
||||||
|
self.streaming.browser_session.persistent_browser_session_id if self.streaming.browser_session else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if browser_session_id:
|
||||||
|
cdp_session = await self.browser.new_browser_cdp_session()
|
||||||
|
await cdp_session.send(
|
||||||
|
"Browser.setDownloadBehavior",
|
||||||
|
{
|
||||||
|
"behavior": "allow",
|
||||||
|
"downloadPath": f"/app/downloads/{org_id}/{browser_session_id}",
|
||||||
|
"eventsEnabled": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
contexts = self.browser.contexts
|
contexts = self.browser.contexts
|
||||||
if contexts:
|
if contexts:
|
||||||
@@ -145,6 +166,7 @@ async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterato
|
|||||||
LOG.error(msg)
|
LOG.error(msg)
|
||||||
|
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
|
|
||||||
if not streaming.browser_session or not streaming.browser_session.browser_address:
|
if not streaming.browser_session or not streaming.browser_session.browser_address:
|
||||||
msg = "connected_agent: no browser session or browser address found for streaming client."
|
msg = "connected_agent: no browser session or browser address found for streaming client."
|
||||||
|
|
||||||
@@ -156,7 +178,7 @@ async def connected_agent(streaming: sc.Streaming | None) -> typing.AsyncIterato
|
|||||||
|
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
|
|
||||||
agent = StreamingAgent()
|
agent = StreamingAgent(streaming=streaming)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await agent.connect(streaming.browser_session.browser_address)
|
await agent.connect(streaming.browser_session.browser_address)
|
||||||
|
|||||||
@@ -271,6 +271,7 @@ class Streaming:
|
|||||||
|
|
||||||
organization_id: str
|
organization_id: str
|
||||||
vnc_port: int
|
vnc_port: int
|
||||||
|
x_api_key: str
|
||||||
websocket: WebSocket
|
websocket: WebSocket
|
||||||
|
|
||||||
# --
|
# --
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ from websockets.exceptions import ConnectionClosedError
|
|||||||
|
|
||||||
import skyvern.forge.sdk.routes.streaming_clients as sc
|
import skyvern.forge.sdk.routes.streaming_clients as sc
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||||
from skyvern.forge.sdk.routes.streaming_agent import connected_agent
|
from skyvern.forge.sdk.routes.streaming_agent import connected_agent
|
||||||
from skyvern.forge.sdk.routes.streaming_auth import auth
|
from skyvern.forge.sdk.routes.streaming_auth import auth
|
||||||
@@ -37,6 +39,28 @@ from skyvern.forge.sdk.utils.aio import collect
|
|||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class Constants:
|
||||||
|
MissingXApiKey = "<missing-x-api-key>"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_x_api_key(organization_id: str) -> str:
|
||||||
|
token = await app.DATABASE.get_valid_org_auth_token(
|
||||||
|
organization_id,
|
||||||
|
OrganizationAuthTokenType.api.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
LOG.warning(
|
||||||
|
"No valid API key found for organization when streaming.",
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
x_api_key = Constants.MissingXApiKey
|
||||||
|
else:
|
||||||
|
x_api_key = token.token
|
||||||
|
|
||||||
|
return x_api_key
|
||||||
|
|
||||||
|
|
||||||
async def get_streaming_for_browser_session(
|
async def get_streaming_for_browser_session(
|
||||||
client_id: str,
|
client_id: str,
|
||||||
browser_session_id: str,
|
browser_session_id: str,
|
||||||
@@ -60,12 +84,15 @@ async def get_streaming_for_browser_session(
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
x_api_key = await get_x_api_key(organization_id)
|
||||||
|
|
||||||
streaming = sc.Streaming(
|
streaming = sc.Streaming(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
interactor="agent",
|
interactor="agent",
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||||
browser_session=browser_session,
|
browser_session=browser_session,
|
||||||
|
x_api_key=x_api_key,
|
||||||
websocket=websocket,
|
websocket=websocket,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -99,11 +126,14 @@ async def get_streaming_for_task(
|
|||||||
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
|
LOG.info("No initial browser session found for task.", task_id=task_id, organization_id=organization_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
x_api_key = await get_x_api_key(organization_id)
|
||||||
|
|
||||||
streaming = sc.Streaming(
|
streaming = sc.Streaming(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
interactor="agent",
|
interactor="agent",
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||||
|
x_api_key=x_api_key,
|
||||||
websocket=websocket,
|
websocket=websocket,
|
||||||
browser_session=browser_session,
|
browser_session=browser_session,
|
||||||
task=task,
|
task=task,
|
||||||
@@ -146,6 +176,8 @@ async def get_streaming_for_workflow_run(
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
x_api_key = await get_x_api_key(organization_id)
|
||||||
|
|
||||||
streaming = sc.Streaming(
|
streaming = sc.Streaming(
|
||||||
client_id=client_id,
|
client_id=client_id,
|
||||||
interactor="agent",
|
interactor="agent",
|
||||||
@@ -153,6 +185,7 @@ async def get_streaming_for_workflow_run(
|
|||||||
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
vnc_port=settings.SKYVERN_BROWSER_VNC_PORT,
|
||||||
browser_session=browser_session,
|
browser_session=browser_session,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
|
x_api_key=x_api_key,
|
||||||
websocket=websocket,
|
websocket=websocket,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user