Refactor script gen with block level code cache (#3910)
This commit is contained in:
@@ -635,7 +635,7 @@ class WorkflowService:
|
||||
# Unified execution: execute blocks one by one, using script code when available
|
||||
if is_script_run is False:
|
||||
workflow_script = None
|
||||
workflow_run = await self._execute_workflow_blocks(
|
||||
workflow_run, blocks_to_update = await self._execute_workflow_blocks(
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
organization=organization,
|
||||
@@ -663,6 +663,7 @@ class WorkflowService:
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
block_labels=block_labels,
|
||||
blocks_to_update=blocks_to_update,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
@@ -690,7 +691,7 @@ class WorkflowService:
|
||||
block_labels: list[str] | None = None,
|
||||
block_outputs: dict[str, Any] | None = None,
|
||||
workflow_script: WorkflowScript | None = None,
|
||||
) -> WorkflowRun:
|
||||
) -> tuple[WorkflowRun, set[str]]:
|
||||
organization_id = organization.organization_id
|
||||
workflow_run_id = workflow_run.workflow_run_id
|
||||
top_level_blocks = workflow.workflow_definition.blocks
|
||||
@@ -699,6 +700,7 @@ class WorkflowService:
|
||||
# Load script blocks if workflow_script is provided
|
||||
script_blocks_by_label: dict[str, Any] = {}
|
||||
loaded_script_module = None
|
||||
blocks_to_update: set[str] = set()
|
||||
|
||||
if workflow_script:
|
||||
LOG.info(
|
||||
@@ -946,6 +948,13 @@ class WorkflowService:
|
||||
workflow_run_id=workflow_run_id, failure_reason="Block result is None"
|
||||
)
|
||||
break
|
||||
if (
|
||||
not block_executed_with_code
|
||||
and block.label
|
||||
and block_result.status == BlockStatus.completed
|
||||
and not getattr(block, "disable_cache", False)
|
||||
):
|
||||
blocks_to_update.add(block.label)
|
||||
if block_result.status == BlockStatus.canceled:
|
||||
LOG.info(
|
||||
f"Block with type {block.block_type} at index {block_idx}/{blocks_cnt - 1} was canceled for workflow run {workflow_run_id}, cancelling workflow run",
|
||||
@@ -1064,7 +1073,7 @@ class WorkflowService:
|
||||
workflow_run_id=workflow_run_id, failure_reason=failure_reason
|
||||
)
|
||||
break
|
||||
return workflow_run
|
||||
return workflow_run, blocks_to_update
|
||||
|
||||
async def create_workflow(
|
||||
self,
|
||||
@@ -3289,13 +3298,17 @@ class WorkflowService:
|
||||
workflow: Workflow,
|
||||
workflow_run: WorkflowRun,
|
||||
block_labels: list[str] | None = None,
|
||||
blocks_to_update: set[str] | None = None,
|
||||
) -> None:
|
||||
code_gen = workflow_run.code_gen
|
||||
blocks_to_update = set(blocks_to_update or [])
|
||||
|
||||
LOG.info(
|
||||
"Generate script?",
|
||||
block_labels=block_labels,
|
||||
code_gen=code_gen,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
blocks_to_update=list(blocks_to_update),
|
||||
)
|
||||
|
||||
if block_labels and not code_gen:
|
||||
@@ -3308,16 +3321,74 @@ class WorkflowService:
|
||||
workflow_run,
|
||||
block_labels,
|
||||
)
|
||||
|
||||
if existing_script:
|
||||
LOG.info(
|
||||
"Found cached script for workflow. Skipping script generation",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
cache_key_value=rendered_cache_key_value,
|
||||
script_id=existing_script.script_id,
|
||||
cached_block_labels: set[str] = set()
|
||||
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=existing_script.script_revision_id,
|
||||
run_with=workflow_run.run_with,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
for script_block in script_blocks:
|
||||
if script_block.script_block_label:
|
||||
cached_block_labels.add(script_block.script_block_label)
|
||||
|
||||
definition_labels = {block.label for block in workflow.workflow_definition.blocks if block.label}
|
||||
definition_labels.add(settings.WORKFLOW_START_BLOCK_LABEL)
|
||||
cached_block_labels.add(settings.WORKFLOW_START_BLOCK_LABEL)
|
||||
|
||||
if cached_block_labels != definition_labels:
|
||||
missing_labels = definition_labels - cached_block_labels
|
||||
if missing_labels:
|
||||
blocks_to_update.update(missing_labels)
|
||||
# Always rebuild the orchestrator if the definition changed
|
||||
blocks_to_update.add(settings.WORKFLOW_START_BLOCK_LABEL)
|
||||
|
||||
should_regenerate = bool(blocks_to_update) or bool(code_gen)
|
||||
|
||||
if not should_regenerate:
|
||||
LOG.info(
|
||||
"Workflow script already up to date; skipping regeneration",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
cache_key_value=rendered_cache_key_value,
|
||||
script_id=existing_script.script_id,
|
||||
script_revision_id=existing_script.script_revision_id,
|
||||
run_with=workflow_run.run_with,
|
||||
)
|
||||
return
|
||||
|
||||
# delete the existing workflow scripts if any
|
||||
await app.DATABASE.delete_workflow_scripts_by_permanent_id(
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
script_ids=[existing_script.script_id],
|
||||
)
|
||||
|
||||
# create a new script
|
||||
regenerated_script = await app.DATABASE.create_script(
|
||||
organization_id=workflow.organization_id,
|
||||
run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
|
||||
await workflow_script_service.generate_workflow_script(
|
||||
workflow_run=workflow_run,
|
||||
workflow=workflow,
|
||||
script=regenerated_script,
|
||||
rendered_cache_key_value=rendered_cache_key_value,
|
||||
cached_script=existing_script,
|
||||
updated_block_labels=blocks_to_update,
|
||||
)
|
||||
aio_task_primary_key = f"{regenerated_script.script_id}_{regenerated_script.version}"
|
||||
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:
|
||||
aio_tasks = app.ARTIFACT_MANAGER.upload_aiotasks_map[aio_task_primary_key]
|
||||
if aio_tasks:
|
||||
await asyncio.gather(*aio_tasks)
|
||||
else:
|
||||
LOG.warning(
|
||||
"No upload aio tasks found for regenerated script",
|
||||
script_id=regenerated_script.script_id,
|
||||
version=regenerated_script.version,
|
||||
)
|
||||
return
|
||||
|
||||
created_script = await app.DATABASE.create_script(
|
||||
@@ -3330,6 +3401,8 @@ class WorkflowService:
|
||||
workflow=workflow,
|
||||
script=created_script,
|
||||
rendered_cache_key_value=rendered_cache_key_value,
|
||||
cached_script=None,
|
||||
updated_block_labels=None,
|
||||
)
|
||||
aio_task_primary_key = f"{created_script.script_id}_{created_script.version}"
|
||||
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:
|
||||
|
||||
Reference in New Issue
Block a user