From 3d54a288ad54e8b9f72b283ea75d000ca7e297b7 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Fri, 12 Dec 2025 20:02:29 -0800 Subject: [PATCH] test transaction pooler db connection in k8s workers (#4290) --- skyvern/config.py | 3 ++ skyvern/forge/sdk/db/agent_db.py | 73 +++++++++++++++++++++++++------- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/skyvern/config.py b/skyvern/config.py index dc9d1a6e..0e63aeb7 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -54,6 +54,7 @@ class Settings(BaseSettings): LONG_RUNNING_TASK_WARNING_RATIO: float = 0.95 MAX_RETRIES_PER_STEP: int = 5 DEBUG_MODE: bool = False + # Database settings DATABASE_STRING: str = ( "postgresql+asyncpg://skyvern@localhost/skyvern" if platform.system() == "Windows" @@ -62,6 +63,8 @@ class Settings(BaseSettings): DATABASE_REPLICA_STRING: str | None = None DATABASE_STATEMENT_TIMEOUT_MS: int = 60000 DISABLE_CONNECTION_POOL: bool = False + DB_DISABLE_PREPARED_STATEMENTS: bool = False + PROMPT_ACTION_HISTORY_WINDOW: int = 1 TASK_RESPONSE_ACTION_SCREENSHOT_COUNT: int = 3 diff --git a/skyvern/forge/sdk/db/agent_db.py b/skyvern/forge/sdk/db/agent_db.py index dcd90f75..061ac025 100644 --- a/skyvern/forge/sdk/db/agent_db.py +++ b/skyvern/forge/sdk/db/agent_db.py @@ -3,7 +3,23 @@ from datetime import datetime, timedelta from typing import Any, List, Literal, Sequence, overload import structlog -from sqlalchemy import and_, asc, case, delete, distinct, exists, func, or_, pool, select, tuple_, update +from sqlalchemy import ( + Connection, + and_, + asc, + case, + delete, + distinct, + event, + exists, + func, + or_, + pool, + select, + tuple_, + update, +) +from sqlalchemy.engine.url import make_url from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine @@ -153,28 +169,53 @@ def _serialize_proxy_location(proxy_location: ProxyLocationInput) -> str | None: return result -DB_CONNECT_ARGS: dict[str, Any] = {} +def _connect_args_for_driver(database_string: str) -> dict[str, Any]: + driver = make_url(database_string).drivername # "postgresql+psycopg" or "postgresql+asyncpg" + args: dict[str, Any] = {} -if "postgresql+psycopg" in settings.DATABASE_STRING: - DB_CONNECT_ARGS = {"options": f"-c statement_timeout={settings.DATABASE_STATEMENT_TIMEOUT_MS}"} -elif "postgresql+asyncpg" in settings.DATABASE_STRING: - DB_CONNECT_ARGS = {"server_settings": {"statement_timeout": str(settings.DATABASE_STATEMENT_TIMEOUT_MS)}} + if settings.DB_DISABLE_PREPARED_STATEMENTS: + if driver == "postgresql+psycopg": + # psycopg3: disable server-side prepares + args["prepare_threshold"] = 0 + elif driver == "postgresql+asyncpg": + # asyncpg: disable statement cache (prepared statements) + args["statement_cache_size"] = 0 + else: + LOG.warning( + "The database driver might not be well optimized or supported by skyvern: {driver}", driver=driver + ) + + return args + + +def _install_statement_timeout(engine: AsyncEngine, timeout_ms: int) -> None: + if not timeout_ms or timeout_ms <= 0: + return + + # Works for direct AND poolers because it's not a startup parameter. + # Applies per-transaction, which is the most reliable behavior with transaction pooling. + @event.listens_for(engine.sync_engine, "begin") + def _set_timeout(conn: Connection) -> None: + conn.exec_driver_sql(f"SET LOCAL statement_timeout = {int(timeout_ms)}") + + +def make_async_engine(database_string: str) -> AsyncEngine: + engine = create_async_engine( + database_string, + json_serializer=_custom_json_serializer, + connect_args=_connect_args_for_driver(database_string), + poolclass=pool.NullPool if settings.DISABLE_CONNECTION_POOL else None, + ) + + _install_statement_timeout(engine, settings.DATABASE_STATEMENT_TIMEOUT_MS) + return engine class AgentDB: def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None: super().__init__() self.debug_enabled = debug_enabled - self.engine = ( - create_async_engine( - database_string, - json_serializer=_custom_json_serializer, - connect_args=DB_CONNECT_ARGS, - poolclass=pool.NullPool if settings.DISABLE_CONNECTION_POOL else None, - ) - if db_engine is None - else db_engine - ) + self.engine = db_engine or make_async_engine(database_string) self.Session = async_sessionmaker(bind=self.engine) async def create_task(