Add @retry decorator for DB operations (#4328)

This commit is contained in:
Stanislav Novosad
2025-12-18 11:32:40 -07:00
committed by GitHub
parent f592ee1874
commit c61bd26c8c
2 changed files with 106 additions and 48 deletions

View File

@@ -20,13 +20,16 @@ from sqlalchemy import (
update, update,
) )
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import (
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine SQLAlchemyError,
)
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from skyvern.config import settings from skyvern.config import settings
from skyvern.constants import DEFAULT_SCRIPT_RUN_ID from skyvern.constants import DEFAULT_SCRIPT_RUN_ID
from skyvern.exceptions import BrowserProfileNotFound, WorkflowParameterNotFound, WorkflowRunNotFound from skyvern.exceptions import BrowserProfileNotFound, WorkflowParameterNotFound, WorkflowRunNotFound
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB, read_retry
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
from skyvern.forge.sdk.db.exceptions import NotFoundError from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import ( from skyvern.forge.sdk.db.models import (
@@ -218,12 +221,14 @@ def make_async_engine(database_string: str) -> AsyncEngine:
return engine return engine
class AgentDB: class AgentDB(BaseAlchemyDB):
def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None: def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None:
super().__init__() super().__init__(db_engine or make_async_engine(database_string))
self.debug_enabled = debug_enabled self.debug_enabled = debug_enabled
self.engine = db_engine or make_async_engine(database_string)
self.Session = async_sessionmaker(bind=self.engine) def is_retryable_error(self, error: SQLAlchemyError) -> bool:
error_msg = str(error).lower()
return "server closed the connection" in error_msg
async def create_task( async def create_task(
self, self,
@@ -404,29 +409,23 @@ class AgentDB:
LOG.exception("UnexpectedError during bulk artifact creation") LOG.exception("UnexpectedError during bulk artifact creation")
raise raise
@read_retry()
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None: async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
"""Get a task by its id""" """Get a task by its id"""
try: async with self.Session() as session:
async with self.Session() as session: if task_obj := (
if task_obj := ( await session.scalars(
await session.scalars( select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id) )
) ).first():
).first(): return convert_to_task(task_obj, self.debug_enabled)
return convert_to_task(task_obj, self.debug_enabled) else:
else: LOG.info(
LOG.info( "Task not found",
"Task not found", task_id=task_id,
task_id=task_id, organization_id=organization_id,
organization_id=organization_id, )
) return None
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_tasks_by_ids( async def get_tasks_by_ids(
self, self,
@@ -2758,6 +2757,7 @@ class AgentDB:
LOG.error("SQLAlchemyError", exc_info=True) LOG.error("SQLAlchemyError", exc_info=True)
raise raise
@read_retry()
async def get_workflow_run( async def get_workflow_run(
self, self,
workflow_run_id: str, workflow_run_id: str,
@@ -2765,21 +2765,17 @@ class AgentDB:
job_id: str | None = None, job_id: str | None = None,
status: WorkflowRunStatus | None = None, status: WorkflowRunStatus | None = None,
) -> WorkflowRun | None: ) -> WorkflowRun | None:
try: async with self.Session() as session:
async with self.Session() as session: get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)
get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id) if organization_id:
if organization_id: get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id)
get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id) if job_id:
if job_id: get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id)
get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id) if status:
if status: get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value)
get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value) if workflow_run := (await session.scalars(get_workflow_run_query)).first():
if workflow_run := (await session.scalars(get_workflow_run_query)).first(): return convert_to_workflow_run(workflow_run)
return convert_to_workflow_run(workflow_run) return None
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_last_queued_workflow_run( async def get_last_queued_workflow_run(
self, self,
@@ -3822,6 +3818,7 @@ class AgentDB:
await session.execute(stmt) await session.execute(stmt)
await session.commit() await session.commit()
@read_retry()
async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None: async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None:
async with self.Session() as session: async with self.Session() as session:
if task_v2 := ( if task_v2 := (
@@ -4474,6 +4471,7 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True) LOG.error("UnexpectedError", exc_info=True)
raise raise
@read_retry()
async def get_persistent_browser_session_by_runnable_id( async def get_persistent_browser_session_by_runnable_id(
self, runnable_id: str, organization_id: str | None = None self, runnable_id: str, organization_id: str | None = None
) -> PersistentBrowserSession | None: ) -> PersistentBrowserSession | None:
@@ -4495,12 +4493,6 @@ class AgentDB:
except NotFoundError: except NotFoundError:
LOG.error("NotFoundError", exc_info=True) LOG.error("NotFoundError", exc_info=True)
raise raise
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_persistent_browser_session( async def get_persistent_browser_session(
self, self,

View File

@@ -0,0 +1,66 @@
import asyncio
from functools import wraps
from typing import Any, Callable
import structlog
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
LOG = structlog.get_logger()
def read_retry(retries: int = 3) -> Callable:
"""Decorator to retry async database operations on transient failures.
Args:
retries: Maximum number of retry attempts (default: 3)
"""
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@wraps(fn)
async def wrapper(
base_db: "BaseAlchemyDB",
*args: Any,
**kwargs: Any,
) -> Any:
for attempt in range(retries):
try:
return await fn(base_db, *args, **kwargs)
except SQLAlchemyError as e:
if not base_db.is_retryable_error(e):
LOG.error("SQLAlchemyError", exc_info=True, attempt=attempt)
raise
if attempt >= retries - 1:
LOG.error("SQLAlchemyError after all retries", exc_info=True, attempt=attempt)
raise
backoff_time = 0.1 * (2**attempt)
LOG.warning(
"SQLAlchemyError retrying",
attempt=attempt,
backoff_time=backoff_time,
exc_info=True,
)
await asyncio.sleep(backoff_time)
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
raise RuntimeError(f"Retry logic error in {fn.__name__}")
return wrapper
return decorator
class BaseAlchemyDB:
"""Base database client with connection and session management."""
def __init__(self, db_engine: AsyncEngine) -> None:
self.engine = db_engine
self.Session = async_sessionmaker(bind=db_engine)
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
return False