From c61bd26c8cefe9579ca1a3003cc9a20346862e32 Mon Sep 17 00:00:00 2001 From: Stanislav Novosad Date: Thu, 18 Dec 2025 11:32:40 -0700 Subject: [PATCH] Add @retry decorator for DB operations (#4328) --- skyvern/forge/sdk/db/agent_db.py | 88 +++++++++++-------------- skyvern/forge/sdk/db/base_alchemy_db.py | 66 +++++++++++++++++++ 2 files changed, 106 insertions(+), 48 deletions(-) create mode 100644 skyvern/forge/sdk/db/base_alchemy_db.py diff --git a/skyvern/forge/sdk/db/agent_db.py b/skyvern/forge/sdk/db/agent_db.py index 380a34d0..0f1df535 100644 --- a/skyvern/forge/sdk/db/agent_db.py +++ b/skyvern/forge/sdk/db/agent_db.py @@ -20,13 +20,16 @@ from sqlalchemy import ( update, ) from sqlalchemy.engine.url import make_url -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine +from sqlalchemy.exc import ( + SQLAlchemyError, +) +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from skyvern.config import settings from skyvern.constants import DEFAULT_SCRIPT_RUN_ID from skyvern.exceptions import BrowserProfileNotFound, WorkflowParameterNotFound, WorkflowRunNotFound from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType +from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB, read_retry from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType from skyvern.forge.sdk.db.exceptions import NotFoundError from skyvern.forge.sdk.db.models import ( @@ -218,12 +221,14 @@ def make_async_engine(database_string: str) -> AsyncEngine: return engine -class AgentDB: +class AgentDB(BaseAlchemyDB): def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None: - super().__init__() + super().__init__(db_engine or make_async_engine(database_string)) self.debug_enabled = debug_enabled - self.engine = db_engine or make_async_engine(database_string) - self.Session = async_sessionmaker(bind=self.engine) + + def is_retryable_error(self, error: SQLAlchemyError) -> bool: + error_msg = str(error).lower() + return "server closed the connection" in error_msg async def create_task( self, @@ -404,29 +409,23 @@ class AgentDB: LOG.exception("UnexpectedError during bulk artifact creation") raise + @read_retry() async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None: """Get a task by its id""" - try: - async with self.Session() as session: - if task_obj := ( - await session.scalars( - select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id) - ) - ).first(): - return convert_to_task(task_obj, self.debug_enabled) - else: - LOG.info( - "Task not found", - task_id=task_id, - organization_id=organization_id, - ) - return None - except SQLAlchemyError: - LOG.error("SQLAlchemyError", exc_info=True) - raise - except Exception: - LOG.error("UnexpectedError", exc_info=True) - raise + async with self.Session() as session: + if task_obj := ( + await session.scalars( + select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id) + ) + ).first(): + return convert_to_task(task_obj, self.debug_enabled) + else: + LOG.info( + "Task not found", + task_id=task_id, + organization_id=organization_id, + ) + return None async def get_tasks_by_ids( self, @@ -2758,6 +2757,7 @@ class AgentDB: LOG.error("SQLAlchemyError", exc_info=True) raise + @read_retry() async def get_workflow_run( self, workflow_run_id: str, @@ -2765,21 +2765,17 @@ class AgentDB: job_id: str | None = None, status: WorkflowRunStatus | None = None, ) -> WorkflowRun | None: - try: - async with self.Session() as session: - get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id) - if organization_id: - get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id) - if job_id: - get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id) - if status: - get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value) - if workflow_run := (await session.scalars(get_workflow_run_query)).first(): - return convert_to_workflow_run(workflow_run) - return None - except SQLAlchemyError: - LOG.error("SQLAlchemyError", exc_info=True) - raise + async with self.Session() as session: + get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id) + if organization_id: + get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id) + if job_id: + get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id) + if status: + get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value) + if workflow_run := (await session.scalars(get_workflow_run_query)).first(): + return convert_to_workflow_run(workflow_run) + return None async def get_last_queued_workflow_run( self, @@ -3822,6 +3818,7 @@ class AgentDB: await session.execute(stmt) await session.commit() + @read_retry() async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None: async with self.Session() as session: if task_v2 := ( @@ -4474,6 +4471,7 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise + @read_retry() async def get_persistent_browser_session_by_runnable_id( self, runnable_id: str, organization_id: str | None = None ) -> PersistentBrowserSession | None: @@ -4495,12 +4493,6 @@ class AgentDB: except NotFoundError: LOG.error("NotFoundError", exc_info=True) raise - except SQLAlchemyError: - LOG.error("SQLAlchemyError", exc_info=True) - raise - except Exception: - LOG.error("UnexpectedError", exc_info=True) - raise async def get_persistent_browser_session( self, diff --git a/skyvern/forge/sdk/db/base_alchemy_db.py b/skyvern/forge/sdk/db/base_alchemy_db.py new file mode 100644 index 00000000..a0d10ac9 --- /dev/null +++ b/skyvern/forge/sdk/db/base_alchemy_db.py @@ -0,0 +1,66 @@ +import asyncio +from functools import wraps +from typing import Any, Callable + +import structlog +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker + +LOG = structlog.get_logger() + + +def read_retry(retries: int = 3) -> Callable: + """Decorator to retry async database operations on transient failures. + + Args: + retries: Maximum number of retry attempts (default: 3) + """ + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + @wraps(fn) + async def wrapper( + base_db: "BaseAlchemyDB", + *args: Any, + **kwargs: Any, + ) -> Any: + for attempt in range(retries): + try: + return await fn(base_db, *args, **kwargs) + except SQLAlchemyError as e: + if not base_db.is_retryable_error(e): + LOG.error("SQLAlchemyError", exc_info=True, attempt=attempt) + raise + if attempt >= retries - 1: + LOG.error("SQLAlchemyError after all retries", exc_info=True, attempt=attempt) + raise + + backoff_time = 0.1 * (2**attempt) + LOG.warning( + "SQLAlchemyError retrying", + attempt=attempt, + backoff_time=backoff_time, + exc_info=True, + ) + await asyncio.sleep(backoff_time) + + except Exception: + LOG.error("UnexpectedError", exc_info=True) + raise + + raise RuntimeError(f"Retry logic error in {fn.__name__}") + + return wrapper + + return decorator + + +class BaseAlchemyDB: + """Base database client with connection and session management.""" + + def __init__(self, db_engine: AsyncEngine) -> None: + self.engine = db_engine + self.Session = async_sessionmaker(bind=db_engine) + + def is_retryable_error(self, error: SQLAlchemyError) -> bool: + """Check if a database error is retryable. Override in subclasses for specific error handling.""" + return False