Add schema validation and default value filling for extraction results (#4063)
Co-authored-by: Suchintan <suchintan@users.noreply.github.com>
This commit is contained in:
304
skyvern/forge/sdk/api/llm/schema_validator.py
Normal file
304
skyvern/forge/sdk/api/llm/schema_validator.py
Normal file
@@ -0,0 +1,304 @@
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from jsonschema import Draft202012Validator
|
||||
from jsonschema.exceptions import SchemaError
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
_TYPE_DEFAULT_FACTORIES: dict[str, Any] = {
|
||||
"string": lambda: None,
|
||||
"number": lambda: 0,
|
||||
"integer": lambda: 0,
|
||||
"boolean": lambda: False,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
"null": lambda: None,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_schema_type(schema_type: str | list[Any] | None, path: str) -> str | None:
|
||||
"""Normalize a schema type definition to a single string value."""
|
||||
if isinstance(schema_type, list):
|
||||
non_null_types = [str(t).lower() for t in schema_type if str(t).lower() != "null"]
|
||||
if not non_null_types:
|
||||
return "null"
|
||||
|
||||
if len(non_null_types) > 1:
|
||||
LOG.warning(
|
||||
"Multiple non-null types in schema, using first one",
|
||||
path=path,
|
||||
types=non_null_types,
|
||||
)
|
||||
return non_null_types[0]
|
||||
|
||||
return str(schema_type).lower() if schema_type is not None else None
|
||||
|
||||
|
||||
def get_default_value_for_type(schema_type: str | list[Any] | None, path: str = "root") -> Any:
|
||||
"""Get a default value based on JSON schema type."""
|
||||
normalized_type = _resolve_schema_type(schema_type, path)
|
||||
if normalized_type is None:
|
||||
return None
|
||||
|
||||
factory = _TYPE_DEFAULT_FACTORIES.get(normalized_type)
|
||||
return factory() if callable(factory) else None
|
||||
|
||||
|
||||
def fill_missing_fields(data: Any, schema: dict[str, Any] | list | str | None, path: str = "root") -> Any:
|
||||
"""
|
||||
Recursively fill missing fields in data based on the schema.
|
||||
|
||||
Args:
|
||||
data: The data to validate and fill
|
||||
schema: The JSON schema to validate against
|
||||
path: Current path in the data structure (for logging)
|
||||
|
||||
Returns:
|
||||
The data with missing fields filled with default values
|
||||
"""
|
||||
if schema is None:
|
||||
return data
|
||||
|
||||
if isinstance(schema, (str, list)):
|
||||
LOG.debug("Schema is permissive", path=path, schema=schema)
|
||||
return data
|
||||
|
||||
schema_type = _resolve_schema_type(schema.get("type"), path)
|
||||
raw_schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "null" and data is None:
|
||||
LOG.debug("Data is None and schema allows null type, keeping as None", path=path)
|
||||
return None
|
||||
|
||||
# Check if null is allowed in the schema type
|
||||
is_nullable = isinstance(raw_schema_type, list) and "null" in raw_schema_type
|
||||
|
||||
if schema_type == "object" or "properties" in schema:
|
||||
# If data is None and schema allows null, keep it as None
|
||||
if data is None and is_nullable:
|
||||
LOG.debug("Data is None and schema allows null, keeping as None", path=path)
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
LOG.warning(
|
||||
"Expected object but got different type, creating empty object",
|
||||
path=path,
|
||||
data_type=type(data).__name__,
|
||||
)
|
||||
data = {}
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_path = f"{path}.{field_name}"
|
||||
|
||||
if field_name not in data:
|
||||
if field_name in required_fields:
|
||||
default_value = field_schema.get(
|
||||
"default", get_default_value_for_type(field_schema.get("type"), field_path)
|
||||
)
|
||||
LOG.info(
|
||||
"Filling missing required field with default value",
|
||||
path=field_path,
|
||||
default_value=default_value,
|
||||
)
|
||||
data[field_name] = default_value
|
||||
else:
|
||||
LOG.debug("Skipping optional missing field", path=field_path)
|
||||
continue
|
||||
|
||||
data[field_name] = fill_missing_fields(data[field_name], field_schema, field_path)
|
||||
|
||||
return data
|
||||
|
||||
if schema_type == "array":
|
||||
# If data is None and schema allows null, keep it as None
|
||||
if data is None and is_nullable:
|
||||
LOG.debug("Data is None and schema allows null, keeping as None", path=path)
|
||||
return None
|
||||
|
||||
if not isinstance(data, list):
|
||||
LOG.warning(
|
||||
"Expected array but got different type, creating empty array",
|
||||
path=path,
|
||||
data_type=type(data).__name__,
|
||||
)
|
||||
return []
|
||||
|
||||
items_schema = schema.get("items")
|
||||
if not items_schema:
|
||||
return data
|
||||
|
||||
return [fill_missing_fields(item, items_schema, f"{path}[{idx}]") for idx, item in enumerate(data)]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def validate_schema(schema: dict[str, Any] | list | str | None) -> bool:
|
||||
"""
|
||||
Validate that the schema itself is a valid JSON Schema.
|
||||
|
||||
Args:
|
||||
schema: The JSON schema to validate
|
||||
|
||||
Returns:
|
||||
True if the schema is valid, False otherwise
|
||||
"""
|
||||
if schema is None or isinstance(schema, (str, list)):
|
||||
return True
|
||||
|
||||
try:
|
||||
Draft202012Validator.check_schema(schema)
|
||||
return True
|
||||
except SchemaError as e:
|
||||
LOG.warning("Invalid JSON schema, will return data as-is", error=str(e), schema=schema)
|
||||
return False
|
||||
|
||||
|
||||
def validate_data_against_schema(data: Any, schema: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Validate data against a JSON schema using Draft202012Validator.
|
||||
|
||||
Args:
|
||||
data: The data to validate
|
||||
schema: The JSON schema to validate against
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid)
|
||||
"""
|
||||
validator = Draft202012Validator(schema)
|
||||
errors = []
|
||||
|
||||
for error in validator.iter_errors(data):
|
||||
error_path = ".".join(str(p) for p in error.path) if error.path else "root"
|
||||
errors.append(f"{error_path}: {error.message}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _is_all_default_values(data: dict[str, Any], schema: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if a dict contains only default values (indicating it was created from invalid data).
|
||||
|
||||
Args:
|
||||
data: The data object to check
|
||||
schema: The schema defining the expected structure
|
||||
|
||||
Returns:
|
||||
True if all values are defaults, False otherwise
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
if not properties:
|
||||
return False
|
||||
|
||||
# Check each property against its default value
|
||||
for field_name, field_schema in properties.items():
|
||||
if field_name not in data:
|
||||
continue
|
||||
|
||||
field_value = data[field_name]
|
||||
field_type = _resolve_schema_type(field_schema.get("type"), f"check.{field_name}")
|
||||
default_value = get_default_value_for_type(field_type)
|
||||
|
||||
# If any field has a non-default value, the record is meaningful
|
||||
if field_value != default_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _filter_invalid_array_items(data: list[Any], schema: dict[str, Any]) -> list[Any]:
|
||||
"""
|
||||
Filter out array items that are all default values (created from invalid data like strings).
|
||||
|
||||
Args:
|
||||
data: The array data to filter
|
||||
schema: The array schema
|
||||
|
||||
Returns:
|
||||
Filtered array with invalid items removed
|
||||
"""
|
||||
items_schema = schema.get("items")
|
||||
if not items_schema or not isinstance(items_schema, dict):
|
||||
return data
|
||||
|
||||
# Only filter if items are objects
|
||||
if items_schema.get("type") not in ("object", ["object", "null"]):
|
||||
return data
|
||||
|
||||
filtered = []
|
||||
removed_count = 0
|
||||
|
||||
for item in data:
|
||||
if isinstance(item, dict) and _is_all_default_values(item, items_schema):
|
||||
removed_count += 1
|
||||
LOG.info("Filtering out invalid array item with all default values", item=item)
|
||||
else:
|
||||
filtered.append(item)
|
||||
|
||||
if removed_count > 0:
|
||||
LOG.info(f"Removed {removed_count} invalid array items")
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def validate_and_fill_extraction_result(
|
||||
extraction_result: dict[str, Any],
|
||||
schema: dict[str, Any] | list | str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate extraction result against schema and fill missing fields with defaults.
|
||||
|
||||
This function handles malformed JSON responses from LLMs by:
|
||||
1. Validating the schema itself is valid JSON Schema (returns data as-is if invalid)
|
||||
2. Filling in missing required fields with appropriate default values
|
||||
3. Validating the filled structure against the provided schema using jsonschema
|
||||
4. Preserving optional fields that are present
|
||||
|
||||
Args:
|
||||
extraction_result: The extraction result from the LLM
|
||||
schema: The JSON schema that defines the expected structure
|
||||
|
||||
Returns:
|
||||
The validated and filled extraction result, or the original data if schema is invalid
|
||||
"""
|
||||
if schema is None:
|
||||
LOG.debug("No schema provided, returning extraction result as-is")
|
||||
return extraction_result
|
||||
|
||||
if not validate_schema(schema):
|
||||
LOG.info("Schema is invalid, returning extraction result as-is without transformations")
|
||||
return extraction_result
|
||||
|
||||
LOG.info("Validating and filling extraction result against schema")
|
||||
|
||||
try:
|
||||
filled_result = fill_missing_fields(extraction_result, schema)
|
||||
|
||||
# Filter out invalid array items if the schema is for an array
|
||||
if isinstance(schema, dict) and schema.get("type") == "array" and isinstance(filled_result, list):
|
||||
filled_result = _filter_invalid_array_items(filled_result, schema)
|
||||
|
||||
if isinstance(schema, dict):
|
||||
validation_errors = validate_data_against_schema(filled_result, schema)
|
||||
if validation_errors:
|
||||
LOG.warning(
|
||||
"Validation errors found after filling",
|
||||
errors=validation_errors,
|
||||
)
|
||||
|
||||
LOG.info("Successfully validated and filled extraction result")
|
||||
return filled_result
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"Failed to validate and fill extraction result",
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
return extraction_result
|
||||
Reference in New Issue
Block a user