add lock to serialize workflow script regeneration (#4487)

This commit is contained in:
pedrohsdb
2026-01-19 10:43:58 -08:00
committed by GitHub
parent e55af4b078
commit 13d9e63268
2 changed files with 153 additions and 43 deletions

View File

@@ -1,11 +1,49 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import timedelta from datetime import timedelta
from typing import Any, Union from types import TracebackType
from typing import Any, Protocol, Self, Union, runtime_checkable
CACHE_EXPIRE_TIME = timedelta(weeks=4) CACHE_EXPIRE_TIME = timedelta(weeks=4)
MAX_CACHE_ITEM = 1000 MAX_CACHE_ITEM = 1000
@runtime_checkable
class AsyncLock(Protocol):
"""Protocol for async context manager locks (compatible with redis.asyncio.lock.Lock and NoopLock)."""
async def __aenter__(self) -> Self: ...
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None: ...
class NoopLock:
"""
A no-op lock implementation for use when distributed locking is not available.
Acts as an async context manager that does nothing - suitable for OSS/local deployments.
"""
def __init__(self, lock_name: str, blocking_timeout: int = 5, timeout: int = 10) -> None:
self.lock_name = lock_name
self.blocking_timeout = blocking_timeout
self.timeout = timeout
async def __aenter__(self) -> "NoopLock":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
class BaseCache(ABC): class BaseCache(ABC):
@abstractmethod @abstractmethod
async def set(self, key: str, value: Any, ex: Union[int, timedelta, None] = CACHE_EXPIRE_TIME) -> None: async def set(self, key: str, value: Any, ex: Union[int, timedelta, None] = CACHE_EXPIRE_TIME) -> None:
@@ -14,3 +52,11 @@ class BaseCache(ABC):
@abstractmethod @abstractmethod
async def get(self, key: str) -> Any: async def get(self, key: str) -> Any:
pass pass
def get_lock(self, lock_name: str, blocking_timeout: int = 5, timeout: int = 10) -> AsyncLock:
"""
Get a distributed lock for the given name.
Default implementation returns a no-op lock for OSS deployments.
Cloud implementations should override this to use Redis locks.
"""
return NoopLock(lock_name, blocking_timeout, timeout)

View File

@@ -8,12 +8,22 @@ from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from hashlib import sha256
from typing import Any, Literal, cast from typing import Any, Literal, cast
import httpx import httpx
import structlog import structlog
from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.exc import IntegrityError, SQLAlchemyError
# Import LockError for specific exception handling; fallback for OSS without redis
try:
from redis.exceptions import LockError
except ImportError:
# redis not installed (OSS deployment) - create placeholder that's never raised
class LockError(Exception): # type: ignore[no-redef]
pass
import skyvern import skyvern
from skyvern import analytics from skyvern import analytics
from skyvern.client.types.output_parameter import OutputParameter as BlockOutputParameter from skyvern.client.types.output_parameter import OutputParameter as BlockOutputParameter
@@ -37,6 +47,7 @@ from skyvern.exceptions import (
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.cache.factory import CacheFactory
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
@@ -3267,50 +3278,103 @@ class WorkflowService:
) )
return return
LOG.info( async def _regenerate_script() -> None:
"deleting old workflow script and generating new script", """Delete old script and generate new one.
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
cache_key_value=rendered_cache_key_value,
script_id=existing_script.script_id,
script_revision_id=existing_script.script_revision_id,
run_with=workflow_run.run_with,
blocks_to_update=list(blocks_to_update),
code_gen=code_gen,
)
# delete the existing workflow scripts if any Uses double-check pattern: re-verify regeneration is needed after acquiring lock
await app.DATABASE.delete_workflow_scripts_by_permanent_id( to handle race conditions where another process regenerated while we waited.
organization_id=workflow.organization_id, """
workflow_permanent_id=workflow.workflow_permanent_id, # Double-check: another process may have regenerated while we waited for lock
script_ids=[existing_script.script_id], fresh_script = await workflow_script_service.get_workflow_script_by_cache_key_value(
) organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
# create a new script cache_key_value=rendered_cache_key_value,
regenerated_script = await app.DATABASE.create_script( statuses=[ScriptStatus.published],
organization_id=workflow.organization_id, use_cache=False,
run_id=workflow_run.workflow_run_id, )
) if fresh_script and fresh_script.script_id != existing_script.script_id:
LOG.info(
await workflow_script_service.generate_workflow_script( "Script already regenerated by another process, skipping",
workflow_run=workflow_run, workflow_id=workflow.workflow_id,
workflow=workflow, workflow_run_id=workflow_run.workflow_run_id,
script=regenerated_script, cache_key_value=rendered_cache_key_value,
rendered_cache_key_value=rendered_cache_key_value, existing_script_id=existing_script.script_id,
cached_script=existing_script, fresh_script_id=fresh_script.script_id,
updated_block_labels=blocks_to_update,
)
aio_task_primary_key = f"{regenerated_script.script_id}_{regenerated_script.version}"
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:
aio_tasks = app.ARTIFACT_MANAGER.upload_aiotasks_map[aio_task_primary_key]
if aio_tasks:
await asyncio.gather(*aio_tasks)
else:
LOG.warning(
"No upload aio tasks found for regenerated script",
script_id=regenerated_script.script_id,
version=regenerated_script.version,
) )
return
LOG.info(
"deleting old workflow script and generating new script",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
cache_key_value=rendered_cache_key_value,
script_id=existing_script.script_id,
script_revision_id=existing_script.script_revision_id,
run_with=workflow_run.run_with,
blocks_to_update=list(blocks_to_update),
code_gen=code_gen,
)
await app.DATABASE.delete_workflow_scripts_by_permanent_id(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
script_ids=[existing_script.script_id],
)
regenerated_script = await app.DATABASE.create_script(
organization_id=workflow.organization_id,
run_id=workflow_run.workflow_run_id,
)
await workflow_script_service.generate_workflow_script(
workflow_run=workflow_run,
workflow=workflow,
script=regenerated_script,
rendered_cache_key_value=rendered_cache_key_value,
cached_script=existing_script,
updated_block_labels=blocks_to_update,
)
aio_task_primary_key = f"{regenerated_script.script_id}_{regenerated_script.version}"
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:
aio_tasks = app.ARTIFACT_MANAGER.upload_aiotasks_map[aio_task_primary_key]
if aio_tasks:
await asyncio.gather(*aio_tasks)
else:
LOG.warning(
"No upload aio tasks found for regenerated script",
script_id=regenerated_script.script_id,
version=regenerated_script.version,
)
# Use distributed redis lock to prevent concurrent regenerations
cache = CacheFactory.get_cache()
lock = None
if cache is not None:
try:
digest = sha256(rendered_cache_key_value.encode("utf-8")).hexdigest()
lock_name = f"workflow_script_regen:{workflow.workflow_permanent_id}:{digest}"
# blocking_timeout=60s to wait for lock, timeout=60s for lock TTL (per wintonzheng: p99=44s)
lock = cache.get_lock(lock_name, blocking_timeout=60, timeout=60)
except AttributeError:
LOG.debug("Cache doesn't support locking, proceeding without lock")
if lock is not None:
try:
async with lock:
await _regenerate_script()
except LockError as exc:
# Lock acquisition failed (e.g., another process holds the lock, timeout)
# Skip regeneration and trust the lock holder to complete the work.
# The double-check pattern in _regenerate_script() will handle it on the next call.
LOG.info(
"Skipping regeneration - lock acquisition failed, another process may be regenerating",
workflow_id=workflow.workflow_id,
workflow_permanent_id=workflow.workflow_permanent_id,
error=str(exc),
)
else:
# No Redis/cache available - proceed without lock (graceful degradation for OSS)
await _regenerate_script()
return return
created_script = await app.DATABASE.create_script( created_script = await app.DATABASE.create_script(