Refactor script gen with block level code cache (#3910)
This commit is contained in:
@@ -3,7 +3,8 @@ import base64
|
||||
import structlog
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from skyvern.core.script_generations.generate_script import generate_workflow_script_python_code
|
||||
from skyvern.config import settings
|
||||
from skyvern.core.script_generations.generate_script import ScriptBlockSource, generate_workflow_script_python_code
|
||||
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
@@ -45,6 +46,7 @@ async def generate_or_update_pending_workflow_script(
|
||||
script=script,
|
||||
rendered_cache_key_value=rendered_cache_key_value,
|
||||
pending=True,
|
||||
cached_script=script,
|
||||
)
|
||||
|
||||
|
||||
@@ -104,12 +106,62 @@ async def get_workflow_script(
|
||||
return None, rendered_cache_key_value
|
||||
|
||||
|
||||
async def _load_cached_script_block_sources(
|
||||
script: Script,
|
||||
organization_id: str,
|
||||
) -> dict[str, ScriptBlockSource]:
|
||||
"""
|
||||
Load existing script block sources (code + metadata) for a script revision so they can be reused.
|
||||
"""
|
||||
cached_blocks: dict[str, ScriptBlockSource] = {}
|
||||
|
||||
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
for script_block in script_blocks:
|
||||
if not script_block.script_block_label:
|
||||
continue
|
||||
|
||||
code_str: str | None = None
|
||||
if script_block.script_file_id:
|
||||
script_file = await app.DATABASE.get_script_file_by_id(
|
||||
script_revision_id=script.script_revision_id,
|
||||
file_id=script_block.script_file_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if script_file and script_file.artifact_id:
|
||||
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
if artifact:
|
||||
file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
|
||||
if isinstance(file_content, bytes):
|
||||
code_str = file_content.decode("utf-8")
|
||||
elif isinstance(file_content, str):
|
||||
code_str = file_content
|
||||
|
||||
if not code_str:
|
||||
continue
|
||||
|
||||
cached_blocks[script_block.script_block_label] = ScriptBlockSource(
|
||||
label=script_block.script_block_label,
|
||||
code=code_str,
|
||||
run_signature=script_block.run_signature,
|
||||
workflow_run_id=script_block.workflow_run_id,
|
||||
workflow_run_block_id=script_block.workflow_run_block_id,
|
||||
)
|
||||
|
||||
return cached_blocks
|
||||
|
||||
|
||||
async def generate_workflow_script(
|
||||
workflow_run: WorkflowRun,
|
||||
workflow: Workflow,
|
||||
script: Script,
|
||||
rendered_cache_key_value: str,
|
||||
pending: bool = False,
|
||||
cached_script: Script | None = None,
|
||||
updated_block_labels: set[str] | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
LOG.info(
|
||||
@@ -119,10 +171,26 @@ async def generate_workflow_script(
|
||||
workflow_name=workflow.title,
|
||||
cache_key_value=rendered_cache_key_value,
|
||||
)
|
||||
cached_block_sources: dict[str, ScriptBlockSource] = {}
|
||||
if cached_script:
|
||||
cached_block_sources = await _load_cached_script_block_sources(cached_script, workflow.organization_id)
|
||||
|
||||
codegen_input = await transform_workflow_run_to_code_gen_input(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
|
||||
block_labels = [block.get("label") for block in codegen_input.workflow_blocks if block.get("label")]
|
||||
|
||||
if updated_block_labels is None:
|
||||
updated_block_labels = {label for label in block_labels if label}
|
||||
else:
|
||||
updated_block_labels = set(updated_block_labels)
|
||||
|
||||
missing_labels = {label for label in block_labels if label and label not in cached_block_sources}
|
||||
updated_block_labels.update(missing_labels)
|
||||
updated_block_labels.add(settings.WORKFLOW_START_BLOCK_LABEL)
|
||||
|
||||
python_src = await generate_workflow_script_python_code(
|
||||
file_name=codegen_input.file_name,
|
||||
workflow_run_request=codegen_input.workflow_run,
|
||||
@@ -134,6 +202,8 @@ async def generate_workflow_script(
|
||||
script_id=script.script_id,
|
||||
script_revision_id=script.script_revision_id,
|
||||
pending=pending,
|
||||
cached_blocks=cached_block_sources,
|
||||
updated_block_labels=updated_block_labels,
|
||||
)
|
||||
except Exception:
|
||||
LOG.error("Failed to generate workflow script source", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user