add lock to serialize workflow script regeneration (#4487)

This commit is contained in:
pedrohsdb
2026-01-19 10:43:58 -08:00
committed by GitHub
parent e55af4b078
commit 13d9e63268
2 changed files with 153 additions and 43 deletions

View File

@@ -1,11 +1,49 @@
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import Any, Union
from types import TracebackType
from typing import Any, Protocol, Self, Union, runtime_checkable
CACHE_EXPIRE_TIME = timedelta(weeks=4)
MAX_CACHE_ITEM = 1000
@runtime_checkable
class AsyncLock(Protocol):
"""Protocol for async context manager locks (compatible with redis.asyncio.lock.Lock and NoopLock)."""
async def __aenter__(self) -> Self: ...
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None: ...
class NoopLock:
"""
A no-op lock implementation for use when distributed locking is not available.
Acts as an async context manager that does nothing - suitable for OSS/local deployments.
"""
def __init__(self, lock_name: str, blocking_timeout: int = 5, timeout: int = 10) -> None:
self.lock_name = lock_name
self.blocking_timeout = blocking_timeout
self.timeout = timeout
async def __aenter__(self) -> "NoopLock":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
class BaseCache(ABC):
@abstractmethod
async def set(self, key: str, value: Any, ex: Union[int, timedelta, None] = CACHE_EXPIRE_TIME) -> None:
@@ -14,3 +52,11 @@ class BaseCache(ABC):
@abstractmethod
async def get(self, key: str) -> Any:
pass
def get_lock(self, lock_name: str, blocking_timeout: int = 5, timeout: int = 10) -> AsyncLock:
"""
Get a distributed lock for the given name.
Default implementation returns a no-op lock for OSS deployments.
Cloud implementations should override this to use Redis locks.
"""
return NoopLock(lock_name, blocking_timeout, timeout)

View File

@@ -8,12 +8,22 @@ from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import UTC, datetime
from hashlib import sha256
from typing import Any, Literal, cast
import httpx
import structlog
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
# Import LockError for specific exception handling; fallback for OSS without redis
try:
from redis.exceptions import LockError
except ImportError:
# redis not installed (OSS deployment) - create placeholder that's never raised
class LockError(Exception): # type: ignore[no-redef]
pass
import skyvern
from skyvern import analytics
from skyvern.client.types.output_parameter import OutputParameter as BlockOutputParameter
@@ -37,6 +47,7 @@ from skyvern.exceptions import (
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.cache.factory import CacheFactory
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
@@ -3267,50 +3278,103 @@ class WorkflowService:
)
return
LOG.info(
"deleting old workflow script and generating new script",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
cache_key_value=rendered_cache_key_value,
script_id=existing_script.script_id,
script_revision_id=existing_script.script_revision_id,
run_with=workflow_run.run_with,
blocks_to_update=list(blocks_to_update),
code_gen=code_gen,
)
async def _regenerate_script() -> None:
"""Delete old script and generate new one.
# delete the existing workflow scripts if any
await app.DATABASE.delete_workflow_scripts_by_permanent_id(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
script_ids=[existing_script.script_id],
)
# create a new script
regenerated_script = await app.DATABASE.create_script(
organization_id=workflow.organization_id,
run_id=workflow_run.workflow_run_id,
)
await workflow_script_service.generate_workflow_script(
workflow_run=workflow_run,
workflow=workflow,
script=regenerated_script,
rendered_cache_key_value=rendered_cache_key_value,
cached_script=existing_script,
updated_block_labels=blocks_to_update,
)
aio_task_primary_key = f"{regenerated_script.script_id}_{regenerated_script.version}"
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:
aio_tasks = app.ARTIFACT_MANAGER.upload_aiotasks_map[aio_task_primary_key]
if aio_tasks:
await asyncio.gather(*aio_tasks)
else:
LOG.warning(
"No upload aio tasks found for regenerated script",
script_id=regenerated_script.script_id,
version=regenerated_script.version,
Uses double-check pattern: re-verify regeneration is needed after acquiring lock
to handle race conditions where another process regenerated while we waited.
"""
# Double-check: another process may have regenerated while we waited for lock
fresh_script = await workflow_script_service.get_workflow_script_by_cache_key_value(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=rendered_cache_key_value,
statuses=[ScriptStatus.published],
use_cache=False,
)
if fresh_script and fresh_script.script_id != existing_script.script_id:
LOG.info(
"Script already regenerated by another process, skipping",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
cache_key_value=rendered_cache_key_value,
existing_script_id=existing_script.script_id,
fresh_script_id=fresh_script.script_id,
)
return
LOG.info(
"deleting old workflow script and generating new script",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
cache_key_value=rendered_cache_key_value,
script_id=existing_script.script_id,
script_revision_id=existing_script.script_revision_id,
run_with=workflow_run.run_with,
blocks_to_update=list(blocks_to_update),
code_gen=code_gen,
)
await app.DATABASE.delete_workflow_scripts_by_permanent_id(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
script_ids=[existing_script.script_id],
)
regenerated_script = await app.DATABASE.create_script(
organization_id=workflow.organization_id,
run_id=workflow_run.workflow_run_id,
)
await workflow_script_service.generate_workflow_script(
workflow_run=workflow_run,
workflow=workflow,
script=regenerated_script,
rendered_cache_key_value=rendered_cache_key_value,
cached_script=existing_script,
updated_block_labels=blocks_to_update,
)
aio_task_primary_key = f"{regenerated_script.script_id}_{regenerated_script.version}"
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:
aio_tasks = app.ARTIFACT_MANAGER.upload_aiotasks_map[aio_task_primary_key]
if aio_tasks:
await asyncio.gather(*aio_tasks)
else:
LOG.warning(
"No upload aio tasks found for regenerated script",
script_id=regenerated_script.script_id,
version=regenerated_script.version,
)
# Use distributed redis lock to prevent concurrent regenerations
cache = CacheFactory.get_cache()
lock = None
if cache is not None:
try:
digest = sha256(rendered_cache_key_value.encode("utf-8")).hexdigest()
lock_name = f"workflow_script_regen:{workflow.workflow_permanent_id}:{digest}"
# blocking_timeout=60s to wait for lock, timeout=60s for lock TTL (per wintonzheng: p99=44s)
lock = cache.get_lock(lock_name, blocking_timeout=60, timeout=60)
except AttributeError:
LOG.debug("Cache doesn't support locking, proceeding without lock")
if lock is not None:
try:
async with lock:
await _regenerate_script()
except LockError as exc:
# Lock acquisition failed (e.g., another process holds the lock, timeout)
# Skip regeneration and trust the lock holder to complete the work.
# The double-check pattern in _regenerate_script() will handle it on the next call.
LOG.info(
"Skipping regeneration - lock acquisition failed, another process may be regenerating",
workflow_id=workflow.workflow_id,
workflow_permanent_id=workflow.workflow_permanent_id,
error=str(exc),
)
else:
# No Redis/cache available - proceed without lock (graceful degradation for OSS)
await _regenerate_script()
return
created_script = await app.DATABASE.create_script(