diff --git a/skyvern/forge/sdk/db/agent_db.py b/skyvern/forge/sdk/db/agent_db.py index 464c5bd6..dfb052df 100644 --- a/skyvern/forge/sdk/db/agent_db.py +++ b/skyvern/forge/sdk/db/agent_db.py @@ -2679,8 +2679,8 @@ class AgentDB(BaseAlchemyDB): if browser_session_id: workflow_run.browser_session_id = browser_session_id await session.commit() - await session.refresh(workflow_run) await save_workflow_run_logs(workflow_run_id) + await session.refresh(workflow_run) return convert_to_workflow_run(workflow_run) else: raise WorkflowRunNotFound(workflow_run_id) diff --git a/skyvern/forge/sdk/db/base_alchemy_db.py b/skyvern/forge/sdk/db/base_alchemy_db.py index b218fc14..5153e8d9 100644 --- a/skyvern/forge/sdk/db/base_alchemy_db.py +++ b/skyvern/forge/sdk/db/base_alchemy_db.py @@ -1,10 +1,12 @@ import asyncio +import contextvars +from contextlib import asynccontextmanager from functools import wraps -from typing import Any, Callable +from typing import Any, AsyncContextManager, AsyncIterator, Callable import structlog from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker LOG = structlog.get_logger() @@ -59,8 +61,39 @@ class BaseAlchemyDB: def __init__(self, db_engine: AsyncEngine) -> None: self.engine = db_engine - self.Session = async_sessionmaker(bind=db_engine) + self.Session = _SessionFactory(self, async_sessionmaker(bind=db_engine)) def is_retryable_error(self, error: SQLAlchemyError) -> bool: """Check if a database error is retryable. Override in subclasses for specific error handling.""" return False + + +class _SessionFactory: + def __init__(self, db: BaseAlchemyDB, sessionmaker: async_sessionmaker[AsyncSession]) -> None: + self._db = db + self._sessionmaker = sessionmaker + self._session_ctx: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar( + "skyvern_db_session", + default=None, + ) + + def __call__(self) -> AsyncContextManager[AsyncSession]: + return self._session() + + def __getattr__(self, name: str) -> Any: + return getattr(self._sessionmaker, name) + + @asynccontextmanager + async def _session(self) -> AsyncIterator[AsyncSession]: + existing_session = self._session_ctx.get() + if existing_session is not None: + yield existing_session + return + + session = self._sessionmaker() + token = self._session_ctx.set(session) + try: + yield session + finally: + self._session_ctx.reset(token) + await session.close()