diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 63d01be2..1fcd4dbb 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -5,7 +5,7 @@ from typing import Any, List, Sequence import structlog from sqlalchemy import and_, delete, distinct, func, pool, select, tuple_, update from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from skyvern.config import settings from skyvern.exceptions import WorkflowParameterNotFound, WorkflowRunNotFound @@ -106,14 +106,18 @@ elif "postgresql+asyncpg" in settings.DATABASE_STRING: class AgentDB: - def __init__(self, database_string: str, debug_enabled: bool = False) -> None: + def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None: super().__init__() self.debug_enabled = debug_enabled - self.engine = create_async_engine( - database_string, - json_serializer=_custom_json_serializer, - connect_args=DB_CONNECT_ARGS, - poolclass=pool.NullPool if settings.DISABLE_CONNECTION_POOL else None, + self.engine = ( + create_async_engine( + database_string, + json_serializer=_custom_json_serializer, + connect_args=DB_CONNECT_ARGS, + poolclass=pool.NullPool if settings.DISABLE_CONNECTION_POOL else None, + ) + if db_engine is None + else db_engine ) self.Session = async_sessionmaker(bind=self.engine)