Remove frontend hack for requesting persistent browser sessions, part ii (backend) (#3052)

This commit is contained in:
Jonathan Dobson
2025-07-29 09:32:52 -04:00
committed by GitHub
parent d9537327da
commit 8ff1c5dfa2
6 changed files with 228 additions and 36 deletions

View File

@@ -294,6 +294,21 @@ class Settings(BaseSettings):
TRACE_PROVIDER_HOST: str | None = None TRACE_PROVIDER_HOST: str | None = None
TRACE_PROVIDER_API_KEY: str = "fillmein" TRACE_PROVIDER_API_KEY: str = "fillmein"
# Debug Session Settings
DEBUG_SESSION_TIMEOUT_MINUTES: int = 60 * 4
"""
The timeout for a persistent browser session backing a debug session,
in minutes.
"""
DEBUG_SESSION_TIMEOUT_THRESHOLD_MINUTES: int = 5
"""
If there are `DEBUG_SESSION_TIMEOUT_THRESHOLD_MINUTES` or more minutes left
in the persistent browser session (`started_at` + `timeout_minutes`), then
the `timeout_minutes` of the persistent browser session can be extended.
Otherwise we'll consider the persistent browser session to be expired.
"""
def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]: def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]:
""" """
Keys are model names available to blocks in the frontend. These map to key names Keys are model names available to blocks in the frontend. These map to key names

View File

@@ -75,6 +75,7 @@ AGENT_FUNCTION = AgentFunction()
PERSISTENT_SESSIONS_MANAGER = PersistentSessionsManager(database=DATABASE) PERSISTENT_SESSIONS_MANAGER = PersistentSessionsManager(database=DATABASE)
scrape_exclude: ScrapeExcludeFunc | None = None scrape_exclude: ScrapeExcludeFunc | None = None
authentication_function: Callable[[str], Awaitable[Organization]] | None = None authentication_function: Callable[[str], Awaitable[Organization]] | None = None
authenticate_user_function: Callable[[str], Awaitable[str | None]] | None = None
setup_api_app: Callable[[FastAPI], None] | None = None setup_api_app: Callable[[FastAPI], None] | None = None
agent = ForgeAgent() agent = ForgeAgent()

View File

