Add @retry decorator for DB operations (#4328)
This commit is contained in:
committed by
GitHub
parent
f592ee1874
commit
c61bd26c8c
@@ -20,13 +20,16 @@ from sqlalchemy import (
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.exc import (
|
||||
SQLAlchemyError,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import DEFAULT_SCRIPT_RUN_ID
|
||||
from skyvern.exceptions import BrowserProfileNotFound, WorkflowParameterNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.base_alchemy_db import BaseAlchemyDB, read_retry
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
@@ -218,12 +221,14 @@ def make_async_engine(database_string: str) -> AsyncEngine:
|
||||
return engine
|
||||
|
||||
|
||||
class AgentDB:
|
||||
class AgentDB(BaseAlchemyDB):
|
||||
def __init__(self, database_string: str, debug_enabled: bool = False, db_engine: AsyncEngine | None = None) -> None:
|
||||
super().__init__()
|
||||
super().__init__(db_engine or make_async_engine(database_string))
|
||||
self.debug_enabled = debug_enabled
|
||||
self.engine = db_engine or make_async_engine(database_string)
|
||||
self.Session = async_sessionmaker(bind=self.engine)
|
||||
|
||||
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
||||
error_msg = str(error).lower()
|
||||
return "server closed the connection" in error_msg
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
@@ -404,29 +409,23 @@ class AgentDB:
|
||||
LOG.exception("UnexpectedError during bulk artifact creation")
|
||||
raise
|
||||
|
||||
@read_retry()
|
||||
async def get_task(self, task_id: str, organization_id: str | None = None) -> Task | None:
|
||||
"""Get a task by its id"""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
if task_obj := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.info(
|
||||
"Task not found",
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
async with self.Session() as session:
|
||||
if task_obj := (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
|
||||
)
|
||||
).first():
|
||||
return convert_to_task(task_obj, self.debug_enabled)
|
||||
else:
|
||||
LOG.info(
|
||||
"Task not found",
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_tasks_by_ids(
|
||||
self,
|
||||
@@ -2758,6 +2757,7 @@ class AgentDB:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
@read_retry()
|
||||
async def get_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@@ -2765,21 +2765,17 @@ class AgentDB:
|
||||
job_id: str | None = None,
|
||||
status: WorkflowRunStatus | None = None,
|
||||
) -> WorkflowRun | None:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)
|
||||
if organization_id:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id)
|
||||
if job_id:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id)
|
||||
if status:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value)
|
||||
if workflow_run := (await session.scalars(get_workflow_run_query)).first():
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
async with self.Session() as session:
|
||||
get_workflow_run_query = select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)
|
||||
if organization_id:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(organization_id=organization_id)
|
||||
if job_id:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(job_id=job_id)
|
||||
if status:
|
||||
get_workflow_run_query = get_workflow_run_query.filter_by(status=status.value)
|
||||
if workflow_run := (await session.scalars(get_workflow_run_query)).first():
|
||||
return convert_to_workflow_run(workflow_run)
|
||||
return None
|
||||
|
||||
async def get_last_queued_workflow_run(
|
||||
self,
|
||||
@@ -3822,6 +3818,7 @@ class AgentDB:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@read_retry()
|
||||
async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> TaskV2 | None:
|
||||
async with self.Session() as session:
|
||||
if task_v2 := (
|
||||
@@ -4474,6 +4471,7 @@ class AgentDB:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
@read_retry()
|
||||
async def get_persistent_browser_session_by_runnable_id(
|
||||
self, runnable_id: str, organization_id: str | None = None
|
||||
) -> PersistentBrowserSession | None:
|
||||
@@ -4495,12 +4493,6 @@ class AgentDB:
|
||||
except NotFoundError:
|
||||
LOG.error("NotFoundError", exc_info=True)
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_persistent_browser_session(
|
||||
self,
|
||||
|
||||
66
skyvern/forge/sdk/db/base_alchemy_db.py
Normal file
66
skyvern/forge/sdk/db/base_alchemy_db.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
import structlog
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
def read_retry(retries: int = 3) -> Callable:
|
||||
"""Decorator to retry async database operations on transient failures.
|
||||
|
||||
Args:
|
||||
retries: Maximum number of retry attempts (default: 3)
|
||||
"""
|
||||
|
||||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(fn)
|
||||
async def wrapper(
|
||||
base_db: "BaseAlchemyDB",
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
return await fn(base_db, *args, **kwargs)
|
||||
except SQLAlchemyError as e:
|
||||
if not base_db.is_retryable_error(e):
|
||||
LOG.error("SQLAlchemyError", exc_info=True, attempt=attempt)
|
||||
raise
|
||||
if attempt >= retries - 1:
|
||||
LOG.error("SQLAlchemyError after all retries", exc_info=True, attempt=attempt)
|
||||
raise
|
||||
|
||||
backoff_time = 0.1 * (2**attempt)
|
||||
LOG.warning(
|
||||
"SQLAlchemyError retrying",
|
||||
attempt=attempt,
|
||||
backoff_time=backoff_time,
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
raise RuntimeError(f"Retry logic error in {fn.__name__}")
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BaseAlchemyDB:
|
||||
"""Base database client with connection and session management."""
|
||||
|
||||
def __init__(self, db_engine: AsyncEngine) -> None:
|
||||
self.engine = db_engine
|
||||
self.Session = async_sessionmaker(bind=db_engine)
|
||||
|
||||
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
||||
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
||||
return False
|
||||
Reference in New Issue
Block a user