test transaction pooler db connection in k8s workers (#4290)

This commit is contained in:
Shuchang Zheng
2025-12-12 20:02:29 -08:00
committed by GitHub
parent a902fa7a6e
commit 3d54a288ad
2 changed files with 60 additions and 16 deletions

View File

@@ -54,6 +54,7 @@ class Settings(BaseSettings):
LONG_RUNNING_TASK_WARNING_RATIO: float = 0.95
MAX_RETRIES_PER_STEP: int = 5
DEBUG_MODE: bool = False
# Database settings
DATABASE_STRING: str = (
"postgresql+asyncpg://skyvern@localhost/skyvern"
if platform.system() == "Windows"
@@ -62,6 +63,8 @@ class Settings(BaseSettings):
DATABASE_REPLICA_STRING: str | None = None
DATABASE_STATEMENT_TIMEOUT_MS: int = 60000
DISABLE_CONNECTION_POOL: bool = False
DB_DISABLE_PREPARED_STATEMENTS: bool = False
PROMPT_ACTION_HISTORY_WINDOW: int = 1
TASK_RESPONSE_ACTION_SCREENSHOT_COUNT: int = 3

View File

@@ -3,7 +3,23 @@ from datetime import datetime, timedelta
from typing import Any, List, Literal, Sequence, overload
import structlog
from sqlalchemy import and_, asc, case, delete, distinct, exists, func, or_, pool, select, tuple_, update
from sqlalchemy import (
Connection,
and_,
asc,
case,
delete,
distinct,
event,
exists,
func,
or_,
pool,
select,
tuple_,
update,
)
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
@@ -153,28 +169,53 @@ def _serialize_proxy_location(proxy_location: ProxyLocationInput) -> str | None:
return result
DB_CONNECT_ARGS: dict[str, Any] = {}
def _connect_args_for_driver(database_string: str) -> dict[str, Any]:
driver = make_url(database_string).drivername # "postgresql+psycopg" or "postgresql+asyncpg"
args: dict[str, Any] = {}
if "postgresql+psycopg" in settings.DATABASE_STRING:
DB_CONNECT_ARGS = {"options": f"-c statement_timeout={settings.DATABASE_STATEMENT_TIMEOUT_MS}"}
elif "postgresql+asyncpg" in settings.DATABASE_STRING:
DB_CONNECT_ARGS = {"server_settings": {"statement_timeout": str(settings.DATABASE_STATEMENT_TIMEOUT_MS)}}
if settings.DB_DISABLE_PREPARED_STATEMENTS:
if driver == "postgresql+psycopg":
# psycopg3: disable server-side prepares
args["prepare_threshold"] = 0
elif driver == "postgresql+asyncpg":
# asyncpg: disable statement cache (prepared statements)
args["statement_cache_size"] = 0
else:
LOG.warning(
"The database driver might not be well optimized or supported by skyvern: {driver}", driver=driver
)
return args
def _install_statement_timeout(engine: AsyncEngine, timeout_ms: int) -> None:
if not timeout_ms or timeout_ms <= 0:
return
# Works for direct AND poolers because it's not a startup parameter.
# Applies per-transaction, which is the most reliable behavior with transaction pooling.
@event.listens_for(engine.sync_engine, "begin")
def _set_timeout(conn: Connection) -> None:
conn.exec_driver_sql(f"SET LOCAL statement_timeout = {int(timeout_ms)}")
def make_async_engine(database_string: str) -> AsyncEngine:
engine = create_async_engine(
database_string,
json_serializer=_custom_json_serializer,
connect_args=_connect_args_for_driver(database_string),
poolclass=pool.NullPool if settings.DISABLE_CONNECTION_POOL else None,
)
_install_statement_timeout(engine, settings.DATABASE_STATEMENT_TIMEOUT_MS)
return engine
class AgentDB:
def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None:
super().__init__()
self.debug_enabled = debug_enabled
self.engine = (
create_async_engine(
database_string,
json_serializer=_custom_json_serializer,
connect_args=DB_CONNECT_ARGS,
poolclass=pool.NullPool if settings.DISABLE_CONNECTION_POOL else None,
)
if db_engine is None
else db_engine
)
self.engine = db_engine or make_async_engine(database_string)
self.Session = async_sessionmaker(bind=self.engine)
async def create_task(