Add browser session id permission checking for task v1, v2, and workflow runs (#1980)
This commit is contained in:
@@ -5,10 +5,10 @@ from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
|
||||
class PermissionChecker(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def check(self, organization: Organization) -> None:
|
||||
async def check(self, organization: Organization, browser_session_id: str | None = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class NoopPermissionChecker(PermissionChecker):
|
||||
async def check(self, organization: Organization) -> None:
|
||||
async def check(self, organization: Organization, browser_session_id: str | None = None) -> None:
|
||||
return
|
||||
|
||||
@@ -2585,17 +2585,20 @@ class AgentDB:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_persistent_browser_session_by_id(self, session_id: str) -> Optional[PersistentBrowserSession]:
|
||||
async def get_persistent_browser_session_by_id(
|
||||
self, session_id: str, organization_id: str | None = None
|
||||
) -> Optional[PersistentBrowserSession]:
|
||||
"""Get a specific persistent browser session."""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
persistent_browser_session = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
query = (
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
if organization_id:
|
||||
query = query.filter_by(organization_id=organization_id)
|
||||
persistent_browser_session = (await session.scalars(query)).first()
|
||||
if persistent_browser_session:
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
raise NotFoundError(f"PersistentBrowserSession {session_id} not found")
|
||||
|
||||
@@ -168,7 +168,7 @@ async def run_task(
|
||||
x_max_steps_override: Annotated[int | None, Header()] = None,
|
||||
) -> CreateTaskResponse:
|
||||
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
|
||||
await PermissionCheckerFactory.get_instance().check(current_org)
|
||||
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=task.browser_session_id)
|
||||
|
||||
created_task = await app.agent.create_task(task, current_org.organization_id)
|
||||
url_hash = generate_url_hash(task.url)
|
||||
@@ -574,6 +574,7 @@ async def run_workflow(
|
||||
analytics.capture("skyvern-oss-agent-workflow-execute")
|
||||
context = skyvern_context.ensure_context()
|
||||
request_id = context.request_id
|
||||
await PermissionCheckerFactory.get_instance().check(current_org)
|
||||
|
||||
if template:
|
||||
if workflow_id not in await app.STORAGE.retrieve_global_workflows():
|
||||
@@ -1121,6 +1122,7 @@ async def run_task_v2(
|
||||
max_iterations_override=x_max_iterations_override,
|
||||
max_steps_override=x_max_steps_override,
|
||||
)
|
||||
await PermissionCheckerFactory.get_instance().check(organization, browser_session_id=data.browser_session_id)
|
||||
|
||||
try:
|
||||
task_v2 = await task_v2_service.initialize_task_v2(
|
||||
|
||||
@@ -81,7 +81,9 @@ class BrowserManager:
|
||||
"Getting browser state for task from persistent sessions manager",
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
browser_state = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_state(browser_session_id)
|
||||
browser_state = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_state(
|
||||
browser_session_id, organization_id=task.organization_id
|
||||
)
|
||||
if browser_state is None:
|
||||
LOG.warning(
|
||||
"Browser state not found in persistent sessions manager",
|
||||
@@ -148,7 +150,9 @@ class BrowserManager:
|
||||
"Getting browser state for workflow run from persistent sessions manager",
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
browser_state = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_state(browser_session_id)
|
||||
browser_state = await app.PERSISTENT_SESSIONS_MANAGER.get_browser_state(
|
||||
browser_session_id, organization_id=workflow_run.organization_id
|
||||
)
|
||||
if browser_state is None:
|
||||
LOG.warning(
|
||||
"Browser state not found in persistent sessions manager", browser_session_id=browser_session_id
|
||||
|
||||
@@ -35,7 +35,7 @@ class PersistentSessionsManager:
|
||||
"""Get all active sessions for an organization."""
|
||||
return await self.database.get_active_persistent_browser_sessions(organization_id)
|
||||
|
||||
async def get_browser_state(self, session_id: str) -> BrowserState | None:
|
||||
async def get_browser_state(self, session_id: str, organization_id: str | None = None) -> BrowserState | None:
|
||||
"""Get a specific browser session's state by session ID."""
|
||||
browser_session = self._browser_sessions.get(session_id)
|
||||
return browser_session.browser_state if browser_session else None
|
||||
|
||||
Reference in New Issue
Block a user