script gen post action (#3480)
This commit is contained in:
@@ -17,7 +17,7 @@ from jinja2.sandbox import SandboxedEnvironment
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
|
||||
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
|
||||
from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block
|
||||
from skyvern.core.script_generations.generate_script import _build_block_fn, create_or_update_script_block
|
||||
from skyvern.core.script_generations.skyvern_page import script_run_context_manager
|
||||
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge import app
|
||||
@@ -45,10 +45,10 @@ from skyvern.forge.sdk.workflow.models.block import (
|
||||
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, OutputParameter, ParameterType
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow
|
||||
from skyvern.schemas.runs import RunEngine
|
||||
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate
|
||||
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate, ScriptStatus
|
||||
from skyvern.schemas.workflows import BlockStatus, BlockType, FileStorageType, FileType
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
LOG = structlog.get_logger()
|
||||
jinja_sandbox_env = SandboxedEnvironment()
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ async def build_file_tree(
|
||||
script_id: str,
|
||||
script_version: int,
|
||||
script_revision_id: str,
|
||||
draft: bool = False,
|
||||
) -> dict[str, FileNode]:
|
||||
"""Build a hierarchical file tree from a list of files and upload the files to s3 with the same tree structure."""
|
||||
file_tree: dict[str, FileNode] = {}
|
||||
@@ -70,33 +71,94 @@ async def build_file_tree(
|
||||
|
||||
# Create artifact and upload to S3
|
||||
try:
|
||||
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
|
||||
organization_id=organization_id,
|
||||
script_id=script_id,
|
||||
script_version=script_version,
|
||||
file_path=file.path,
|
||||
data=content_bytes,
|
||||
)
|
||||
LOG.debug(
|
||||
"Created script file artifact",
|
||||
artifact_id=artifact_id,
|
||||
file_path=file.path,
|
||||
script_id=script_id,
|
||||
script_version=script_version,
|
||||
)
|
||||
# create a script file record
|
||||
await app.DATABASE.create_script_file(
|
||||
script_revision_id=script_revision_id,
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
file_path=file.path,
|
||||
file_name=file.path.split("/")[-1],
|
||||
file_type="file",
|
||||
content_hash=f"sha256:{content_hash}",
|
||||
file_size=file_size,
|
||||
mime_type=file.mime_type,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
if draft:
|
||||
# get the script file object
|
||||
script_file = await app.DATABASE.get_script_file_by_path(
|
||||
script_revision_id=script_revision_id,
|
||||
file_path=file.path,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if script_file:
|
||||
if not script_file.artifact_id:
|
||||
LOG.error(
|
||||
"Failed to update file. An existing script file has no artifact id",
|
||||
script_file_id=script_file.file_id,
|
||||
)
|
||||
continue
|
||||
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
if artifact:
|
||||
# override the actual file in the storage
|
||||
asyncio.create_task(app.STORAGE.store_artifact(artifact, content_bytes))
|
||||
else:
|
||||
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
|
||||
organization_id=organization_id,
|
||||
script_id=script_id,
|
||||
script_version=script_version,
|
||||
file_path=file.path,
|
||||
data=content_bytes,
|
||||
)
|
||||
# update the artifact_id in the script file
|
||||
await app.DATABASE.update_script_file(
|
||||
script_file_id=script_file.file_id,
|
||||
organization_id=organization_id,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
else:
|
||||
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
|
||||
organization_id=organization_id,
|
||||
script_id=script_id,
|
||||
script_version=script_version,
|
||||
file_path=file.path,
|
||||
data=content_bytes,
|
||||
)
|
||||
LOG.debug(
|
||||
"Created script file artifact",
|
||||
artifact_id=artifact_id,
|
||||
file_path=file.path,
|
||||
script_id=script_id,
|
||||
script_version=script_version,
|
||||
)
|
||||
# create a script file record
|
||||
await app.DATABASE.create_script_file(
|
||||
script_revision_id=script_revision_id,
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
file_path=file.path,
|
||||
file_name=file.path.split("/")[-1],
|
||||
file_type="file",
|
||||
content_hash=f"sha256:{content_hash}",
|
||||
file_size=file_size,
|
||||
mime_type=file.mime_type,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
else:
|
||||
artifact_id = await app.ARTIFACT_MANAGER.create_script_file_artifact(
|
||||
organization_id=organization_id,
|
||||
script_id=script_id,
|
||||
script_version=script_version,
|
||||
file_path=file.path,
|
||||
data=content_bytes,
|
||||
)
|
||||
LOG.debug(
|
||||
"Created script file artifact",
|
||||
artifact_id=artifact_id,
|
||||
file_path=file.path,
|
||||
script_id=script_id,
|
||||
script_version=script_version,
|
||||
)
|
||||
# create a script file record
|
||||
await app.DATABASE.create_script_file(
|
||||
script_revision_id=script_revision_id,
|
||||
script_id=script_id,
|
||||
organization_id=organization_id,
|
||||
file_path=file.path,
|
||||
file_name=file.path.split("/")[-1],
|
||||
file_type="file",
|
||||
content_hash=f"sha256:{content_hash}",
|
||||
file_size=file_size,
|
||||
mime_type=file.mime_type,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to create script file artifact",
|
||||
@@ -794,6 +856,7 @@ async def _regenerate_script_block_after_ai_fallback(
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
cache_key_value=cache_key_value,
|
||||
cache_key=workflow.cache_key,
|
||||
statuses=[ScriptStatus.published],
|
||||
)
|
||||
|
||||
if not existing_scripts:
|
||||
@@ -898,12 +961,12 @@ async def _regenerate_script_block_after_ai_fallback(
|
||||
)
|
||||
continue
|
||||
|
||||
await create_script_block(
|
||||
await create_or_update_script_block(
|
||||
block_code=block_file_content,
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
script_id=new_script.script_id,
|
||||
organization_id=organization_id,
|
||||
block_name=existing_block.script_block_label,
|
||||
block_label=existing_block.script_block_label,
|
||||
)
|
||||
block_file_content_bytes = (
|
||||
block_file_content if isinstance(block_file_content, bytes) else block_file_content.encode("utf-8")
|
||||
|
||||
@@ -466,6 +466,8 @@ async def run_task_v2_helper(
|
||||
|
||||
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
|
||||
current_run_id = context.run_id if context and context.run_id else task_v2_id
|
||||
# task v2 can be nested inside a workflow run, so we need to use the root workflow run id
|
||||
root_workflow_run_id = context.root_workflow_run_id if context and context.root_workflow_run_id else workflow_run_id
|
||||
enable_parse_select_in_extract = app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
|
||||
"ENABLE_PARSE_SELECT_IN_EXTRACT",
|
||||
current_run_id,
|
||||
@@ -476,6 +478,7 @@ async def run_task_v2_helper(
|
||||
organization_id=organization_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
root_workflow_run_id=root_workflow_run_id,
|
||||
request_id=request_id,
|
||||
task_v2_id=task_v2_id,
|
||||
run_id=current_run_id,
|
||||
|
||||
188
skyvern/services/workflow_script_service.py
Normal file
188
skyvern/services/workflow_script_service.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import base64
|
||||
|
||||
import structlog
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from skyvern.core.script_generations.generate_script import generate_workflow_script_python_code
|
||||
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
|
||||
from skyvern.schemas.scripts import FileEncoding, Script, ScriptFileCreate, ScriptStatus
|
||||
from skyvern.services import script_service
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
jinja_sandbox_env = SandboxedEnvironment()
|
||||
|
||||
|
||||
async def generate_or_update_draft_workflow_script(
|
||||
workflow_run: WorkflowRun,
|
||||
workflow: Workflow,
|
||||
) -> None:
|
||||
organization_id = workflow.organization_id
|
||||
context = skyvern_context.current()
|
||||
if not context:
|
||||
return
|
||||
script_id = context.script_id
|
||||
script = None
|
||||
if script_id:
|
||||
script = await app.DATABASE.get_script(script_id=script_id, organization_id=organization_id)
|
||||
|
||||
if not script:
|
||||
script = await app.DATABASE.create_script(organization_id=organization_id, run_id=workflow_run.workflow_run_id)
|
||||
if context:
|
||||
context.script_id = script.script_id
|
||||
context.script_revision_id = script.script_revision_id
|
||||
|
||||
_, rendered_cache_key_value = await get_workflow_script(
|
||||
workflow=workflow,
|
||||
workflow_run=workflow_run,
|
||||
status=ScriptStatus.pending,
|
||||
)
|
||||
await generate_workflow_script(
|
||||
workflow_run=workflow_run,
|
||||
workflow=workflow,
|
||||
script=script,
|
||||
rendered_cache_key_value=rendered_cache_key_value,
|
||||
draft=True,
|
||||
)
|
||||
|
||||
|
||||
async def get_workflow_script(
|
||||
workflow: Workflow,
|
||||
workflow_run: WorkflowRun,
|
||||
block_labels: list[str] | None = None,
|
||||
status: ScriptStatus = ScriptStatus.published,
|
||||
) -> tuple[Script | None, str]:
|
||||
"""
|
||||
Check if there's a related workflow script that should be used instead of running the workflow.
|
||||
Returns the tuple of (script, rendered_cache_key_value).
|
||||
"""
|
||||
cache_key = workflow.cache_key or ""
|
||||
rendered_cache_key_value = ""
|
||||
|
||||
if not workflow.generate_script:
|
||||
return None, rendered_cache_key_value
|
||||
if block_labels:
|
||||
# Do not generate script or run script if block_labels is provided
|
||||
return None, rendered_cache_key_value
|
||||
|
||||
try:
|
||||
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
|
||||
|
||||
rendered_cache_key_value = jinja_sandbox_env.from_string(cache_key).render(parameters)
|
||||
|
||||
# Check if there are existing cached scripts for this workflow + cache_key_value
|
||||
existing_scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
cache_key_value=rendered_cache_key_value,
|
||||
statuses=[status],
|
||||
)
|
||||
|
||||
if existing_scripts:
|
||||
LOG.info(
|
||||
"Found cached script for workflow",
|
||||
workflow_id=workflow.workflow_id,
|
||||
cache_key_value=rendered_cache_key_value,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
script_count=len(existing_scripts),
|
||||
)
|
||||
return existing_scripts[0], rendered_cache_key_value
|
||||
|
||||
return None, rendered_cache_key_value
|
||||
|
||||
except Exception as e:
|
||||
LOG.warning(
|
||||
"Failed to check for workflow script, proceeding with normal workflow execution",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
return None, rendered_cache_key_value
|
||||
|
||||
|
||||
async def generate_workflow_script(
|
||||
workflow_run: WorkflowRun,
|
||||
workflow: Workflow,
|
||||
script: Script,
|
||||
rendered_cache_key_value: str,
|
||||
draft: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
LOG.info(
|
||||
"Generating script for workflow",
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_name=workflow.title,
|
||||
cache_key_value=rendered_cache_key_value,
|
||||
)
|
||||
codegen_input = await transform_workflow_run_to_code_gen_input(
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
python_src = await generate_workflow_script_python_code(
|
||||
file_name=codegen_input.file_name,
|
||||
workflow_run_request=codegen_input.workflow_run,
|
||||
workflow=codegen_input.workflow,
|
||||
blocks=codegen_input.workflow_blocks,
|
||||
actions_by_task=codegen_input.actions_by_task,
|
||||
task_v2_child_blocks=codegen_input.task_v2_child_blocks,
|
||||
organization_id=workflow.organization_id,
|
||||
script_id=script.script_id,
|
||||
script_revision_id=script.script_revision_id,
|
||||
draft=draft,
|
||||
)
|
||||
except Exception:
|
||||
LOG.error("Failed to generate workflow script source", exc_info=True)
|
||||
return
|
||||
|
||||
# 4) Persist script and files, then record mapping
|
||||
content_bytes = python_src.encode("utf-8")
|
||||
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
|
||||
files = [
|
||||
ScriptFileCreate(
|
||||
path="main.py",
|
||||
content=content_b64,
|
||||
encoding=FileEncoding.BASE64,
|
||||
mime_type="text/x-python",
|
||||
)
|
||||
]
|
||||
|
||||
# Upload script file(s) as artifacts and create rows
|
||||
await script_service.build_file_tree(
|
||||
files=files,
|
||||
organization_id=workflow.organization_id,
|
||||
script_id=script.script_id,
|
||||
script_version=script.version,
|
||||
script_revision_id=script.script_revision_id,
|
||||
draft=draft,
|
||||
)
|
||||
|
||||
# check if an existing drfat workflow script exists for this workflow run
|
||||
existing_draft_workflow_script = None
|
||||
status = ScriptStatus.published
|
||||
if draft:
|
||||
status = ScriptStatus.pending
|
||||
existing_draft_workflow_script = await app.DATABASE.get_workflow_script(
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
statuses=[status],
|
||||
)
|
||||
if not existing_draft_workflow_script:
|
||||
# Record the workflow->script mapping for cache lookup
|
||||
await app.DATABASE.create_workflow_script(
|
||||
organization_id=workflow.organization_id,
|
||||
script_id=script.script_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
cache_key=workflow.cache_key or "",
|
||||
cache_key_value=rendered_cache_key_value,
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
status=status,
|
||||
)
|
||||
Reference in New Issue
Block a user