Allow AgentDB init to accept a db_engine (for unit tests) (#2626)
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Any, List, Sequence
|
||||
import structlog
|
||||
from sqlalchemy import and_, delete, distinct, func, pool, select, tuple_, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import WorkflowParameterNotFound, WorkflowRunNotFound
|
||||
@@ -106,14 +106,18 @@ elif "postgresql+asyncpg" in settings.DATABASE_STRING:
|
||||
|
||||
|
||||
class AgentDB:
|
||||
def __init__(self, database_string: str, debug_enabled: bool = False) -> None:
|
||||
def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None:
|
||||
super().__init__()
|
||||
self.debug_enabled = debug_enabled
|
||||
self.engine = create_async_engine(
|
||||
database_string,
|
||||
json_serializer=_custom_json_serializer,
|
||||
connect_args=DB_CONNECT_ARGS,
|
||||
poolclass=pool.NullPool if settings.DISABLE_CONNECTION_POOL else None,
|
||||
self.engine = (
|
||||
create_async_engine(
|
||||
database_string,
|
||||
json_serializer=_custom_json_serializer,
|
||||
connect_args=DB_CONNECT_ARGS,
|
||||
poolclass=pool.NullPool if settings.DISABLE_CONNECTION_POOL else None,
|
||||
)
|
||||
if db_engine is None
|
||||
else db_engine
|
||||
)
|
||||
self.Session = async_sessionmaker(bind=self.engine)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user