Add schema validation and default value filling for extraction results (#4063)

Co-authored-by: Suchintan <suchintan@users.noreply.github.com>
This commit is contained in:
Mohamed Khalil
2025-11-28 09:10:41 +02:00
committed by GitHub
parent 362309dec3
commit b7ecdaafb7
7 changed files with 851 additions and 2 deletions

View File

@@ -0,0 +1,304 @@
from typing import Any
import structlog
from jsonschema import Draft202012Validator
from jsonschema.exceptions import SchemaError
LOG = structlog.get_logger()
_TYPE_DEFAULT_FACTORIES: dict[str, Any] = {
"string": lambda: None,
"number": lambda: 0,
"integer": lambda: 0,
"boolean": lambda: False,
"array": list,
"object": dict,
"null": lambda: None,
}
def _resolve_schema_type(schema_type: str | list[Any] | None, path: str) -> str | None:
"""Normalize a schema type definition to a single string value."""
if isinstance(schema_type, list):
non_null_types = [str(t).lower() for t in schema_type if str(t).lower() != "null"]
if not non_null_types:
return "null"
if len(non_null_types) > 1:
LOG.warning(
"Multiple non-null types in schema, using first one",
path=path,
types=non_null_types,
)
return non_null_types[0]
return str(schema_type).lower() if schema_type is not None else None
def get_default_value_for_type(schema_type: str | list[Any] | None, path: str = "root") -> Any:
"""Get a default value based on JSON schema type."""
normalized_type = _resolve_schema_type(schema_type, path)
if normalized_type is None:
return None
factory = _TYPE_DEFAULT_FACTORIES.get(normalized_type)
return factory() if callable(factory) else None
def fill_missing_fields(data: Any, schema: dict[str, Any] | list | str | None, path: str = "root") -> Any:
"""
Recursively fill missing fields in data based on the schema.
Args:
data: The data to validate and fill
schema: The JSON schema to validate against
path: Current path in the data structure (for logging)
Returns:
The data with missing fields filled with default values
"""
if schema is None:
return data
if isinstance(schema, (str, list)):
LOG.debug("Schema is permissive", path=path, schema=schema)
return data
schema_type = _resolve_schema_type(schema.get("type"), path)
raw_schema_type = schema.get("type")
if schema_type == "null" and data is None:
LOG.debug("Data is None and schema allows null type, keeping as None", path=path)
return None
# Check if null is allowed in the schema type
is_nullable = isinstance(raw_schema_type, list) and "null" in raw_schema_type
if schema_type == "object" or "properties" in schema:
# If data is None and schema allows null, keep it as None
if data is None and is_nullable:
LOG.debug("Data is None and schema allows null, keeping as None", path=path)
return None
if not isinstance(data, dict):
LOG.warning(
"Expected object but got different type, creating empty object",
path=path,
data_type=type(data).__name__,
)
data = {}
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
for field_name, field_schema in properties.items():
field_path = f"{path}.{field_name}"
if field_name not in data:
if field_name in required_fields:
default_value = field_schema.get(
"default", get_default_value_for_type(field_schema.get("type"), field_path)
)
LOG.info(
"Filling missing required field with default value",
path=field_path,
default_value=default_value,
)
data[field_name] = default_value
else:
LOG.debug("Skipping optional missing field", path=field_path)
continue
data[field_name] = fill_missing_fields(data[field_name], field_schema, field_path)
return data
if schema_type == "array":
# If data is None and schema allows null, keep it as None
if data is None and is_nullable:
LOG.debug("Data is None and schema allows null, keeping as None", path=path)
return None
if not isinstance(data, list):
LOG.warning(
"Expected array but got different type, creating empty array",
path=path,
data_type=type(data).__name__,
)
return []
items_schema = schema.get("items")
if not items_schema:
return data
return [fill_missing_fields(item, items_schema, f"{path}[{idx}]") for idx, item in enumerate(data)]
return data
def validate_schema(schema: dict[str, Any] | list | str | None) -> bool:
"""
Validate that the schema itself is a valid JSON Schema.
Args:
schema: The JSON schema to validate
Returns:
True if the schema is valid, False otherwise
"""
if schema is None or isinstance(schema, (str, list)):
return True
try:
Draft202012Validator.check_schema(schema)
return True
except SchemaError as e:
LOG.warning("Invalid JSON schema, will return data as-is", error=str(e), schema=schema)
return False
def validate_data_against_schema(data: Any, schema: dict[str, Any]) -> list[str]:
"""
Validate data against a JSON schema using Draft202012Validator.
Args:
data: The data to validate
schema: The JSON schema to validate against
Returns:
List of validation error messages (empty if valid)
"""
validator = Draft202012Validator(schema)
errors = []
for error in validator.iter_errors(data):
error_path = ".".join(str(p) for p in error.path) if error.path else "root"
errors.append(f"{error_path}: {error.message}")
return errors
def _is_all_default_values(data: dict[str, Any], schema: dict[str, Any]) -> bool:
"""
Check if a dict contains only default values (indicating it was created from invalid data).
Args:
data: The data object to check
schema: The schema defining the expected structure
Returns:
True if all values are defaults, False otherwise
"""
if not isinstance(data, dict):
return False
properties = schema.get("properties", {})
if not properties:
return False
# Check each property against its default value
for field_name, field_schema in properties.items():
if field_name not in data:
continue
field_value = data[field_name]
field_type = _resolve_schema_type(field_schema.get("type"), f"check.{field_name}")
default_value = get_default_value_for_type(field_type)
# If any field has a non-default value, the record is meaningful
if field_value != default_value:
return False
return True
def _filter_invalid_array_items(data: list[Any], schema: dict[str, Any]) -> list[Any]:
"""
Filter out array items that are all default values (created from invalid data like strings).
Args:
data: The array data to filter
schema: The array schema
Returns:
Filtered array with invalid items removed
"""
items_schema = schema.get("items")
if not items_schema or not isinstance(items_schema, dict):
return data
# Only filter if items are objects
if items_schema.get("type") not in ("object", ["object", "null"]):
return data
filtered = []
removed_count = 0
for item in data:
if isinstance(item, dict) and _is_all_default_values(item, items_schema):
removed_count += 1
LOG.info("Filtering out invalid array item with all default values", item=item)
else:
filtered.append(item)
if removed_count > 0:
LOG.info(f"Removed {removed_count} invalid array items")
return filtered
def validate_and_fill_extraction_result(
extraction_result: dict[str, Any],
schema: dict[str, Any] | list | str | None,
) -> dict[str, Any]:
"""
Validate extraction result against schema and fill missing fields with defaults.
This function handles malformed JSON responses from LLMs by:
1. Validating the schema itself is valid JSON Schema (returns data as-is if invalid)
2. Filling in missing required fields with appropriate default values
3. Validating the filled structure against the provided schema using jsonschema
4. Preserving optional fields that are present
Args:
extraction_result: The extraction result from the LLM
schema: The JSON schema that defines the expected structure
Returns:
The validated and filled extraction result, or the original data if schema is invalid
"""
if schema is None:
LOG.debug("No schema provided, returning extraction result as-is")
return extraction_result
if not validate_schema(schema):
LOG.info("Schema is invalid, returning extraction result as-is without transformations")
return extraction_result
LOG.info("Validating and filling extraction result against schema")
try:
filled_result = fill_missing_fields(extraction_result, schema)
# Filter out invalid array items if the schema is for an array
if isinstance(schema, dict) and schema.get("type") == "array" and isinstance(filled_result, list):
filled_result = _filter_invalid_array_items(filled_result, schema)
if isinstance(schema, dict):
validation_errors = validate_data_against_schema(filled_result, schema)
if validation_errors:
LOG.warning(
"Validation errors found after filling",
errors=validation_errors,
)
LOG.info("Successfully validated and filled extraction result")
return filled_result
except Exception as e:
LOG.error(
"Failed to validate and fill extraction result",
error=str(e),
exc_info=True,
)
return extraction_result