script regeneration after ai fallback (#3330)

This commit is contained in:
Shuchang Zheng
2025-08-31 11:46:31 +08:00
committed by GitHub
parent dd8f189234
commit 83b3cfb6af
4 changed files with 369 additions and 20 deletions

View File

@@ -7,12 +7,15 @@ import os
from datetime import datetime
from typing import Any, cast
import libcst as cst
import structlog
from fastapi import BackgroundTasks, HTTPException
from jinja2.sandbox import SandboxedEnvironment
from skyvern.config import settings
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
from skyvern.forge import app
@@ -22,8 +25,9 @@ from skyvern.forge.sdk.models import StepStatus
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.tasks import TaskOutput, TaskStatus
from skyvern.forge.sdk.workflow.models.block import TaskBlock
from skyvern.forge.sdk.workflow.models.workflow import Workflow
from skyvern.schemas.runs import RunEngine
from skyvern.schemas.scripts import CreateScriptResponse, FileNode, ScriptFileCreate
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate
from skyvern.schemas.workflows import BlockStatus, BlockType
LOG = structlog.get_logger(__name__)
@@ -446,6 +450,7 @@ async def _run_cached_function(cache_key: str) -> Any:
async def _fallback_to_ai_run(
block_type: BlockType,
cache_key: str,
prompt: str | None = None,
url: str | None = None,
@@ -475,32 +480,38 @@ async def _fallback_to_ai_run(
and context.step_id
):
return
organization_id = context.organization_id
workflow_id = context.workflow_id
workflow_run_id = context.workflow_run_id
workflow_permanent_id = context.workflow_permanent_id
task_id = context.task_id
script_step_id = context.step_id
try:
organization_id = context.organization_id
LOG.info(
"Script trying to fallback to AI run",
cache_key=cache_key,
organization_id=organization_id,
workflow_id=context.workflow_id,
workflow_run_id=context.workflow_run_id,
task_id=context.task_id,
step_id=context.step_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
task_id=task_id,
step_id=script_step_id,
)
# 1. fail the previous step
previous_step = await app.DATABASE.update_step(
step_id=context.step_id,
task_id=context.task_id,
step_id=script_step_id,
task_id=task_id,
organization_id=organization_id,
status=StepStatus.failed,
)
# 2. create a new step for ai run
ai_step = await app.DATABASE.create_step(
task_id=context.task_id,
task_id=task_id,
organization_id=organization_id,
order=previous_step.order + 1,
retry_index=0,
)
context.step_id = ai_step.step_id
ai_step_id = ai_step.step_id
# 3. build the task block
# 4. run execute_step
organization = await app.DATABASE.get_organization(organization_id=organization_id)
@@ -510,21 +521,35 @@ async def _fallback_to_ai_run(
if not task:
raise Exception(f"Task is missing task_id={context.task_id}")
workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id)
if not workflow or not workflow.ai_fallback:
if not workflow:
return
if not workflow.ai_fallback:
LOG.info(
"AI fallback is not enabled for the workflow",
workflow_id=workflow_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=workflow_run_id,
)
return
# get the output_paramter
output_parameter = workflow.get_output_parameter(cache_key)
if not output_parameter:
LOG.exception(
"Output parameter not found for the workflow",
workflow_id=workflow_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=workflow_run_id,
)
return
LOG.info(
"Script starting to fallback to AI run",
cache_key=cache_key,
organization_id=organization_id,
workflow_id=context.workflow_id,
workflow_run_id=context.workflow_run_id,
task_id=context.task_id,
step_id=context.step_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
task_id=task_id,
step_id=script_step_id,
)
task_block = TaskBlock(
@@ -553,6 +578,7 @@ async def _fallback_to_ai_run(
step=ai_step,
task_block=task_block,
)
# Update block status to completed if workflow block was created
if workflow_run_block_id:
await _update_workflow_block(
@@ -562,6 +588,37 @@ async def _fallback_to_ai_run(
step_id=context.step_id,
label=cache_key,
)
# 5. After successful AI execution, regenerate the script block and create new version
try:
await _regenerate_script_block_after_ai_fallback(
block_type=block_type,
cache_key=cache_key,
task_id=context.task_id,
script_step_id=ai_step_id,
ai_step_id=ai_step_id,
organization_id=organization_id,
workflow=workflow,
workflow_run_id=context.workflow_run_id,
prompt=prompt,
url=url,
engine=engine,
complete_criterion=complete_criterion,
terminate_criterion=terminate_criterion,
data_extraction_goal=data_extraction_goal,
schema=schema,
error_code_mapping=error_code_mapping,
max_steps=max_steps,
complete_on_download=complete_on_download,
download_suffix=download_suffix,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
complete_verification=complete_verification,
include_action_history_in_verification=include_action_history_in_verification,
)
except Exception as e:
LOG.warning("Failed to regenerate script block after AI fallback", error=str(e), exc_info=True)
# Don't fail the entire fallback process if script regeneration fails
except Exception as e:
LOG.warning("Failed to fallback to AI run", cache_key=cache_key, exc_info=True)
# Update block status to failed if workflow block was created
@@ -577,6 +634,293 @@ async def _fallback_to_ai_run(
raise e
async def _regenerate_script_block_after_ai_fallback(
block_type: BlockType,
cache_key: str,
task_id: str,
script_step_id: str,
ai_step_id: str,
organization_id: str,
workflow: Workflow,
workflow_run_id: str,
prompt: str | None = None,
url: str | None = None,
engine: RunEngine = RunEngine.skyvern_v1,
complete_criterion: str | None = None,
terminate_criterion: str | None = None,
data_extraction_goal: str | None = None,
schema: dict[str, Any] | list | str | None = None,
error_code_mapping: dict[str, str] | None = None,
max_steps: int | None = None,
complete_on_download: bool = False,
download_suffix: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
complete_verification: bool = True,
include_action_history_in_verification: bool = False,
) -> None:
"""
Regenerate the script block after a successful AI fallback and create a new script version.
Only the specific block that fell back to AI is regenerated; all other blocks remain unchanged.
1. get the latest cashed script for the workflow
2. create a completely new script, with only the current block's script being different as it's newly generated.
-
"""
try:
# Get the current script for this workflow and cache key value
# Render the cache_key_value from workflow run parameters (same logic as generate_script_for_workflow)
cache_key_value = ""
if workflow.cache_key:
try:
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
cache_key_value = jinja_sandbox_env.from_string(workflow.cache_key).render(parameters)
except Exception as e:
LOG.warning("Failed to render cache key for script regeneration", error=str(e), exc_info=True)
# Fallback to using cache_key as cache_key_value
cache_key_value = cache_key
if not cache_key_value:
cache_key_value = cache_key # Fallback
existing_scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
organization_id=organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=cache_key_value,
cache_key=workflow.cache_key,
)
if not existing_scripts:
LOG.error("No existing script found to regenerate", cache_key=cache_key, cache_key_value=cache_key_value)
return
current_script = existing_scripts[0]
LOG.info(
"Regenerating script block after AI fallback",
script_id=current_script.script_id,
script_version=current_script.version,
cache_key=cache_key,
cache_key_value=cache_key_value,
)
# Create a new script version
new_script = await app.DATABASE.create_script(
organization_id=organization_id,
run_id=workflow_run_id,
script_id=current_script.script_id, # Use same script_id for versioning
version=current_script.version + 1,
)
# deprecate the current workflow script
await app.DATABASE.delete_workflow_cache_key_value(
organization_id=organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=cache_key_value,
)
# Create workflow script mapping for the new version
await app.DATABASE.create_workflow_script(
organization_id=organization_id,
script_id=new_script.script_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key=workflow.cache_key or "",
cache_key_value=cache_key_value,
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run_id,
)
# Get all existing script blocks from the previous version
existing_script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_revision_id=current_script.script_revision_id,
organization_id=organization_id,
)
# Copy all existing script blocks to the new version (except the one we're regenerating)
block_file_contents = []
starter_block_file_content_bytes = b""
block_file_content: bytes | str = ""
for existing_block in existing_script_blocks:
if existing_block.script_block_label == cache_key:
# Skip this block - we'll regenerate it
block_file_content = await _generate_block_code_from_task(
block_type=block_type,
cache_key=cache_key,
task_id=task_id,
script_step_id=script_step_id,
ai_step_id=ai_step_id,
organization_id=organization_id,
workflow=workflow,
workflow_run_id=workflow_run_id,
)
else:
# Copy the existing block to the new version
# Get the script file content for this block and copy a new script block for it
if existing_block.script_file_id:
script_file = await app.DATABASE.get_script_file_by_id(
script_revision_id=current_script.script_revision_id,
file_id=existing_block.script_file_id,
organization_id=organization_id,
)
if script_file and script_file.artifact_id:
# Retrieve the artifact content
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
if artifact:
file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
if file_content:
block_file_content = file_content
else:
LOG.warning(
"Failed to retrieve artifact content for existing block",
block_label=existing_block.script_block_label,
)
else:
LOG.warning(
"Artifact not found for existing block", block_label=existing_block.script_block_label
)
else:
LOG.warning(
"Script file or artifact not found for existing block",
block_label=existing_block.script_block_label,
)
else:
LOG.warning("No script file ID for existing block", block_label=existing_block.script_block_label)
if not block_file_content:
LOG.warning(
"No block file content found for existing block", block_label=existing_block.script_block_label
)
continue
await create_script_block(
block_code=block_file_content,
script_revision_id=new_script.script_revision_id,
script_id=new_script.script_id,
organization_id=organization_id,
block_name=existing_block.script_block_label,
)
block_file_content_bytes = (
block_file_content if isinstance(block_file_content, bytes) else block_file_content.encode("utf-8")
)
if existing_block.script_block_label == settings.WORKFLOW_START_BLOCK_LABEL:
starter_block_file_content_bytes = block_file_content_bytes
else:
block_file_contents.append(block_file_content_bytes)
if starter_block_file_content_bytes:
block_file_contents.insert(0, starter_block_file_content_bytes)
else:
LOG.error("Starter block file content not found")
# 4) Persist script and files, then record mapping
python_src = "\n\n".join([block_file_content.decode("utf-8") for block_file_content in block_file_contents])
content_bytes = python_src.encode("utf-8")
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
files = [
ScriptFileCreate(
path="main.py",
content=content_b64,
encoding=FileEncoding.BASE64,
mime_type="text/x-python",
)
]
# Upload script file(s) as artifacts and create rows
await build_file_tree(
files=files,
organization_id=workflow.organization_id,
script_id=new_script.script_id,
script_version=new_script.version,
script_revision_id=new_script.script_revision_id,
)
except Exception as e:
LOG.error("Failed to regenerate script block after AI fallback", error=str(e), exc_info=True)
raise
async def _get_block_definition_by_label(
label: str, workflow: Workflow, task_id: str, organization_id: str
) -> dict[str, Any] | None:
final_dump = None
for block in workflow.workflow_definition.blocks:
if block.label == label:
final_dump = block.model_dump()
break
if not final_dump:
return None
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
if task:
task_dump = task.model_dump()
final_dump.update({k: v for k, v in task_dump.items() if k not in final_dump})
# Add run block execution metadata
final_dump.update(
{
"task_id": task_id,
"output": task.extracted_information,
}
)
return final_dump
async def _generate_block_code_from_task(
block_type: BlockType,
cache_key: str,
task_id: str,
script_step_id: str,
ai_step_id: str,
organization_id: str,
workflow: Workflow,
workflow_run_id: str,
) -> str:
block_data = await _get_block_definition_by_label(cache_key, workflow, task_id, organization_id)
if not block_data:
return ""
try:
# Now regenerate only the specific block that fell back to AI
task_actions = await app.DATABASE.get_task_actions_hydrated(
task_id=task_id,
organization_id=organization_id,
)
# Filter actions by step_id and exclude the final action that failed before ai fallback
actions_to_cache = []
for index, task_action in enumerate(task_actions):
# if this action is the last action of the script step, right before ai fallback, we should not include it
if (
index < len(task_actions) - 1
and task_action.step_id == script_step_id
and task_actions[index + 1].step_id == ai_step_id
):
continue
action_dump = task_action.model_dump()
action_dump["xpath"] = task_action.get_xpath()
actions_to_cache.append(action_dump)
if not actions_to_cache:
LOG.warning("No actions found in successful step for script block regeneration")
return ""
# Generate the new block function
block_fn_def = _build_block_fn(block_data, actions_to_cache)
# Convert the FunctionDef to code using a temporary module
temp_module = cst.Module(body=[block_fn_def])
block_code = temp_module.code
return block_code
except Exception as block_gen_error:
LOG.error("Failed to generate block function", error=str(block_gen_error), exc_info=True)
# Even if block generation fails, we've created the new script version
# which can be useful for debugging
return ""
async def run_task(
prompt: str,
url: str | None = None,
@@ -610,6 +954,7 @@ async def run_task(
except Exception as e:
LOG.exception("Failed to run task block. Falling back to AI run.")
await _fallback_to_ai_run(
block_type=BlockType.TASK,
cache_key=cache_key,
prompt=prompt,
url=url,
@@ -668,6 +1013,7 @@ async def download(
except Exception as e:
LOG.exception("Failed to run download block. Falling back to AI run.")
await _fallback_to_ai_run(
block_type=BlockType.FILE_DOWNLOAD,
cache_key=cache_key,
prompt=prompt,
url=url,
@@ -726,6 +1072,7 @@ async def action(
except Exception as e:
LOG.exception("Failed to run action block. Falling back to AI run.")
await _fallback_to_ai_run(
block_type=BlockType.ACTION,
cache_key=cache_key,
prompt=prompt,
url=url,
@@ -781,8 +1128,9 @@ async def login(
)
except Exception as e:
LOG.exception("Failed to run login block. Falling back to AI run.")
LOG.exception("Failed to run login block")
await _fallback_to_ai_run(
block_type=BlockType.LOGIN,
cache_key=cache_key,
prompt=prompt,
url=url,