diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 13fc2a7a..823ccc78 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -2674,35 +2674,10 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise - async def get_persistent_browser_session_by_id( - self, session_id: str, organization_id: str | None = None - ) -> PersistentBrowserSession | None: - """Get a specific persistent browser session.""" - try: - async with self.Session() as session: - query = ( - select(PersistentBrowserSessionModel) - .filter_by(persistent_browser_session_id=session_id) - .filter_by(deleted_at=None) - ) - if organization_id: - query = query.filter_by(organization_id=organization_id) - persistent_browser_session = (await session.scalars(query)).first() - if persistent_browser_session: - return PersistentBrowserSession.model_validate(persistent_browser_session) - raise NotFoundError(f"PersistentBrowserSession {session_id} not found") - except NotFoundError: - LOG.error("NotFoundError", exc_info=True) - raise - except SQLAlchemyError: - LOG.error("SQLAlchemyError", exc_info=True) - raise - except Exception: - LOG.error("UnexpectedError", exc_info=True) - raise - async def get_persistent_browser_session( - self, session_id: str, organization_id: str + self, + session_id: str, + organization_id: str | None = None, ) -> PersistentBrowserSession | None: """Get a specific persistent browser session.""" try: @@ -2755,6 +2730,41 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise + async def set_persistent_browser_session_browser_address( + self, + browser_session_id: str, + browser_address: str, + organization_id: str | None = None, + ) -> None: + """Set the browser address for a persistent browser session.""" + try: + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=browser_session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if persistent_browser_session: + persistent_browser_session.browser_address = browser_address + # once the address is set, the session is started + persistent_browser_session.started_at = datetime.utcnow() + await session.commit() + await session.refresh(persistent_browser_session) + else: + raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found") + except NotFoundError: + LOG.error("NotFoundError", exc_info=True) + raise + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + except Exception: + LOG.error("UnexpectedError", exc_info=True) + raise + async def mark_persistent_browser_session_deleted(self, session_id: str, organization_id: str) -> None: """Mark a persistent browser session as deleted.""" try: @@ -2813,7 +2823,11 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise - async def release_persistent_browser_session(self, session_id: str, organization_id: str) -> None: + async def release_persistent_browser_session( + self, + session_id: str, + organization_id: str, + ) -> PersistentBrowserSession: """Release a specific persistent browser session.""" try: async with self.Session() as session: @@ -2830,6 +2844,7 @@ class AgentDB: persistent_browser_session.runnable_id = None await session.commit() await session.refresh(persistent_browser_session) + return PersistentBrowserSession.model_validate(persistent_browser_session) else: raise NotFoundError(f"PersistentBrowserSession {session_id} not found") except SQLAlchemyError: @@ -2842,6 +2857,34 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise + async def close_persistent_browser_session(self, session_id: str, organization_id: str) -> PersistentBrowserSession: + """Close a specific persistent browser session.""" + try: + async with self.Session() as session: + persistent_browser_session = ( + await session.scalars( + select(PersistentBrowserSessionModel) + .filter_by(persistent_browser_session_id=session_id) + .filter_by(organization_id=organization_id) + .filter_by(deleted_at=None) + ) + ).first() + if persistent_browser_session: + persistent_browser_session.completed_at = datetime.utcnow() + await session.commit() + await session.refresh(persistent_browser_session) + return PersistentBrowserSession.model_validate(persistent_browser_session) + raise NotFoundError(f"PersistentBrowserSession {session_id} not found") + except NotFoundError: + LOG.error("NotFoundError", exc_info=True) + raise + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + except Exception: + LOG.error("UnexpectedError", exc_info=True) + raise + async def get_all_active_persistent_browser_sessions(self) -> List[PersistentBrowserSessionModel]: """Get all active persistent browser sessions across all organizations.""" try: