Reapply "Fix nested DB connections" (#4411)
This commit is contained in:
committed by
GitHub
parent
4401216346
commit
66d28bb24d
@@ -2679,8 +2679,8 @@ class AgentDB(BaseAlchemyDB):
|
|||||||
if browser_session_id:
|
if browser_session_id:
|
||||||
workflow_run.browser_session_id = browser_session_id
|
workflow_run.browser_session_id = browser_session_id
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(workflow_run)
|
|
||||||
await save_workflow_run_logs(workflow_run_id)
|
await save_workflow_run_logs(workflow_run_id)
|
||||||
|
await session.refresh(workflow_run)
|
||||||
return convert_to_workflow_run(workflow_run)
|
return convert_to_workflow_run(workflow_run)
|
||||||
else:
|
else:
|
||||||
raise WorkflowRunNotFound(workflow_run_id)
|
raise WorkflowRunNotFound(workflow_run_id)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import contextvars
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable
|
from typing import Any, AsyncContextManager, AsyncIterator, Callable
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
@@ -59,8 +61,39 @@ class BaseAlchemyDB:
|
|||||||
|
|
||||||
def __init__(self, db_engine: AsyncEngine) -> None:
|
def __init__(self, db_engine: AsyncEngine) -> None:
|
||||||
self.engine = db_engine
|
self.engine = db_engine
|
||||||
self.Session = async_sessionmaker(bind=db_engine)
|
self.Session = _SessionFactory(self, async_sessionmaker(bind=db_engine))
|
||||||
|
|
||||||
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
||||||
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _SessionFactory:
|
||||||
|
def __init__(self, db: BaseAlchemyDB, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._db = db
|
||||||
|
self._sessionmaker = sessionmaker
|
||||||
|
self._session_ctx: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar(
|
||||||
|
"skyvern_db_session",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self) -> AsyncContextManager[AsyncSession]:
|
||||||
|
return self._session()
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
return getattr(self._sessionmaker, name)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _session(self) -> AsyncIterator[AsyncSession]:
|
||||||
|
existing_session = self._session_ctx.get()
|
||||||
|
if existing_session is not None:
|
||||||
|
yield existing_session
|
||||||
|
return
|
||||||
|
|
||||||
|
session = self._sessionmaker()
|
||||||
|
token = self._session_ctx.set(session)
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
self._session_ctx.reset(token)
|
||||||
|
await session.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user