@@ -3023,6 +3023,38 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True) LOG.error("UnexpectedError", exc_info=True)
raise raise
async def update_persistent_browser_session(
self,
browser_session_id: str,
timeout_minutes: int,
organization_id: str | None = None,
) -> PersistentBrowserSession:
try:
async with self.Session() as session:
persistent_browser_session = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=browser_session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if not persistent_browser_session:
raise NotFoundError(f"PersistentBrowserSession {browser_session_id} not found")
persistent_browser_session.timeout_minutes = timeout_minutes
await session.commit()
await session.refresh(persistent_browser_session)
return PersistentBrowserSession.model_validate(persistent_browser_session)
except NotFoundError:
LOG.error("NotFoundError", exc_info=True)
raise
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def set_persistent_browser_session_browser_address( async def set_persistent_browser_session_browser_address(
self, self,
browser_session_id: str, browser_session_id: str,
@@ -3373,10 +3405,10 @@ class AgentDB:
async def get_debug_session( async def get_debug_session(
self, self,
*,
organization_id: str, organization_id: str,
workflow_permanent_id: str,
user_id: str, user_id: str,
timeout_minutes: int = 10, workflow_permanent_id: str,
) -> DebugSession | None: ) -> DebugSession | None:
async with self.Session() as session: async with self.Session() as session:
debug_session = ( debug_session = (
@@ -3391,48 +3423,22 @@ class AgentDB:
if not debug_session: if not debug_session:
return None return None
browser_session = await self.get_persistent_browser_session(
debug_session.browser_session_id, organization_id
)
if browser_session and browser_session.completed_at is None:
# TODO: should we check for expiry here - within some threshold?
return DebugSession.model_validate(debug_session)
browser_session = await self.create_persistent_browser_session(
organization_id=organization_id,
runnable_type="workflow",
runnable_id=workflow_permanent_id,
timeout_minutes=timeout_minutes,
)
debug_session.browser_session_id = browser_session.persistent_browser_session_id
await session.commit()
await session.refresh(debug_session)
return DebugSession.model_validate(debug_session) return DebugSession.model_validate(debug_session)
async def create_debug_session( async def create_debug_session(
self, self,
*,
browser_session_id: str,
organization_id: str, organization_id: str,
workflow_permanent_id: str,
user_id: str, user_id: str,
timeout_minutes: int = 10, workflow_permanent_id: str,
) -> DebugSession: ) -> DebugSession:
async with self.Session() as session: async with self.Session() as session:
browser_session = await self.create_persistent_browser_session(
organization_id=organization_id,
runnable_type="workflow",
runnable_id=workflow_permanent_id,
timeout_minutes=timeout_minutes,
)
debug_session = DebugSessionModel( debug_session = DebugSessionModel(
organization_id=organization_id, organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id, workflow_permanent_id=workflow_permanent_id,
user_id=user_id, user_id=user_id,
browser_session_id=browser_session.persistent_browser_session_id, browser_session_id=browser_session_id,
) )
session.add(debug_session) session.add(debug_session)
@@ -3440,3 +3446,25 @@ class AgentDB:
await session.refresh(debug_session) await session.refresh(debug_session)
return DebugSession.model_validate(debug_session) return DebugSession.model_validate(debug_session)
async def update_debug_session(
self,
*,
debug_session_id: str,
browser_session_id: str | None = None,
) -> DebugSession:
async with self.Session() as session:
debug_session = (
await session.scalars(select(DebugSessionModel).filter_by(debug_session_id=debug_session_id))
).first()
if not debug_session:
raise NotFoundError(f"Debug session {debug_session_id} not found")
if browser_session_id:
debug_session.browser_session_id = browser_session_id
await session.commit()
await session.refresh(debug_session)
return DebugSession.model_validate(debug_session)

View File

@@ -1,5 +1,7 @@
import asyncio import asyncio
from datetime import datetime, timedelta, timezone
from enum import Enum from enum import Enum
from math import floor
from typing import Annotated, Any from typing import Annotated, Any
import structlog import structlog
@@ -37,6 +39,7 @@ from skyvern.forge.sdk.routes.code_samples import (
) )
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router, legacy_v2_router from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router, legacy_v2_router
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestionBase, AISuggestionRequest from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestionBase, AISuggestionRequest
from skyvern.forge.sdk.schemas.debug_sessions import DebugSession
from skyvern.forge.sdk.schemas.organizations import ( from skyvern.forge.sdk.schemas.organizations import (
GetOrganizationAPIKeysResponse, GetOrganizationAPIKeysResponse,
GetOrganizationsResponse, GetOrganizationsResponse,
@@ -2005,3 +2008,110 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
final_workflow_run_block_timeline.extend(thought_timeline) final_workflow_run_block_timeline.extend(thought_timeline)
final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True) final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True)
return final_workflow_run_block_timeline return final_workflow_run_block_timeline
@base_router.get(
"/debug-session/{workflow_permanent_id}",
include_in_schema=False,
)
async def get_or_create_debug_session_by_user_and_workflow_permanent_id(
workflow_permanent_id: str,
current_org: Organization = Depends(org_auth_service.get_current_org),
current_user_id: str = Depends(org_auth_service.get_current_user_id),
) -> DebugSession:
"""
`current_user_id` is a unique identifier for a user, but does not map to an
entity in the database (at time of writing)
If the debug session does not exist, a new one will be created.
If the debug session exists, the timeout will be extended to 4 hours from
the time of the request.
"""
debug_session = await app.DATABASE.get_debug_session(
organization_id=current_org.organization_id,
user_id=current_user_id,
workflow_permanent_id=workflow_permanent_id,
)
if not debug_session:
new_browser_session = await app.PERSISTENT_SESSIONS_MANAGER.create_session(
organization_id=current_org.organization_id,
timeout_minutes=settings.DEBUG_SESSION_TIMEOUT_MINUTES,
)
debug_session = await app.DATABASE.create_debug_session(
browser_session_id=new_browser_session.persistent_browser_session_id,
organization_id=current_org.organization_id,
user_id=current_user_id,
workflow_permanent_id=workflow_permanent_id,
)
return debug_session
browser_session = await app.DATABASE.get_persistent_browser_session(
debug_session.browser_session_id, current_org.organization_id
)
if browser_session and browser_session.completed_at is None:
if browser_session.started_at is None or browser_session.timeout_minutes is None:
LOG.warning(
"Persistent browser session started_at or timeout_minutes is None; assumption == this is normal, and they will become non-None shortly",
debug_session_id=debug_session.debug_session_id,
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
user_id=current_user_id,
)
return debug_session
right_now = datetime.now(timezone.utc)
current_timeout_minutes = browser_session.timeout_minutes
started_at_utc = (
browser_session.started_at.replace(tzinfo=timezone.utc)
if browser_session.started_at.tzinfo is None
else browser_session.started_at
)
current_timeout_datetime = started_at_utc + timedelta(minutes=float(current_timeout_minutes))
minutes_left = (current_timeout_datetime - right_now).total_seconds() / 60
if minutes_left >= settings.DEBUG_SESSION_TIMEOUT_THRESHOLD_MINUTES:
new_timeout_datetime = right_now + timedelta(minutes=settings.DEBUG_SESSION_TIMEOUT_MINUTES)
minutes_diff = floor((new_timeout_datetime - current_timeout_datetime).total_seconds() / 60)
new_timeout_minutes = current_timeout_minutes + minutes_diff
LOG.info(
f"Extended persistent browser session (for debugging) by {minutes_diff} minute(s)",
minutes_diff=minutes_diff,
debug_session_id=debug_session.debug_session_id,
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
user_id=current_user_id,
)
await app.DATABASE.update_persistent_browser_session(
browser_session_id=debug_session.browser_session_id,
organization_id=current_org.organization_id,
timeout_minutes=new_timeout_minutes,
)
return debug_session
else:
LOG.info(
"pbs for debug session has expired",
debug_session_id=debug_session.debug_session_id,
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
user_id=current_user_id,
)
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.create_session(
organization_id=current_org.organization_id,
timeout_minutes=settings.DEBUG_SESSION_TIMEOUT_MINUTES,
)
await app.DATABASE.update_debug_session(
debug_session_id=debug_session.debug_session_id,
browser_session_id=browser_session.persistent_browser_session_id,
)
return debug_session

View File

@@ -1,13 +1,13 @@
from datetime import datetime from datetime import datetime
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
class DebugSession(BaseModel): class DebugSession(BaseModel):
model_config = ConfigDict(from_attributes=True)
debug_session_id: str debug_session_id: str
organization_id: str
browser_session_id: str browser_session_id: str
workflow_permanent_id: str workflow_permanent_id: str | None = None
user_id: str
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime

View File

@@ -86,6 +86,44 @@ async def _authenticate_helper(authorization: str) -> Organization:
return organization return organization
async def get_current_user_id(
authorization: Annotated[str | None, Header(include_in_schema=False)] = None,
) -> str:
if not authorization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return await _authenticate_user_helper(authorization)
async def get_current_user_id_with_authentication(
authorization: Annotated[str | None, Header()] = None,
) -> str:
if not authorization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return await _authenticate_user_helper(authorization)
async def _authenticate_user_helper(authorization: str) -> str:
token = authorization.split(" ")[1]
if not app.authenticate_user_function:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid user authentication method",
)
user_id = await app.authenticate_user_function(token)
if not user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return user_id
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL)) @cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization: async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
""" """