script regeneration after ai fallback (#3330)
This commit is contained in:
@@ -7,12 +7,15 @@ import os
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
import libcst as cst
|
||||
import structlog
|
||||
from fastapi import BackgroundTasks, HTTPException
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
|
||||
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
|
||||
from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block
|
||||
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
|
||||
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge import app
|
||||
@@ -22,8 +25,9 @@ from skyvern.forge.sdk.models import StepStatus
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskOutput, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.block import TaskBlock
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow
|
||||
from skyvern.schemas.runs import RunEngine
|
||||
from skyvern.schemas.scripts import CreateScriptResponse, FileNode, ScriptFileCreate
|
||||
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate
|
||||
from skyvern.schemas.workflows import BlockStatus, BlockType
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
@@ -446,6 +450,7 @@ async def _run_cached_function(cache_key: str) -> Any:
|
||||
|
||||
|
||||
async def _fallback_to_ai_run(
|
||||
block_type: BlockType,
|
||||
cache_key: str,
|
||||
prompt: str | None = None,
|
||||
url: str | None = None,
|
||||
@@ -475,32 +480,38 @@ async def _fallback_to_ai_run(
|
||||
and context.step_id
|
||||
):
|
||||
return
|
||||
organization_id = context.organization_id
|
||||
workflow_id = context.workflow_id
|
||||
workflow_run_id = context.workflow_run_id
|
||||
workflow_permanent_id = context.workflow_permanent_id
|
||||
task_id = context.task_id
|
||||
script_step_id = context.step_id
|
||||
try:
|
||||
organization_id = context.organization_id
|
||||
LOG.info(
|
||||
"Script trying to fallback to AI run",
|
||||
cache_key=cache_key,
|
||||
organization_id=organization_id,
|
||||
workflow_id=context.workflow_id,
|
||||
workflow_run_id=context.workflow_run_id,
|
||||
task_id=context.task_id,
|
||||
step_id=context.step_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
task_id=task_id,
|
||||
step_id=script_step_id,
|
||||
)
|
||||
# 1. fail the previous step
|
||||
previous_step = await app.DATABASE.update_step(
|
||||
step_id=context.step_id,
|
||||
task_id=context.task_id,
|
||||
step_id=script_step_id,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
status=StepStatus.failed,
|
||||
)
|
||||
# 2. create a new step for ai run
|
||||
ai_step = await app.DATABASE.create_step(
|
||||
task_id=context.task_id,
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
order=previous_step.order + 1,
|
||||
retry_index=0,
|
||||
)
|
||||
context.step_id = ai_step.step_id
|
||||
ai_step_id = ai_step.step_id
|
||||
# 3. build the task block
|
||||
# 4. run execute_step
|
||||
organization = await app.DATABASE.get_organization(organization_id=organization_id)
|
||||
@@ -510,21 +521,35 @@ async def _fallback_to_ai_run(
|
||||
if not task:
|
||||
raise Exception(f"Task is missing task_id={context.task_id}")
|
||||
workflow = await app.DATABASE.get_workflow(workflow_id=context.workflow_id, organization_id=organization_id)
|
||||
if not workflow or not workflow.ai_fallback:
|
||||
if not workflow:
|
||||
return
|
||||
if not workflow.ai_fallback:
|
||||
LOG.info(
|
||||
"AI fallback is not enabled for the workflow",
|
||||
workflow_id=workflow_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
# get the output_paramter
|
||||
output_parameter = workflow.get_output_parameter(cache_key)
|
||||
if not output_parameter:
|
||||
LOG.exception(
|
||||
"Output parameter not found for the workflow",
|
||||
workflow_id=workflow_id,
|
||||
workflow_permanent_id=workflow_permanent_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
return
|
||||
LOG.info(
|
||||
"Script starting to fallback to AI run",
|
||||
cache_key=cache_key,
|
||||
organization_id=organization_id,
|
||||
workflow_id=context.workflow_id,
|
||||
workflow_run_id=context.workflow_run_id,
|
||||
task_id=context.task_id,
|
||||
step_id=context.step_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
task_id=task_id,
|
||||
step_id=script_step_id,
|
||||
)
|
||||
|
||||
task_block = TaskBlock(
|
||||
@@ -553,6 +578,7 @@ async def _fallback_to_ai_run(
|
||||
step=ai_step,
|
||||
task_block=task_block,
|
||||
)
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -562,6 +588,37 @@ async def _fallback_to_ai_run(
|
||||
step_id=context.step_id,
|
||||
label=cache_key,
|
||||
)
|
||||
|
||||
# 5. After successful AI execution, regenerate the script block and create new version
|
||||
try:
|
||||
await _regenerate_script_block_after_ai_fallback(
|
||||
block_type=block_type,
|
||||
cache_key=cache_key,
|
||||
task_id=context.task_id,
|
||||
script_step_id=ai_step_id,
|
||||
ai_step_id=ai_step_id,
|
||||
organization_id=organization_id,
|
||||
workflow=workflow,
|
||||
workflow_run_id=context.workflow_run_id,
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
engine=engine,
|
||||
complete_criterion=complete_criterion,
|
||||
terminate_criterion=terminate_criterion,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
schema=schema,
|
||||
error_code_mapping=error_code_mapping,
|
||||
max_steps=max_steps,
|
||||
complete_on_download=complete_on_download,
|
||||
download_suffix=download_suffix,
|
||||
totp_verification_url=totp_verification_url,
|
||||
totp_identifier=totp_identifier,
|
||||
complete_verification=complete_verification,
|
||||
include_action_history_in_verification=include_action_history_in_verification,
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to regenerate script block after AI fallback", error=str(e), exc_info=True)
|
||||
# Don't fail the entire fallback process if script regeneration fails
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to fallback to AI run", cache_key=cache_key, exc_info=True)
|
||||
# Update block status to failed if workflow block was created
|
||||
@@ -577,6 +634,293 @@ async def _fallback_to_ai_run(
|
||||
raise e
|
||||
|
||||
|
||||
async def _regenerate_script_block_after_ai_fallback(
|
||||
block_type: BlockType,
|
||||
cache_key: str,
|
||||
task_id: str,
|
||||
script_step_id: str,
|
||||
ai_step_id: str,
|
||||
organization_id: str,
|
||||
workflow: Workflow,
|
||||
workflow_run_id: str,
|
||||
prompt: str | None = None,
|
||||
url: str | None = None,
|
||||
engine: RunEngine = RunEngine.skyvern_v1,
|
||||
complete_criterion: str | None = None,
|
||||
terminate_criterion: str | None = None,
|
||||
data_extraction_goal: str | None = None,
|
||||
schema: dict[str, Any] | list | str | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
max_steps: int | None = None,
|
||||
complete_on_download: bool = False,
|
||||
download_suffix: str | None = None,
|
||||
totp_verification_url: str | None = None,
|
||||
totp_identifier: str | None = None,
|
||||
complete_verification: bool = True,
|
||||
include_action_history_in_verification: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Regenerate the script block after a successful AI fallback and create a new script version.
|
||||
Only the specific block that fell back to AI is regenerated; all other blocks remain unchanged.
|
||||
|
||||
1. get the latest cashed script for the workflow
|
||||
2. create a completely new script, with only the current block's script being different as it's newly generated.
|
||||
-
|
||||
"""
|
||||
try:
|
||||
# Get the current script for this workflow and cache key value
|
||||
# Render the cache_key_value from workflow run parameters (same logic as generate_script_for_workflow)
|
||||
cache_key_value = ""
|
||||
if workflow.cache_key:
|
||||
try:
|
||||
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(workflow_run_id=workflow_run_id)
|
||||
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
|
||||
cache_key_value = jinja_sandbox_env.from_string(workflow.cache_key).render(parameters)
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to render cache key for script regeneration", error=str(e), exc_info=True)
|
||||
# Fallback to using cache_key as cache_key_value
|
||||
cache_key_value = cache_key
|
||||
|
||||
if not cache_key_value:
|
||||
cache_key_value = cache_key # Fallback
|
||||
|
||||
existing_scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
cache_key_value=cache_key_value,
|
||||
cache_key=workflow.cache_key,
|
||||
)
|
||||
|
||||
if not existing_scripts:
|
||||
LOG.error("No existing script found to regenerate", cache_key=cache_key, cache_key_value=cache_key_value)
|
||||
return
|
||||
|
||||
current_script = existing_scripts[0]
|
||||
LOG.info(
|
||||
"Regenerating script block after AI fallback",
|
||||
script_id=current_script.script_id,
|
||||
script_version=current_script.version,
|
||||
cache_key=cache_key,
|
||||
cache_key_value=cache_key_value,
|
||||
)
|
||||
|
||||
# Create a new script version
|
||||
new_script = await app.DATABASE.create_script(
|
||||
organization_id=organization_id,
|
||||
run_id=workflow_run_id,
|
||||
script_id=current_script.script_id, # Use same script_id for versioning
|
||||
version=current_script.version + 1,
|
||||
)
|
||||
|
||||
# deprecate the current workflow script
|
||||
await app.DATABASE.delete_workflow_cache_key_value(
|
||||
organization_id=organization_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
cache_key_value=cache_key_value,
|
||||
)
|
||||
|
||||
# Create workflow script mapping for the new version
|
||||
await app.DATABASE.create_workflow_script(
|
||||
organization_id=organization_id,
|
||||
script_id=new_script.script_id,
|
||||
workflow_permanent_id=workflow.workflow_permanent_id,
|
||||
cache_key=workflow.cache_key or "",
|
||||
cache_key_value=cache_key_value,
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
# Get all existing script blocks from the previous version
|
||||
existing_script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
||||
script_revision_id=current_script.script_revision_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
# Copy all existing script blocks to the new version (except the one we're regenerating)
|
||||
block_file_contents = []
|
||||
starter_block_file_content_bytes = b""
|
||||
block_file_content: bytes | str = ""
|
||||
for existing_block in existing_script_blocks:
|
||||
if existing_block.script_block_label == cache_key:
|
||||
# Skip this block - we'll regenerate it
|
||||
block_file_content = await _generate_block_code_from_task(
|
||||
block_type=block_type,
|
||||
cache_key=cache_key,
|
||||
task_id=task_id,
|
||||
script_step_id=script_step_id,
|
||||
ai_step_id=ai_step_id,
|
||||
organization_id=organization_id,
|
||||
workflow=workflow,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
else:
|
||||
# Copy the existing block to the new version
|
||||
# Get the script file content for this block and copy a new script block for it
|
||||
if existing_block.script_file_id:
|
||||
script_file = await app.DATABASE.get_script_file_by_id(
|
||||
script_revision_id=current_script.script_revision_id,
|
||||
file_id=existing_block.script_file_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if script_file and script_file.artifact_id:
|
||||
# Retrieve the artifact content
|
||||
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
|
||||
if artifact:
|
||||
file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
|
||||
if file_content:
|
||||
block_file_content = file_content
|
||||
else:
|
||||
LOG.warning(
|
||||
"Failed to retrieve artifact content for existing block",
|
||||
block_label=existing_block.script_block_label,
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"Artifact not found for existing block", block_label=existing_block.script_block_label
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"Script file or artifact not found for existing block",
|
||||
block_label=existing_block.script_block_label,
|
||||
)
|
||||
else:
|
||||
LOG.warning("No script file ID for existing block", block_label=existing_block.script_block_label)
|
||||
|
||||
if not block_file_content:
|
||||
LOG.warning(
|
||||
"No block file content found for existing block", block_label=existing_block.script_block_label
|
||||
)
|
||||
continue
|
||||
|
||||
await create_script_block(
|
||||
block_code=block_file_content,
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
script_id=new_script.script_id,
|
||||
organization_id=organization_id,
|
||||
block_name=existing_block.script_block_label,
|
||||
)
|
||||
block_file_content_bytes = (
|
||||
block_file_content if isinstance(block_file_content, bytes) else block_file_content.encode("utf-8")
|
||||
)
|
||||
if existing_block.script_block_label == settings.WORKFLOW_START_BLOCK_LABEL:
|
||||
starter_block_file_content_bytes = block_file_content_bytes
|
||||
else:
|
||||
block_file_contents.append(block_file_content_bytes)
|
||||
|
||||
if starter_block_file_content_bytes:
|
||||
block_file_contents.insert(0, starter_block_file_content_bytes)
|
||||
else:
|
||||
LOG.error("Starter block file content not found")
|
||||
|
||||
# 4) Persist script and files, then record mapping
|
||||
python_src = "\n\n".join([block_file_content.decode("utf-8") for block_file_content in block_file_contents])
|
||||
content_bytes = python_src.encode("utf-8")
|
||||
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
|
||||
files = [
|
||||
ScriptFileCreate(
|
||||
path="main.py",
|
||||
content=content_b64,
|
||||
encoding=FileEncoding.BASE64,
|
||||
mime_type="text/x-python",
|
||||
)
|
||||
]
|
||||
|
||||
# Upload script file(s) as artifacts and create rows
|
||||
await build_file_tree(
|
||||
files=files,
|
||||
organization_id=workflow.organization_id,
|
||||
script_id=new_script.script_id,
|
||||
script_version=new_script.version,
|
||||
script_revision_id=new_script.script_revision_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
LOG.error("Failed to regenerate script block after AI fallback", error=str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def _get_block_definition_by_label(
|
||||
label: str, workflow: Workflow, task_id: str, organization_id: str
|
||||
) -> dict[str, Any] | None:
|
||||
final_dump = None
|
||||
for block in workflow.workflow_definition.blocks:
|
||||
if block.label == label:
|
||||
final_dump = block.model_dump()
|
||||
break
|
||||
if not final_dump:
|
||||
return None
|
||||
|
||||
task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
|
||||
if task:
|
||||
task_dump = task.model_dump()
|
||||
final_dump.update({k: v for k, v in task_dump.items() if k not in final_dump})
|
||||
|
||||
# Add run block execution metadata
|
||||
final_dump.update(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"output": task.extracted_information,
|
||||
}
|
||||
)
|
||||
|
||||
return final_dump
|
||||
|
||||
|
||||
async def _generate_block_code_from_task(
|
||||
block_type: BlockType,
|
||||
cache_key: str,
|
||||
task_id: str,
|
||||
script_step_id: str,
|
||||
ai_step_id: str,
|
||||
organization_id: str,
|
||||
workflow: Workflow,
|
||||
workflow_run_id: str,
|
||||
) -> str:
|
||||
block_data = await _get_block_definition_by_label(cache_key, workflow, task_id, organization_id)
|
||||
if not block_data:
|
||||
return ""
|
||||
try:
|
||||
# Now regenerate only the specific block that fell back to AI
|
||||
task_actions = await app.DATABASE.get_task_actions_hydrated(
|
||||
task_id=task_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
# Filter actions by step_id and exclude the final action that failed before ai fallback
|
||||
actions_to_cache = []
|
||||
for index, task_action in enumerate(task_actions):
|
||||
# if this action is the last action of the script step, right before ai fallback, we should not include it
|
||||
if (
|
||||
index < len(task_actions) - 1
|
||||
and task_action.step_id == script_step_id
|
||||
and task_actions[index + 1].step_id == ai_step_id
|
||||
):
|
||||
continue
|
||||
action_dump = task_action.model_dump()
|
||||
action_dump["xpath"] = task_action.get_xpath()
|
||||
actions_to_cache.append(action_dump)
|
||||
|
||||
if not actions_to_cache:
|
||||
LOG.warning("No actions found in successful step for script block regeneration")
|
||||
return ""
|
||||
|
||||
# Generate the new block function
|
||||
block_fn_def = _build_block_fn(block_data, actions_to_cache)
|
||||
|
||||
# Convert the FunctionDef to code using a temporary module
|
||||
temp_module = cst.Module(body=[block_fn_def])
|
||||
block_code = temp_module.code
|
||||
|
||||
return block_code
|
||||
|
||||
except Exception as block_gen_error:
|
||||
LOG.error("Failed to generate block function", error=str(block_gen_error), exc_info=True)
|
||||
# Even if block generation fails, we've created the new script version
|
||||
# which can be useful for debugging
|
||||
return ""
|
||||
|
||||
|
||||
async def run_task(
|
||||
prompt: str,
|
||||
url: str | None = None,
|
||||
@@ -610,6 +954,7 @@ async def run_task(
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to run task block. Falling back to AI run.")
|
||||
await _fallback_to_ai_run(
|
||||
block_type=BlockType.TASK,
|
||||
cache_key=cache_key,
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
@@ -668,6 +1013,7 @@ async def download(
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to run download block. Falling back to AI run.")
|
||||
await _fallback_to_ai_run(
|
||||
block_type=BlockType.FILE_DOWNLOAD,
|
||||
cache_key=cache_key,
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
@@ -726,6 +1072,7 @@ async def action(
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to run action block. Falling back to AI run.")
|
||||
await _fallback_to_ai_run(
|
||||
block_type=BlockType.ACTION,
|
||||
cache_key=cache_key,
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
@@ -781,8 +1128,9 @@ async def login(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to run login block. Falling back to AI run.")
|
||||
LOG.exception("Failed to run login block")
|
||||
await _fallback_to_ai_run(
|
||||
block_type=BlockType.LOGIN,
|
||||
cache_key=cache_key,
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
|
||||
Reference in New Issue
Block a user