Reapply "Fix nested DB connections" (#4411)
This commit is contained in:
committed by
GitHub
parent
4401216346
commit
66d28bb24d
@@ -2679,8 +2679,8 @@ class AgentDB(BaseAlchemyDB):
|
||||
if browser_session_id:
|
||||
workflow_run.browser_session_id = browser_session_id
|
||||
await session.commit()
|
||||
await session.refresh(workflow_run)
|
||||
await save_workflow_run_logs(workflow_run_id)
|
||||
await session.refresh(workflow_run)
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
else:
|
||||
raise WorkflowRunNotFound(workflow_run_id)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
from typing import Any, AsyncContextManager, AsyncIterator, Callable
|
||||
|
||||
import structlog
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
@@ -59,8 +61,39 @@ class BaseAlchemyDB:
|
||||
|
||||
def __init__(self, db_engine: AsyncEngine) -> None:
|
||||
self.engine = db_engine
|
||||
self.Session = async_sessionmaker(bind=db_engine)
|
||||
self.Session = _SessionFactory(self, async_sessionmaker(bind=db_engine))
|
||||
|
||||
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
||||
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
||||
return False
|
||||
|
||||
|
||||
class _SessionFactory:
|
||||
def __init__(self, db: BaseAlchemyDB, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._db = db
|
||||
self._sessionmaker = sessionmaker
|
||||
self._session_ctx: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar(
|
||||
"skyvern_db_session",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __call__(self) -> AsyncContextManager[AsyncSession]:
|
||||
return self._session()
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._sessionmaker, name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session(self) -> AsyncIterator[AsyncSession]:
|
||||
existing_session = self._session_ctx.get()
|
||||
if existing_session is not None:
|
||||
yield existing_session
|
||||
return
|
||||
|
||||
session = self._sessionmaker()
|
||||
token = self._session_ctx.set(session)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
self._session_ctx.reset(token)
|
||||
await session.close()
|
||||
|
||||
Reference in New Issue
Block a user