Revert "Fix nested DB connections" (#4408)
This commit is contained in:
@@ -1,12 +1,10 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import wraps
|
||||
from typing import Any, AsyncContextManager, AsyncIterator, Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
import structlog
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -61,39 +59,8 @@ class BaseAlchemyDB:
|
||||
|
||||
def __init__(self, db_engine: AsyncEngine) -> None:
|
||||
self.engine = db_engine
|
||||
self.Session = _SessionFactory(self, async_sessionmaker(bind=db_engine))
|
||||
self.Session = async_sessionmaker(bind=db_engine)
|
||||
|
||||
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
||||
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
||||
return False
|
||||
|
||||
|
||||
class _SessionFactory:
|
||||
def __init__(self, db: BaseAlchemyDB, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._db = db
|
||||
self._sessionmaker = sessionmaker
|
||||
self._session_ctx: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar(
|
||||
"skyvern_db_session",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __call__(self) -> AsyncContextManager[AsyncSession]:
|
||||
return self._session()
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._sessionmaker, name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session(self) -> AsyncIterator[AsyncSession]:
|
||||
existing_session = self._session_ctx.get()
|
||||
if existing_session is not None:
|
||||
yield existing_session
|
||||
return
|
||||
|
||||
session = self._sessionmaker()
|
||||
token = self._session_ctx.set(session)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
self._session_ctx.reset(token)
|
||||
await session.close()
|
||||
|
||||
Reference in New Issue
Block a user