Add @retry decorator for DB operations (#4328)

This commit is contained in:
Stanislav Novosad
2025-12-18 11:32:40 -07:00
committed by GitHub
parent f592ee1874
commit c61bd26c8c
2 changed files with 106 additions and 48 deletions

View File

@@ -20,13 +20,16 @@ from sqlalchemy import (
update,
)
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.exc import (
SQLAlchemyError,
)
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from skyvern.config import settings
from skyvern.constants import DEFAULT_SCRIPT_RUN_ID
from skyvern.exceptions import BrowserProfileNotFound, WorkflowParameterNotFound, WorkflowRunNotFound
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB, read_retry
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
from skyvern.forge.sdk.db.exceptions import NotFoundError
from skyvern.forge.sdk.db.models import (
@@ -218,12 +221,14 @@ def make_async_engine(database_string: str) -> AsyncEngine:
return engine
class AgentDB:
class AgentDB(BaseAlchemyDB):
def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None:
super().__init__()
super().__init__(db_engine or make_async_engine(database_string))
self.debug_enabled = debug_enabled
self.engine = db_engine or make_async_engine(database_string)
self.Session = async_sessionmaker(bind=self.engine)
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
error_msg = str(error).lower()
return "server closed the connection" in error_msg
async def create_task(
self,
@@ -404,29 +409,23 @@ class AgentDB:
LOG.exception("UnexpectedError during bulk artifact creation")
raise
@read_retry()
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
"""Get a task by its id"""
try:
async with self.Session() as session:
if task_obj := (
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.info(
"Task not found",
task_id=task_id,
organization_id=organization_id,
)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async with self.Session() as session:
if task_obj := (
await session.scalars(
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first():
return convert_to_task(task_obj, self.debug_enabled)
else:
LOG.info(
"Task not found",
task_id=task_id,
organization_id=organization_id,
)
return None
async def get_tasks_by_ids(
self,
@@ -2758,6 +2757,7 @@ class AgentDB:
LOG.error("SQLAlchemyError", exc_info=True)
raise
@read_retry()
async def get_workflow_run(
self,
workflow_run_id: str,
@@ -2765,21 +2765,17 @@ class AgentDB:
job_id: str | None = None,
status: WorkflowRunStatus | None = None,
) -> WorkflowRun | None:
try:
async with self.Session() as session:
get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)
if organization_id:
get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id)
if job_id:
get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id)
if status:
get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value)
if workflow_run := (await session.scalars(get_workflow_run_query)).first():
return convert_to_workflow_run(workflow_run)
return None
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async with self.Session() as session:
get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)
if organization_id:
get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id)
if job_id:
get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id)
if status:
get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value)
if workflow_run := (await session.scalars(get_workflow_run_query)).first():
return convert_to_workflow_run(workflow_run)
return None
async def get_last_queued_workflow_run(
self,
@@ -3822,6 +3818,7 @@ class AgentDB:
await session.execute(stmt)
await session.commit()
@read_retry()
async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None:
async with self.Session() as session:
if task_v2 := (
@@ -4474,6 +4471,7 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True)
raise
@read_retry()
async def get_persistent_browser_session_by_runnable_id(
self, runnable_id: str, organization_id: str | None = None
) -> PersistentBrowserSession | None:
@@ -4495,12 +4493,6 @@ class AgentDB:
except NotFoundError:
LOG.error("NotFoundError", exc_info=True)
raise
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_persistent_browser_session(
self,

View File

@@ -0,0 +1,66 @@
import asyncio
from functools import wraps
from typing import Any, Callable
import structlog
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
LOG = structlog.get_logger()
def read_retry(retries: int = 3) -> Callable:
"""Decorator to retry async database operations on transient failures.
Args:
retries: Maximum number of retry attempts (default: 3)
"""
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@wraps(fn)
async def wrapper(
base_db: "BaseAlchemyDB",
*args: Any,
**kwargs: Any,
) -> Any:
for attempt in range(retries):
try:
return await fn(base_db, *args, **kwargs)
except SQLAlchemyError as e:
if not base_db.is_retryable_error(e):
LOG.error("SQLAlchemyError", exc_info=True, attempt=attempt)
raise
if attempt >= retries - 1:
LOG.error("SQLAlchemyError after all retries", exc_info=True, attempt=attempt)
raise
backoff_time = 0.1 * (2**attempt)
LOG.warning(
"SQLAlchemyError retrying",
attempt=attempt,
backoff_time=backoff_time,
exc_info=True,
)
await asyncio.sleep(backoff_time)
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
raise RuntimeError(f"Retry logic error in {fn.__name__}")
return wrapper
return decorator
class BaseAlchemyDB:
"""Base database client with connection and session management."""
def __init__(self, db_engine: AsyncEngine) -> None:
self.engine = db_engine
self.Session = async_sessionmaker(bind=db_engine)
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
return False