Reapply "Fix nested DB connections" (#4411)

This commit is contained in:
Stanislav Novosad
2026-01-07 14:13:26 -07:00
committed by GitHub
parent 4401216346
commit 66d28bb24d
2 changed files with 37 additions and 4 deletions

View File

@@ -2679,8 +2679,8 @@ class AgentDB(BaseAlchemyDB):
if browser_session_id: if browser_session_id:
workflow_run.browser_session_id = browser_session_id workflow_run.browser_session_id = browser_session_id
await session.commit() await session.commit()
await session.refresh(workflow_run)
await save_workflow_run_logs(workflow_run_id) await save_workflow_run_logs(workflow_run_id)
await session.refresh(workflow_run)
return convert_to_workflow_run(workflow_run) return convert_to_workflow_run(workflow_run)
else: else:
raise WorkflowRunNotFound(workflow_run_id) raise WorkflowRunNotFound(workflow_run_id)

View File

@@ -1,10 +1,12 @@
import asyncio import asyncio
import contextvars
from contextlib import asynccontextmanager
from functools import wraps from functools import wraps
from typing import Any, Callable from typing import Any, AsyncContextManager, AsyncIterator, Callable
import structlog import structlog
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -59,8 +61,39 @@ class BaseAlchemyDB:
def __init__(self, db_engine: AsyncEngine) -> None: def __init__(self, db_engine: AsyncEngine) -> None:
self.engine = db_engine self.engine = db_engine
self.Session = async_sessionmaker(bind=db_engine) self.Session = _SessionFactory(self, async_sessionmaker(bind=db_engine))
def is_retryable_error(self, error: SQLAlchemyError) -> bool: def is_retryable_error(self, error: SQLAlchemyError) -> bool:
"""Check if a database error is retryable. Override in subclasses for specific error handling.""" """Check if a database error is retryable. Override in subclasses for specific error handling."""
return False return False
class _SessionFactory:
def __init__(self, db: BaseAlchemyDB, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
self._db = db
self._sessionmaker = sessionmaker
self._session_ctx: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar(
"skyvern_db_session",
default=None,
)
def __call__(self) -> AsyncContextManager[AsyncSession]:
return self._session()
def __getattr__(self, name: str) -> Any:
return getattr(self._sessionmaker, name)
@asynccontextmanager
async def _session(self) -> AsyncIterator[AsyncSession]:
existing_session = self._session_ctx.get()
if existing_session is not None:
yield existing_session
return
session = self._sessionmaker()
token = self._session_ctx.set(session)
try:
yield session
finally:
self._session_ctx.reset(token)
await session.close()