Fix nested DB connections (#4402)
This commit is contained in:
committed by
GitHub
parent
f5cb826a37
commit
2f6fd5262b
@@ -1,10 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import contextvars
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable
|
from typing import Any, AsyncContextManager, AsyncIterator, Callable
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
@@ -59,8 +61,39 @@ class BaseAlchemyDB:
|
|||||||
|
|
||||||
def __init__(self, db_engine: AsyncEngine) -> None:
|
def __init__(self, db_engine: AsyncEngine) -> None:
|
||||||
self.engine = db_engine
|
self.engine = db_engine
|
||||||
self.Session = async_sessionmaker(bind=db_engine)
|
self.Session = _SessionFactory(self, async_sessionmaker(bind=db_engine))
|
||||||
|
|
||||||
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
||||||
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _SessionFactory:
|
||||||
|
def __init__(self, db: BaseAlchemyDB, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._db = db
|
||||||
|
self._sessionmaker = sessionmaker
|
||||||
|
self._session_ctx: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar(
|
||||||
|
"skyvern_db_session",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self) -> AsyncContextManager[AsyncSession]:
|
||||||
|
return self._session()
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
return getattr(self._sessionmaker, name)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _session(self) -> AsyncIterator[AsyncSession]:
|
||||||
|
existing_session = self._session_ctx.get()
|
||||||
|
if existing_session is not None:
|
||||||
|
yield existing_session
|
||||||
|
return
|
||||||
|
|
||||||
|
session = self._sessionmaker()
|
||||||
|
token = self._session_ctx.set(session)
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
self._session_ctx.reset(token)
|
||||||
|
await session.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user