Revert "Fix nested DB connections" (#4408)
This commit is contained in:
@@ -1,12 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, AsyncContextManager, AsyncIterator, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
@@ -61,39 +59,8 @@ class BaseAlchemyDB:
|
|||||||
|
|
||||||
def __init__(self, db_engine: AsyncEngine) -> None:
|
def __init__(self, db_engine: AsyncEngine) -> None:
|
||||||
self.engine = db_engine
|
self.engine = db_engine
|
||||||
self.Session = _SessionFactory(self, async_sessionmaker(bind=db_engine))
|
self.Session = async_sessionmaker(bind=db_engine)
|
||||||
|
|
||||||
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
||||||
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class _SessionFactory:
|
|
||||||
def __init__(self, db: BaseAlchemyDB, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
|
|
||||||
self._db = db
|
|
||||||
self._sessionmaker = sessionmaker
|
|
||||||
self._session_ctx: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar(
|
|
||||||
"skyvern_db_session",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self) -> AsyncContextManager[AsyncSession]:
|
|
||||||
return self._session()
|
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Any:
|
|
||||||
return getattr(self._sessionmaker, name)
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def _session(self) -> AsyncIterator[AsyncSession]:
|
|
||||||
existing_session = self._session_ctx.get()
|
|
||||||
if existing_session is not None:
|
|
||||||
yield existing_session
|
|
||||||
return
|
|
||||||
|
|
||||||
session = self._sessionmaker()
|
|
||||||
token = self._session_ctx.set(session)
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
self._session_ctx.reset(token)
|
|
||||||
await session.close()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user