add lock to serialize workflow script regeneration (#4487)

This commit is contained in:
pedrohsdb
2026-01-19 10:43:58 -08:00
committed by GitHub
parent e55af4b078
commit 13d9e63268
2 changed files with 153 additions and 43 deletions

View File

@@ -1,11 +1,49 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import timedelta from datetime import timedelta
from typing import Any, Union from types import TracebackType
from typing import Any, Protocol, Self, Union, runtime_checkable
CACHE_EXPIRE_TIME = timedelta(weeks=4) CACHE_EXPIRE_TIME = timedelta(weeks=4)
MAX_CACHE_ITEM = 1000 MAX_CACHE_ITEM = 1000
@runtime_checkable
class AsyncLock(Protocol):
"""Protocol for async context manager locks (compatible with redis.asyncio.lock.Lock and NoopLock)."""
async def __aenter__(self) -> Self: ...
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None: ...
class NoopLock:
"""
A no-op lock implementation for use when distributed locking is not available.
Acts as an async context manager that does nothing - suitable for OSS/local deployments.
"""
def __init__(self, lock_name: str, blocking_timeout: int = 5, timeout: int = 10) -> None:
self.lock_name = lock_name
self.blocking_timeout = blocking_timeout
self.timeout = timeout
async def __aenter__(self) -> "NoopLock":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
class BaseCache(ABC): class BaseCache(ABC):
@abstractmethod @abstractmethod
async def set(self, key: str, value: Any, ex: Union[int, timedelta, None] = CACHE_EXPIRE_TIME) -> None: async def set(self, key: str, value: Any, ex: Union[int, timedelta, None] = CACHE_EXPIRE_TIME) -> None:
@@ -14,3 +52,11 @@ class BaseCache(ABC):
@abstractmethod @abstractmethod
async def get(self, key: str) -> Any: async def get(self, key: str) -> Any:
pass pass
def get_lock(self, lock_name: str, blocking_timeout: int = 5, timeout: int = 10) -> AsyncLock:
"""
Get a distributed lock for the given name.
Default implementation returns a no-op lock for OSS deployments.
Cloud implementations should override this to use Redis locks.
"""
return NoopLock(lock_name, blocking_timeout, timeout)

View File

@@ -8,12 +8,22 @@ from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from hashlib import sha256
from typing import Any, Literal, cast from typing import Any, Literal, cast
import httpx import httpx
import structlog import structlog
from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.exc import IntegrityError, SQLAlchemyError
# Import LockError for specific exception handling; fallback for OSS without redis
try:
from redis.exceptions import LockError
except ImportError:
# redis not installed (OSS deployment) - create placeholder that's never raised
class LockError(Exception): # type: ignore[no-redef]
pass
import skyvern import skyvern
from skyvern import analytics from skyvern import analytics
from skyvern.client.types.output_parameter import OutputParameter as BlockOutputParameter from skyvern.client.types.output_parameter import OutputParameter as BlockOutputParameter
@@ -37,6 +47,7 @@ from skyvern.exceptions import (
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.cache.factory import CacheFactory
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
@@ -3267,6 +3278,31 @@ class WorkflowService:
) )
return return
async def _regenerate_script() -> None:
"""Delete old script and generate new one.
Uses double-check pattern: re-verify regeneration is needed after acquiring lock
to handle race conditions where another process regenerated while we waited.
"""
# Double-check: another process may have regenerated while we waited for lock
fresh_script = await workflow_script_service.get_workflow_script_by_cache_key_value(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=rendered_cache_key_value,
statuses=[ScriptStatus.published],
use_cache=False,
)
if fresh_script and fresh_script.script_id != existing_script.script_id:
LOG.info(
"Script already regenerated by another process, skipping",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
cache_key_value=rendered_cache_key_value,
existing_script_id=existing_script.script_id,
fresh_script_id=fresh_script.script_id,
)
return
LOG.info( LOG.info(
"deleting old workflow script and generating new script", "deleting old workflow script and generating new script",
workflow_id=workflow.workflow_id, workflow_id=workflow.workflow_id,
@@ -3279,14 +3315,12 @@ class WorkflowService:
code_gen=code_gen, code_gen=code_gen,
) )
# delete the existing workflow scripts if any
await app.DATABASE.delete_workflow_scripts_by_permanent_id( await app.DATABASE.delete_workflow_scripts_by_permanent_id(
organization_id=workflow.organization_id, organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id, workflow_permanent_id=workflow.workflow_permanent_id,
script_ids=[existing_script.script_id], script_ids=[existing_script.script_id],
) )
# create a new script
regenerated_script = await app.DATABASE.create_script( regenerated_script = await app.DATABASE.create_script(
organization_id=workflow.organization_id, organization_id=workflow.organization_id,
run_id=workflow_run.workflow_run_id, run_id=workflow_run.workflow_run_id,
@@ -3311,6 +3345,36 @@ class WorkflowService:
script_id=regenerated_script.script_id, script_id=regenerated_script.script_id,
version=regenerated_script.version, version=regenerated_script.version,
) )
# Use distributed redis lock to prevent concurrent regenerations
cache = CacheFactory.get_cache()
lock = None
if cache is not None:
try:
digest = sha256(rendered_cache_key_value.encode("utf-8")).hexdigest()
lock_name = f"workflow_script_regen:{workflow.workflow_permanent_id}:{digest}"
# blocking_timeout=60s to wait for lock, timeout=60s for lock TTL (per wintonzheng: p99=44s)
lock = cache.get_lock(lock_name, blocking_timeout=60, timeout=60)
except AttributeError:
LOG.debug("Cache doesn't support locking, proceeding without lock")
if lock is not None:
try:
async with lock:
await _regenerate_script()
except LockError as exc:
# Lock acquisition failed (e.g., another process holds the lock, timeout)
# Skip regeneration and trust the lock holder to complete the work.
# The double-check pattern in _regenerate_script() will handle it on the next call.
LOG.info(
"Skipping regeneration - lock acquisition failed, another process may be regenerating",
workflow_id=workflow.workflow_id,
workflow_permanent_id=workflow.workflow_permanent_id,
error=str(exc),
)
else:
# No Redis/cache available - proceed without lock (graceful degradation for OSS)
await _regenerate_script()
return return
created_script = await app.DATABASE.create_script( created_script = await app.DATABASE.create_script(