map failure page to error code if any, when Script run when task failed (#4149)

This commit is contained in:
Shuchang Zheng
2025-11-30 17:34:48 -08:00
committed by GitHub
parent 9d44997584
commit 76a61d23e6
3 changed files with 1666 additions and 1442 deletions

View File

@@ -2,6 +2,7 @@ import asyncio
import base64
import hashlib
import importlib.util
import json
import os
import uuid
from dataclasses import dataclass
@@ -18,8 +19,10 @@ from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_script import _build_block_fn, create_or_update_script_block
from skyvern.core.script_generations.script_skyvern_page import script_run_context_manager
from skyvern.errors.errors import UserDefinedError
from skyvern.exceptions import ScriptNotFound, ScriptTerminationException, StepTerminationError, WorkflowRunNotFound
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.db.enums import TaskType
@@ -58,6 +61,7 @@ from skyvern.schemas.scripts import (
ScriptStatus,
)
from skyvern.schemas.workflows import BlockStatus, BlockType, FileStorageType, FileType
from skyvern.webeye.scraper.scraper import ElementTreeFormat
LOG = structlog.get_logger()
jinja_sandbox_env = SandboxedEnvironment()
@@ -666,6 +670,90 @@ async def _run_cached_function(cached_fn: Callable) -> Any:
return await cached_fn(page=run_context.page, context=run_context)
async def _detect_user_defined_errors(
task: Task,
step: Step,
workflow_run_id: str,
error_code_mapping: dict[str, str],
prompt: str | None = None,
) -> list[UserDefinedError]:
"""
Detect user-defined errors using LLM when error_code_mapping is provided.
Returns a list of UserDefinedError objects if any errors are detected.
"""
try:
run_context = script_run_context_manager.ensure_run_context()
skyvern_page = run_context.page
scraped_page = await skyvern_page.scraped_page.refresh()
skyvern_page.scraped_page = scraped_page
current_url = scraped_page.url
# Build element tree
element_tree_format = ElementTreeFormat.HTML
elements = scraped_page.build_element_tree(element_tree_format)
screenshots = scraped_page.screenshots
# Build the prompt using the surface-user-defined-errors template
context = skyvern_context.current()
tz_info = datetime.now().astimezone().tzinfo
if context and context.tz_info:
tz_info = context.tz_info
prompt_name = "surface-user-defined-errors"
error_detection_prompt = prompt_engine.load_prompt(
prompt_name,
error_code_mapping_str=json.dumps(error_code_mapping),
navigation_goal=prompt or task.navigation_goal or "",
navigation_payload_str=json.dumps(task.navigation_payload or {}),
elements=elements,
current_url=current_url,
action_history=[],
local_datetime=datetime.now(tz_info).isoformat(),
reasoning=None,
)
# Call LLM to detect errors
json_response = await app.EXTRACTION_LLM_API_HANDLER(
prompt=error_detection_prompt,
screenshots=screenshots,
step=step,
prompt_name=prompt_name,
)
# Parse the response and extract errors
errors_list = json_response.get("errors", [])
user_defined_errors = []
for error_dict in errors_list:
try:
user_defined_error = UserDefinedError.model_validate(error_dict)
user_defined_errors.append(user_defined_error)
except Exception:
LOG.warning(
"Failed to validate user-defined error",
error_dict=error_dict,
)
LOG.info(
"Detected user-defined errors",
task_id=task.task_id,
step_id=step.step_id,
error_count=len(user_defined_errors),
errors=[e.error_code for e in user_defined_errors],
)
return user_defined_errors
except Exception as e:
LOG.exception(
"Failed to detect user-defined errors",
task_id=task.task_id,
step_id=step.step_id,
error=str(e),
)
return []
async def _fallback_to_ai_run(
block_type: BlockType,
cache_key: str,
@@ -746,13 +834,51 @@ async def _fallback_to_ai_run(
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=workflow_run_id,
)
# If error_code_mapping is provided, detect user-defined errors using LLM
detected_errors: list[UserDefinedError] = []
if error_code_mapping:
LOG.info(
"Error code mapping provided, detecting user-defined errors",
workflow_run_id=workflow_run_id,
task_id=task_id,
)
detected_errors = await _detect_user_defined_errors(
task=task,
step=previous_step,
workflow_run_id=workflow_run_id,
error_code_mapping=error_code_mapping,
prompt=prompt,
)
# Update task errors if any errors were detected
if detected_errors:
task_errors = task.errors or []
task_errors.extend([error.model_dump() for error in detected_errors])
await app.DATABASE.update_task(
task_id=task_id,
organization_id=organization_id,
errors=task_errors,
)
LOG.info(
"Updated task with detected user-defined errors",
task_id=task_id,
error_codes=[e.error_code for e in detected_errors],
)
# Update workflow block with failure reason (include detected errors if any)
task_failure_reason = str(error)
if detected_errors:
error_codes = [e.error_code for e in detected_errors]
task_failure_reason = f"{task_failure_reason}. Detected errors: {', '.join(error_codes)}"
if workflow_run_block_id:
await _update_workflow_block(
workflow_run_block_id,
BlockStatus.failed,
task_id=task_id,
task_status=TaskStatus.failed,
failure_reason=str(error),
failure_reason=task_failure_reason,
step_id=script_step_id,
step_status=StepStatus.failed,
label=cache_key,
@@ -1194,6 +1320,7 @@ async def run_task(
cache_key: str | None = None,
engine: RunEngine = RunEngine.skyvern_v1,
model: dict[str, Any] | None = None,
error_code_mapping: dict[str, str] | None = None,
) -> dict[str, Any] | list | str | None:
cache_key = cache_key or label
cached_fn = script_run_context_manager.get_cached_fn(cache_key)
@@ -1239,6 +1366,7 @@ async def run_task(
totp_url=totp_url,
error=e,
workflow_run_block_id=workflow_run_block_id,
error_code_mapping=error_code_mapping,
)
return None
finally:
@@ -1278,6 +1406,7 @@ async def download(
label: str | None = None,
cache_key: str | None = None,
model: dict[str, Any] | None = None,
error_code_mapping: dict[str, str] | None = None,
) -> None:
cache_key = cache_key or label
cached_fn = script_run_context_manager.get_cached_fn(cache_key)
@@ -1320,6 +1449,7 @@ async def download(
complete_on_download=complete_on_download,
error=e,
workflow_run_block_id=workflow_run_block_id,
error_code_mapping=error_code_mapping,
)
finally:
context.prompt = None
@@ -1356,6 +1486,7 @@ async def action(
label: str | None = None,
cache_key: str | None = None,
model: dict[str, Any] | None = None,
error_code_mapping: dict[str, str] | None = None,
) -> None:
context: skyvern_context.SkyvernContext | None
cache_key = cache_key or label
@@ -1399,6 +1530,7 @@ async def action(
totp_url=totp_url,
error=e,
workflow_run_block_id=workflow_run_block_id,
error_code_mapping=error_code_mapping,
)
finally:
context.prompt = None
@@ -1432,6 +1564,7 @@ async def login(
label: str | None = None,
cache_key: str | None = None,
model: dict[str, Any] | None = None,
error_code_mapping: dict[str, str] | None = None,
) -> None:
context: skyvern_context.SkyvernContext | None
cache_key = cache_key or label
@@ -1475,6 +1608,7 @@ async def login(
totp_url=totp_url,
error=e,
workflow_run_block_id=workflow_run_block_id,
error_code_mapping=error_code_mapping,
)
finally:
context.prompt = None