From 83b3cfb6af329e08f6df9fbaab4a4a2cb8b5d7fd Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sun, 31 Aug 2025 11:46:31 +0800 Subject: [PATCH] script regeneration after ai fallback (#3330) --- .../script_generations/generate_script.py | 9 +- .../core/script_generations/skyvern_page.py | 1 - skyvern/forge/sdk/db/client.py | 1 + skyvern/services/script_service.py | 378 +++++++++++++++++- 4 files changed, 369 insertions(+), 20 deletions(-) diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index c46b1ce1..8d5fb482 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -1187,7 +1187,7 @@ async def generate_workflow_script( async def create_script_block( - block_code: str, + block_code: str | bytes, script_revision_id: str, script_id: str, organization_id: str, @@ -1205,6 +1205,7 @@ async def create_script_block( block_name: Optional custom name for the block (defaults to function name) block_description: Optional description for the block """ + block_code_bytes = block_code if isinstance(block_code, bytes) else block_code.encode("utf-8") try: # Step 3: Create script block in database script_block = await app.DATABASE.create_script_block( @@ -1225,7 +1226,7 @@ async def create_script_block( script_id=script_id, script_version=1, # Assuming version 1 for now file_path=file_path, - data=block_code.encode("utf-8"), + data=block_code_bytes, ) # Create script file record @@ -1236,8 +1237,8 @@ async def create_script_block( file_path=file_path, file_name=file_name, file_type="file", - content_hash=f"sha256:{hashlib.sha256(block_code.encode('utf-8')).hexdigest()}", - file_size=len(block_code.encode("utf-8")), + content_hash=f"sha256:{hashlib.sha256(block_code_bytes).hexdigest()}", + file_size=len(block_code_bytes), mime_type="text/x-python", artifact_id=artifact_id, ) diff --git a/skyvern/core/script_generations/skyvern_page.py b/skyvern/core/script_generations/skyvern_page.py index 340c0ab7..d30ef442 100644 --- a/skyvern/core/script_generations/skyvern_page.py +++ b/skyvern/core/script_generations/skyvern_page.py @@ -278,7 +278,6 @@ class SkyvernPage: If the prompt generation or parsing fails for any reason we fall back to clicking the originally supplied ``xpath``. """ - new_xpath = xpath if intention and data: diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 9e17cdcb..cbc17601 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -3967,6 +3967,7 @@ class AgentDB: select(ScriptBlockModel) .filter_by(script_revision_id=script_revision_id) .filter_by(organization_id=organization_id) + .order_by(ScriptBlockModel.created_at.asc()) ) ).all() return [convert_to_script_block(record) for record in records] diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index 2af0e648..74c83028 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -7,12 +7,15 @@ import os from datetime import datetime from typing import Any, cast +import libcst as cst import structlog from fastapi import BackgroundTasks, HTTPException from jinja2.sandbox import SandboxedEnvironment +from skyvern.config import settings from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS +from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound from skyvern.forge import app @@ -22,8 +25,9 @@ from skyvern.forge.sdk.models import StepStatus from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.tasks import TaskOutput, TaskStatus from skyvern.forge.sdk.workflow.models.block import TaskBlock +from skyvern.forge.sdk.workflow.models.workflow import Workflow from skyvern.schemas.runs import RunEngine -from skyvern.schemas.scripts import CreateScriptResponse, FileNode, ScriptFileCreate +from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate from skyvern.schemas.workflows import BlockStatus, BlockType LOG = structlog.get_logger(__name__) @@ -446,6 +450,7 @@ async def _run_cached_function(cache_key: str) -> Any: async def _fallback_to_ai_run( + block_type: BlockType, cache_key: str, prompt: str | None = None, url: str | None = None, @@ -475,32 +480,38 @@ async def _fallback_to_ai_run( and context.step_id ): return + organization_id = context.organization_id + workflow_id = context.workflow_id + workflow_run_id = context.workflow_run_id + workflow_permanent_id = context.workflow_permanent_id + task_id = context.task_id + script_step_id = context.step_id try: - organization_id = context.organization_id LOG.info( "Script trying to fallback to AI run", cache_key=cache_key, organization_id=organization_id, - workflow_id=context.workflow_id, - workflow_run_id=context.workflow_run_id, - task_id=context.task_id, - step_id=context.step_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + task_id=task_id, + step_id=script_step_id, ) # 1. fail the previous step previous_step = await app.DATABASE.update_step( - step_id=context.step_id, - task_id=context.task_id, + step_id=script_step_id, + task_id=task_id, organization_id=organization_id, status=StepStatus.failed, ) # 2. create a new step for ai run ai_step = await app.DATABASE.create_step( - task_id=context.task_id, + task_id=task_id, organization_id=organization_id, order=previous_step.order + 1, retry_index=0, ) context.step_id = ai_step.step_id + ai_step_id = ai_step.step_id # 3. build the task block # 4. run execute_step organization = await app.DATABASE.get_organization(organization_id=organization_id) @@ -510,21 +521,35 @@ async def _fallback_to_ai_run( if not task: raise Exception(f"Task is missing task_id={context.task_id}") workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id) - if not workflow or not workflow.ai_fallback: + if not workflow: + return + if not workflow.ai_fallback: + LOG.info( + "AI fallback is not enabled for the workflow", + workflow_id=workflow_id, + workflow_permanent_id=workflow_permanent_id, + workflow_run_id=workflow_run_id, + ) return # get the output_paramter output_parameter = workflow.get_output_parameter(cache_key) if not output_parameter: + LOG.exception( + "Output parameter not found for the workflow", + workflow_id=workflow_id, + workflow_permanent_id=workflow_permanent_id, + workflow_run_id=workflow_run_id, + ) return LOG.info( "Script starting to fallback to AI run", cache_key=cache_key, organization_id=organization_id, - workflow_id=context.workflow_id, - workflow_run_id=context.workflow_run_id, - task_id=context.task_id, - step_id=context.step_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + task_id=task_id, + step_id=script_step_id, ) task_block = TaskBlock( @@ -553,6 +578,7 @@ async def _fallback_to_ai_run( step=ai_step, task_block=task_block, ) + # Update block status to completed if workflow block was created if workflow_run_block_id: await _update_workflow_block( @@ -562,6 +588,37 @@ async def _fallback_to_ai_run( step_id=context.step_id, label=cache_key, ) + + # 5. After successful AI execution, regenerate the script block and create new version + try: + await _regenerate_script_block_after_ai_fallback( + block_type=block_type, + cache_key=cache_key, + task_id=context.task_id, + script_step_id=ai_step_id, + ai_step_id=ai_step_id, + organization_id=organization_id, + workflow=workflow, + workflow_run_id=context.workflow_run_id, + prompt=prompt, + url=url, + engine=engine, + complete_criterion=complete_criterion, + terminate_criterion=terminate_criterion, + data_extraction_goal=data_extraction_goal, + schema=schema, + error_code_mapping=error_code_mapping, + max_steps=max_steps, + complete_on_download=complete_on_download, + download_suffix=download_suffix, + totp_verification_url=totp_verification_url, + totp_identifier=totp_identifier, + complete_verification=complete_verification, + include_action_history_in_verification=include_action_history_in_verification, + ) + except Exception as e: + LOG.warning("Failed to regenerate script block after AI fallback", error=str(e), exc_info=True) + # Don't fail the entire fallback process if script regeneration fails except Exception as e: LOG.warning("Failed to fallback to AI run", cache_key=cache_key, exc_info=True) # Update block status to failed if workflow block was created @@ -577,6 +634,293 @@ async def _fallback_to_ai_run( raise e +async def _regenerate_script_block_after_ai_fallback( + block_type: BlockType, + cache_key: str, + task_id: str, + script_step_id: str, + ai_step_id: str, + organization_id: str, + workflow: Workflow, + workflow_run_id: str, + prompt: str | None = None, + url: str | None = None, + engine: RunEngine = RunEngine.skyvern_v1, + complete_criterion: str | None = None, + terminate_criterion: str | None = None, + data_extraction_goal: str | None = None, + schema: dict[str, Any] | list | str | None = None, + error_code_mapping: dict[str, str] | None = None, + max_steps: int | None = None, + complete_on_download: bool = False, + download_suffix: str | None = None, + totp_verification_url: str | None = None, + totp_identifier: str | None = None, + complete_verification: bool = True, + include_action_history_in_verification: bool = False, +) -> None: + """ + Regenerate the script block after a successful AI fallback and create a new script version. + Only the specific block that fell back to AI is regenerated; all other blocks remain unchanged. + + 1. get the latest cashed script for the workflow + 2. create a completely new script, with only the current block's script being different as it's newly generated. + - + """ + try: + # Get the current script for this workflow and cache key value + # Render the cache_key_value from workflow run parameters (same logic as generate_script_for_workflow) + cache_key_value = "" + if workflow.cache_key: + try: + parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id) + parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples} + cache_key_value = jinja_sandbox_env.from_string(workflow.cache_key).render(parameters) + except Exception as e: + LOG.warning("Failed to render cache key for script regeneration", error=str(e), exc_info=True) + # Fallback to using cache_key as cache_key_value + cache_key_value = cache_key + + if not cache_key_value: + cache_key_value = cache_key # Fallback + + existing_scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value( + organization_id=organization_id, + workflow_permanent_id=workflow.workflow_permanent_id, + cache_key_value=cache_key_value, + cache_key=workflow.cache_key, + ) + + if not existing_scripts: + LOG.error("No existing script found to regenerate", cache_key=cache_key, cache_key_value=cache_key_value) + return + + current_script = existing_scripts[0] + LOG.info( + "Regenerating script block after AI fallback", + script_id=current_script.script_id, + script_version=current_script.version, + cache_key=cache_key, + cache_key_value=cache_key_value, + ) + + # Create a new script version + new_script = await app.DATABASE.create_script( + organization_id=organization_id, + run_id=workflow_run_id, + script_id=current_script.script_id, # Use same script_id for versioning + version=current_script.version + 1, + ) + + # deprecate the current workflow script + await app.DATABASE.delete_workflow_cache_key_value( + organization_id=organization_id, + workflow_permanent_id=workflow.workflow_permanent_id, + cache_key_value=cache_key_value, + ) + + # Create workflow script mapping for the new version + await app.DATABASE.create_workflow_script( + organization_id=organization_id, + script_id=new_script.script_id, + workflow_permanent_id=workflow.workflow_permanent_id, + cache_key=workflow.cache_key or "", + cache_key_value=cache_key_value, + workflow_id=workflow.workflow_id, + workflow_run_id=workflow_run_id, + ) + + # Get all existing script blocks from the previous version + existing_script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id( + script_revision_id=current_script.script_revision_id, + organization_id=organization_id, + ) + + # Copy all existing script blocks to the new version (except the one we're regenerating) + block_file_contents = [] + starter_block_file_content_bytes = b"" + block_file_content: bytes | str = "" + for existing_block in existing_script_blocks: + if existing_block.script_block_label == cache_key: + # Skip this block - we'll regenerate it + block_file_content = await _generate_block_code_from_task( + block_type=block_type, + cache_key=cache_key, + task_id=task_id, + script_step_id=script_step_id, + ai_step_id=ai_step_id, + organization_id=organization_id, + workflow=workflow, + workflow_run_id=workflow_run_id, + ) + else: + # Copy the existing block to the new version + # Get the script file content for this block and copy a new script block for it + if existing_block.script_file_id: + script_file = await app.DATABASE.get_script_file_by_id( + script_revision_id=current_script.script_revision_id, + file_id=existing_block.script_file_id, + organization_id=organization_id, + ) + + if script_file and script_file.artifact_id: + # Retrieve the artifact content + artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id) + if artifact: + file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact) + if file_content: + block_file_content = file_content + else: + LOG.warning( + "Failed to retrieve artifact content for existing block", + block_label=existing_block.script_block_label, + ) + else: + LOG.warning( + "Artifact not found for existing block", block_label=existing_block.script_block_label + ) + else: + LOG.warning( + "Script file or artifact not found for existing block", + block_label=existing_block.script_block_label, + ) + else: + LOG.warning("No script file ID for existing block", block_label=existing_block.script_block_label) + + if not block_file_content: + LOG.warning( + "No block file content found for existing block", block_label=existing_block.script_block_label + ) + continue + + await create_script_block( + block_code=block_file_content, + script_revision_id=new_script.script_revision_id, + script_id=new_script.script_id, + organization_id=organization_id, + block_name=existing_block.script_block_label, + ) + block_file_content_bytes = ( + block_file_content if isinstance(block_file_content, bytes) else block_file_content.encode("utf-8") + ) + if existing_block.script_block_label == settings.WORKFLOW_START_BLOCK_LABEL: + starter_block_file_content_bytes = block_file_content_bytes + else: + block_file_contents.append(block_file_content_bytes) + + if starter_block_file_content_bytes: + block_file_contents.insert(0, starter_block_file_content_bytes) + else: + LOG.error("Starter block file content not found") + + # 4) Persist script and files, then record mapping + python_src = "\n\n".join([block_file_content.decode("utf-8") for block_file_content in block_file_contents]) + content_bytes = python_src.encode("utf-8") + content_b64 = base64.b64encode(content_bytes).decode("utf-8") + files = [ + ScriptFileCreate( + path="main.py", + content=content_b64, + encoding=FileEncoding.BASE64, + mime_type="text/x-python", + ) + ] + + # Upload script file(s) as artifacts and create rows + await build_file_tree( + files=files, + organization_id=workflow.organization_id, + script_id=new_script.script_id, + script_version=new_script.version, + script_revision_id=new_script.script_revision_id, + ) + + except Exception as e: + LOG.error("Failed to regenerate script block after AI fallback", error=str(e), exc_info=True) + raise + + +async def _get_block_definition_by_label( + label: str, workflow: Workflow, task_id: str, organization_id: str +) -> dict[str, Any] | None: + final_dump = None + for block in workflow.workflow_definition.blocks: + if block.label == label: + final_dump = block.model_dump() + break + if not final_dump: + return None + + task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id) + if task: + task_dump = task.model_dump() + final_dump.update({k: v for k, v in task_dump.items() if k not in final_dump}) + + # Add run block execution metadata + final_dump.update( + { + "task_id": task_id, + "output": task.extracted_information, + } + ) + + return final_dump + + +async def _generate_block_code_from_task( + block_type: BlockType, + cache_key: str, + task_id: str, + script_step_id: str, + ai_step_id: str, + organization_id: str, + workflow: Workflow, + workflow_run_id: str, +) -> str: + block_data = await _get_block_definition_by_label(cache_key, workflow, task_id, organization_id) + if not block_data: + return "" + try: + # Now regenerate only the specific block that fell back to AI + task_actions = await app.DATABASE.get_task_actions_hydrated( + task_id=task_id, + organization_id=organization_id, + ) + + # Filter actions by step_id and exclude the final action that failed before ai fallback + actions_to_cache = [] + for index, task_action in enumerate(task_actions): + # if this action is the last action of the script step, right before ai fallback, we should not include it + if ( + index < len(task_actions) - 1 + and task_action.step_id == script_step_id + and task_actions[index + 1].step_id == ai_step_id + ): + continue + action_dump = task_action.model_dump() + action_dump["xpath"] = task_action.get_xpath() + actions_to_cache.append(action_dump) + + if not actions_to_cache: + LOG.warning("No actions found in successful step for script block regeneration") + return "" + + # Generate the new block function + block_fn_def = _build_block_fn(block_data, actions_to_cache) + + # Convert the FunctionDef to code using a temporary module + temp_module = cst.Module(body=[block_fn_def]) + block_code = temp_module.code + + return block_code + + except Exception as block_gen_error: + LOG.error("Failed to generate block function", error=str(block_gen_error), exc_info=True) + # Even if block generation fails, we've created the new script version + # which can be useful for debugging + return "" + + async def run_task( prompt: str, url: str | None = None, @@ -610,6 +954,7 @@ async def run_task( except Exception as e: LOG.exception("Failed to run task block. Falling back to AI run.") await _fallback_to_ai_run( + block_type=BlockType.TASK, cache_key=cache_key, prompt=prompt, url=url, @@ -668,6 +1013,7 @@ async def download( except Exception as e: LOG.exception("Failed to run download block. Falling back to AI run.") await _fallback_to_ai_run( + block_type=BlockType.FILE_DOWNLOAD, cache_key=cache_key, prompt=prompt, url=url, @@ -726,6 +1072,7 @@ async def action( except Exception as e: LOG.exception("Failed to run action block. Falling back to AI run.") await _fallback_to_ai_run( + block_type=BlockType.ACTION, cache_key=cache_key, prompt=prompt, url=url, @@ -781,8 +1128,9 @@ async def login( ) except Exception as e: - LOG.exception("Failed to run login block. Falling back to AI run.") + LOG.exception("Failed to run login block") await _fallback_to_ai_run( + block_type=BlockType.LOGIN, cache_key=cache_key, prompt=prompt, url=url,