diff --git a/skyvern/forge/sdk/db/base_alchemy_db.py b/skyvern/forge/sdk/db/base_alchemy_db.py index b218fc14..5153e8d9 100644 --- a/skyvern/forge/sdk/db/base_alchemy_db.py +++ b/skyvern/forge/sdk/db/base_alchemy_db.py @@ -1,10 +1,12 @@ import asyncio +import contextvars +from contextlib import asynccontextmanager from functools import wraps -from typing import Any, Callable +from typing import Any, AsyncContextManager, AsyncIterator, Callable import structlog from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker LOG = structlog.get_logger() @@ -59,8 +61,39 @@ class BaseAlchemyDB: def __init__(self, db_engine: AsyncEngine) -> None: self.engine = db_engine - self.Session = async_sessionmaker(bind=db_engine) + self.Session = _SessionFactory(self, async_sessionmaker(bind=db_engine)) def is_retryable_error(self, error: SQLAlchemyError) -> bool: """Check if a database error is retryable. Override in subclasses for specific error handling.""" return False + + +class _SessionFactory: + def __init__(self, db: BaseAlchemyDB, sessionmaker: async_sessionmaker[AsyncSession]) -> None: + self._db = db + self._sessionmaker = sessionmaker + self._session_ctx: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar( + "skyvern_db_session", + default=None, + ) + + def __call__(self) -> AsyncContextManager[AsyncSession]: + return self._session() + + def __getattr__(self, name: str) -> Any: + return getattr(self._sessionmaker, name) + + @asynccontextmanager + async def _session(self) -> AsyncIterator[AsyncSession]: + existing_session = self._session_ctx.get() + if existing_session is not None: + yield existing_session + return + + session = self._sessionmaker() + token = self._session_ctx.set(session) + try: + yield session + finally: + self._session_ctx.reset(token) + await session.close()