script gen post action (#3480)

This commit is contained in:
Shuchang Zheng
2025-09-19 08:50:21 -07:00
committed by GitHub
parent b4669f7477
commit c5280782b0
17 changed files with 536 additions and 264 deletions

View File

@@ -17,7 +17,7 @@ from jinja2.sandbox import SandboxedEnvironment
from skyvern.config import settings
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block
from skyvern.core.script_generations.generate_script import _build_block_fn, create_or_update_script_block
from skyvern.core.script_generations.skyvern_page import script_run_context_manager
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
from skyvern.forge import app
@@ -45,10 +45,10 @@ from skyvern.forge.sdk.workflow.models.block import (
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, OutputParameter, ParameterType
from skyvern.forge.sdk.workflow.models.workflow import Workflow
from skyvern.schemas.runs import RunEngine
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate, ScriptStatus
from skyvern.schemas.workflows import BlockStatus, BlockType, FileStorageType, FileType
LOG = structlog.get_logger(__name__)
LOG = structlog.get_logger()
jinja_sandbox_env = SandboxedEnvironment()
@@ -58,6 +58,7 @@ async def build_file_tree(
script_id: str,
script_version: int,
script_revision_id: str,
draft: bool = False,
) -> dict[str, FileNode]:
"""Build a hierarchical file tree from a list of files and upload the files to s3 with the same tree structure."""
file_tree: dict[str, FileNode] = {}
@@ -70,33 +71,94 @@ async def build_file_tree(
# Create artifact and upload to S3
try:
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=script_version,
file_path=file.path,
data=content_bytes,
)
LOG.debug(
"Created script file artifact",
artifact_id=artifact_id,
file_path=file.path,
script_id=script_id,
script_version=script_version,
)
# create a script file record
await app.DATABASE.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
file_path=file.path,
file_name=file.path.split("/")[-1],
file_type="file",
content_hash=f"sha256:{content_hash}",
file_size=file_size,
mime_type=file.mime_type,
artifact_id=artifact_id,
)
if draft:
# get the script file object
script_file = await app.DATABASE.get_script_file_by_path(
script_revision_id=script_revision_id,
file_path=file.path,
organization_id=organization_id,
)
if script_file:
if not script_file.artifact_id:
LOG.error(
"Failed to update file. An existing script file has no artifact id",
script_file_id=script_file.file_id,
)
continue
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
if artifact:
# override the actual file in the storage
asyncio.create_task(app.STORAGE.store_artifact(artifact, content_bytes))
else:
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=script_version,
file_path=file.path,
data=content_bytes,
)
# update the artifact_id in the script file
await app.DATABASE.update_script_file(
script_file_id=script_file.file_id,
organization_id=organization_id,
artifact_id=artifact_id,
)
else:
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=script_version,
file_path=file.path,
data=content_bytes,
)
LOG.debug(
"Created script file artifact",
artifact_id=artifact_id,
file_path=file.path,
script_id=script_id,
script_version=script_version,
)
# create a script file record
await app.DATABASE.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
file_path=file.path,
file_name=file.path.split("/")[-1],
file_type="file",
content_hash=f"sha256:{content_hash}",
file_size=file_size,
mime_type=file.mime_type,
artifact_id=artifact_id,
)
else:
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
organization_id=organization_id,
script_id=script_id,
script_version=script_version,
file_path=file.path,
data=content_bytes,
)
LOG.debug(
"Created script file artifact",
artifact_id=artifact_id,
file_path=file.path,
script_id=script_id,
script_version=script_version,
)
# create a script file record
await app.DATABASE.create_script_file(
script_revision_id=script_revision_id,
script_id=script_id,
organization_id=organization_id,
file_path=file.path,
file_name=file.path.split("/")[-1],
file_type="file",
content_hash=f"sha256:{content_hash}",
file_size=file_size,
mime_type=file.mime_type,
artifact_id=artifact_id,
)
except Exception:
LOG.exception(
"Failed to create script file artifact",
@@ -794,6 +856,7 @@ async def _regenerate_script_block_after_ai_fallback(
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=cache_key_value,
cache_key=workflow.cache_key,
statuses=[ScriptStatus.published],
)
if not existing_scripts:
@@ -898,12 +961,12 @@ async def _regenerate_script_block_after_ai_fallback(
)
continue
await create_script_block(
await create_or_update_script_block(
block_code=block_file_content,
script_revision_id=new_script.script_revision_id,
script_id=new_script.script_id,
organization_id=organization_id,
block_name=existing_block.script_block_label,
block_label=existing_block.script_block_label,
)
block_file_content_bytes = (
block_file_content if isinstance(block_file_content, bytes) else block_file_content.encode("utf-8")

View File

@@ -466,6 +466,8 @@ async def run_task_v2_helper(
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
current_run_id = context.run_id if context and context.run_id else task_v2_id
# task v2 can be nested inside a workflow run, so we need to use the root workflow run id
root_workflow_run_id = context.root_workflow_run_id if context and context.root_workflow_run_id else workflow_run_id
enable_parse_select_in_extract = app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
"ENABLE_PARSE_SELECT_IN_EXTRACT",
current_run_id,
@@ -476,6 +478,7 @@ async def run_task_v2_helper(
organization_id=organization_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
root_workflow_run_id=root_workflow_run_id,
request_id=request_id,
task_v2_id=task_v2_id,
run_id=current_run_id,

View File

@@ -0,0 +1,188 @@
import base64
import structlog
from jinja2.sandbox import SandboxedEnvironment
from skyvern.core.script_generations.generate_script import generate_workflow_script_python_code
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
from skyvern.schemas.scripts import FileEncoding, Script, ScriptFileCreate, ScriptStatus
from skyvern.services import script_service
LOG = structlog.get_logger()
jinja_sandbox_env = SandboxedEnvironment()
async def generate_or_update_draft_workflow_script(
workflow_run: WorkflowRun,
workflow: Workflow,
) -> None:
organization_id = workflow.organization_id
context = skyvern_context.current()
if not context:
return
script_id = context.script_id
script = None
if script_id:
script = await app.DATABASE.get_script(script_id=script_id, organization_id=organization_id)
if not script:
script = await app.DATABASE.create_script(organization_id=organization_id, run_id=workflow_run.workflow_run_id)
if context:
context.script_id = script.script_id
context.script_revision_id = script.script_revision_id
_, rendered_cache_key_value = await get_workflow_script(
workflow=workflow,
workflow_run=workflow_run,
status=ScriptStatus.pending,
)
await generate_workflow_script(
workflow_run=workflow_run,
workflow=workflow,
script=script,
rendered_cache_key_value=rendered_cache_key_value,
draft=True,
)
async def get_workflow_script(
workflow: Workflow,
workflow_run: WorkflowRun,
block_labels: list[str] | None = None,
status: ScriptStatus = ScriptStatus.published,
) -> tuple[Script | None, str]:
"""
Check if there's a related workflow script that should be used instead of running the workflow.
Returns the tuple of (script, rendered_cache_key_value).
"""
cache_key = workflow.cache_key or ""
rendered_cache_key_value = ""
if not workflow.generate_script:
return None, rendered_cache_key_value
if block_labels:
# Do not generate script or run script if block_labels is provided
return None, rendered_cache_key_value
try:
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
workflow_run_id=workflow_run.workflow_run_id,
)
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
rendered_cache_key_value = jinja_sandbox_env.from_string(cache_key).render(parameters)
# Check if there are existing cached scripts for this workflow + cache_key_value
existing_scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=rendered_cache_key_value,
statuses=[status],
)
if existing_scripts:
LOG.info(
"Found cached script for workflow",
workflow_id=workflow.workflow_id,
cache_key_value=rendered_cache_key_value,
workflow_run_id=workflow_run.workflow_run_id,
script_count=len(existing_scripts),
)
return existing_scripts[0], rendered_cache_key_value
return None, rendered_cache_key_value
except Exception as e:
LOG.warning(
"Failed to check for workflow script, proceeding with normal workflow execution",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
error=str(e),
exc_info=True,
)
return None, rendered_cache_key_value
async def generate_workflow_script(
workflow_run: WorkflowRun,
workflow: Workflow,
script: Script,
rendered_cache_key_value: str,
draft: bool = False,
) -> None:
try:
LOG.info(
"Generating script for workflow",
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
workflow_name=workflow.title,
cache_key_value=rendered_cache_key_value,
)
codegen_input = await transform_workflow_run_to_code_gen_input(
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
python_src = await generate_workflow_script_python_code(
file_name=codegen_input.file_name,
workflow_run_request=codegen_input.workflow_run,
workflow=codegen_input.workflow,
blocks=codegen_input.workflow_blocks,
actions_by_task=codegen_input.actions_by_task,
task_v2_child_blocks=codegen_input.task_v2_child_blocks,
organization_id=workflow.organization_id,
script_id=script.script_id,
script_revision_id=script.script_revision_id,
draft=draft,
)
except Exception:
LOG.error("Failed to generate workflow script source", exc_info=True)
return
# 4) Persist script and files, then record mapping
content_bytes = python_src.encode("utf-8")
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
files = [
ScriptFileCreate(
path="main.py",
content=content_b64,
encoding=FileEncoding.BASE64,
mime_type="text/x-python",
)
]
# Upload script file(s) as artifacts and create rows
await script_service.build_file_tree(
files=files,
organization_id=workflow.organization_id,
script_id=script.script_id,
script_version=script.version,
script_revision_id=script.script_revision_id,
draft=draft,
)
# check if an existing drfat workflow script exists for this workflow run
existing_draft_workflow_script = None
status = ScriptStatus.published
if draft:
status = ScriptStatus.pending
existing_draft_workflow_script = await app.DATABASE.get_workflow_script(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run.workflow_run_id,
statuses=[status],
)
if not existing_draft_workflow_script:
# Record the workflow->script mapping for cache lookup
await app.DATABASE.create_workflow_script(
organization_id=workflow.organization_id,
script_id=script.script_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key=workflow.cache_key or "",
cache_key_value=rendered_cache_key_value,
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
status=status,
)