add lock to serialize workflow script regeneration (#4487)
This commit is contained in:
48
skyvern/forge/sdk/cache/base.py
vendored
48
skyvern/forge/sdk/cache/base.py
vendored
@@ -1,11 +1,49 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Union
|
from types import TracebackType
|
||||||
|
from typing import Any, Protocol, Self, Union, runtime_checkable
|
||||||
|
|
||||||
CACHE_EXPIRE_TIME = timedelta(weeks=4)
|
CACHE_EXPIRE_TIME = timedelta(weeks=4)
|
||||||
MAX_CACHE_ITEM = 1000
|
MAX_CACHE_ITEM = 1000
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class AsyncLock(Protocol):
|
||||||
|
"""Protocol for async context manager locks (compatible with redis.asyncio.lock.Lock and NoopLock)."""
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self: ...
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class NoopLock:
|
||||||
|
"""
|
||||||
|
A no-op lock implementation for use when distributed locking is not available.
|
||||||
|
Acts as an async context manager that does nothing - suitable for OSS/local deployments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lock_name: str, blocking_timeout: int = 5, timeout: int = 10) -> None:
|
||||||
|
self.lock_name = lock_name
|
||||||
|
self.blocking_timeout = blocking_timeout
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "NoopLock":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseCache(ABC):
|
class BaseCache(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def set(self, key: str, value: Any, ex: Union[int, timedelta, None] = CACHE_EXPIRE_TIME) -> None:
|
async def set(self, key: str, value: Any, ex: Union[int, timedelta, None] = CACHE_EXPIRE_TIME) -> None:
|
||||||
@@ -14,3 +52,11 @@ class BaseCache(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get(self, key: str) -> Any:
|
async def get(self, key: str) -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_lock(self, lock_name: str, blocking_timeout: int = 5, timeout: int = 10) -> AsyncLock:
|
||||||
|
"""
|
||||||
|
Get a distributed lock for the given name.
|
||||||
|
Default implementation returns a no-op lock for OSS deployments.
|
||||||
|
Cloud implementations should override this to use Redis locks.
|
||||||
|
"""
|
||||||
|
return NoopLock(lock_name, blocking_timeout, timeout)
|
||||||
|
|||||||
@@ -8,12 +8,22 @@ from collections import deque
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
from hashlib import sha256
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import structlog
|
import structlog
|
||||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||||
|
|
||||||
|
# Import LockError for specific exception handling; fallback for OSS without redis
|
||||||
|
try:
|
||||||
|
from redis.exceptions import LockError
|
||||||
|
except ImportError:
|
||||||
|
# redis not installed (OSS deployment) - create placeholder that's never raised
|
||||||
|
class LockError(Exception): # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
import skyvern
|
import skyvern
|
||||||
from skyvern import analytics
|
from skyvern import analytics
|
||||||
from skyvern.client.types.output_parameter import OutputParameter as BlockOutputParameter
|
from skyvern.client.types.output_parameter import OutputParameter as BlockOutputParameter
|
||||||
@@ -37,6 +47,7 @@ from skyvern.exceptions import (
|
|||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.prompts import prompt_engine
|
from skyvern.forge.prompts import prompt_engine
|
||||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||||
|
from skyvern.forge.sdk.cache.factory import CacheFactory
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
||||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||||
@@ -3267,50 +3278,103 @@ class WorkflowService:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
LOG.info(
|
async def _regenerate_script() -> None:
|
||||||
"deleting old workflow script and generating new script",
|
"""Delete old script and generate new one.
|
||||||
workflow_id=workflow.workflow_id,
|
|
||||||
workflow_run_id=workflow_run.workflow_run_id,
|
|
||||||
cache_key_value=rendered_cache_key_value,
|
|
||||||
script_id=existing_script.script_id,
|
|
||||||
script_revision_id=existing_script.script_revision_id,
|
|
||||||
run_with=workflow_run.run_with,
|
|
||||||
blocks_to_update=list(blocks_to_update),
|
|
||||||
code_gen=code_gen,
|
|
||||||
)
|
|
||||||
|
|
||||||
# delete the existing workflow scripts if any
|
Uses double-check pattern: re-verify regeneration is needed after acquiring lock
|
||||||
await app.DATABASE.delete_workflow_scripts_by_permanent_id(
|
to handle race conditions where another process regenerated while we waited.
|
||||||
organization_id=workflow.organization_id,
|
"""
|
||||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
# Double-check: another process may have regenerated while we waited for lock
|
||||||
script_ids=[existing_script.script_id],
|
fresh_script = await workflow_script_service.get_workflow_script_by_cache_key_value(
|
||||||
)
|
organization_id=workflow.organization_id,
|
||||||
|
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||||
# create a new script
|
cache_key_value=rendered_cache_key_value,
|
||||||
regenerated_script = await app.DATABASE.create_script(
|
statuses=[ScriptStatus.published],
|
||||||
organization_id=workflow.organization_id,
|
use_cache=False,
|
||||||
run_id=workflow_run.workflow_run_id,
|
)
|
||||||
)
|
if fresh_script and fresh_script.script_id != existing_script.script_id:
|
||||||
|
LOG.info(
|
||||||
await workflow_script_service.generate_workflow_script(
|
"Script already regenerated by another process, skipping",
|
||||||
workflow_run=workflow_run,
|
workflow_id=workflow.workflow_id,
|
||||||
workflow=workflow,
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
script=regenerated_script,
|
cache_key_value=rendered_cache_key_value,
|
||||||
rendered_cache_key_value=rendered_cache_key_value,
|
existing_script_id=existing_script.script_id,
|
||||||
cached_script=existing_script,
|
fresh_script_id=fresh_script.script_id,
|
||||||
updated_block_labels=blocks_to_update,
|
|
||||||
)
|
|
||||||
aio_task_primary_key = f"{regenerated_script.script_id}_{regenerated_script.version}"
|
|
||||||
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:
|
|
||||||
aio_tasks = app.ARTIFACT_MANAGER.upload_aiotasks_map[aio_task_primary_key]
|
|
||||||
if aio_tasks:
|
|
||||||
await asyncio.gather(*aio_tasks)
|
|
||||||
else:
|
|
||||||
LOG.warning(
|
|
||||||
"No upload aio tasks found for regenerated script",
|
|
||||||
script_id=regenerated_script.script_id,
|
|
||||||
version=regenerated_script.version,
|
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"deleting old workflow script and generating new script",
|
||||||
|
workflow_id=workflow.workflow_id,
|
||||||
|
workflow_run_id=workflow_run.workflow_run_id,
|
||||||
|
cache_key_value=rendered_cache_key_value,
|
||||||
|
script_id=existing_script.script_id,
|
||||||
|
script_revision_id=existing_script.script_revision_id,
|
||||||
|
run_with=workflow_run.run_with,
|
||||||
|
blocks_to_update=list(blocks_to_update),
|
||||||
|
code_gen=code_gen,
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.DATABASE.delete_workflow_scripts_by_permanent_id(
|
||||||
|
organization_id=workflow.organization_id,
|
||||||
|
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||||
|
script_ids=[existing_script.script_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
regenerated_script = await app.DATABASE.create_script(
|
||||||
|
organization_id=workflow.organization_id,
|
||||||
|
run_id=workflow_run.workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await workflow_script_service.generate_workflow_script(
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
workflow=workflow,
|
||||||
|
script=regenerated_script,
|
||||||
|
rendered_cache_key_value=rendered_cache_key_value,
|
||||||
|
cached_script=existing_script,
|
||||||
|
updated_block_labels=blocks_to_update,
|
||||||
|
)
|
||||||
|
aio_task_primary_key = f"{regenerated_script.script_id}_{regenerated_script.version}"
|
||||||
|
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:
|
||||||
|
aio_tasks = app.ARTIFACT_MANAGER.upload_aiotasks_map[aio_task_primary_key]
|
||||||
|
if aio_tasks:
|
||||||
|
await asyncio.gather(*aio_tasks)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
"No upload aio tasks found for regenerated script",
|
||||||
|
script_id=regenerated_script.script_id,
|
||||||
|
version=regenerated_script.version,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use distributed redis lock to prevent concurrent regenerations
|
||||||
|
cache = CacheFactory.get_cache()
|
||||||
|
lock = None
|
||||||
|
if cache is not None:
|
||||||
|
try:
|
||||||
|
digest = sha256(rendered_cache_key_value.encode("utf-8")).hexdigest()
|
||||||
|
lock_name = f"workflow_script_regen:{workflow.workflow_permanent_id}:{digest}"
|
||||||
|
# blocking_timeout=60s to wait for lock, timeout=60s for lock TTL (per wintonzheng: p99=44s)
|
||||||
|
lock = cache.get_lock(lock_name, blocking_timeout=60, timeout=60)
|
||||||
|
except AttributeError:
|
||||||
|
LOG.debug("Cache doesn't support locking, proceeding without lock")
|
||||||
|
|
||||||
|
if lock is not None:
|
||||||
|
try:
|
||||||
|
async with lock:
|
||||||
|
await _regenerate_script()
|
||||||
|
except LockError as exc:
|
||||||
|
# Lock acquisition failed (e.g., another process holds the lock, timeout)
|
||||||
|
# Skip regeneration and trust the lock holder to complete the work.
|
||||||
|
# The double-check pattern in _regenerate_script() will handle it on the next call.
|
||||||
|
LOG.info(
|
||||||
|
"Skipping regeneration - lock acquisition failed, another process may be regenerating",
|
||||||
|
workflow_id=workflow.workflow_id,
|
||||||
|
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No Redis/cache available - proceed without lock (graceful degradation for OSS)
|
||||||
|
await _regenerate_script()
|
||||||
return
|
return
|
||||||
|
|
||||||
created_script = await app.DATABASE.create_script(
|
created_script = await app.DATABASE.create_script(
|
||||||
|
|||||||
Reference in New Issue
Block a user