Add schema validation and default value filling for extraction results (#4063)
Co-authored-by: Suchintan <suchintan@users.noreply.github.com>
This commit is contained in:
@@ -65,6 +65,7 @@ dependencies = [
|
||||
"pandas>=2.3.1,<3",
|
||||
"azure-identity>=1.24.0,<2",
|
||||
"azure-keyvault-secrets>=4.2.0,<5",
|
||||
"jsonschema>=4.25.1",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@@ -14,6 +14,7 @@ from skyvern.core.script_generations.skyvern_page_ai import SkyvernPageAi
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.files import validate_download_url
|
||||
from skyvern.forge.sdk.api.llm.schema_validator import validate_and_fill_extraction_result
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.schemas.totp_codes import OTPType
|
||||
from skyvern.services.otp_service import poll_otp_value
|
||||
@@ -540,6 +541,14 @@ class RealSkyvernPageAi(SkyvernPageAi):
|
||||
screenshots=self.scraped_page.screenshots,
|
||||
prompt_name="extract-information",
|
||||
)
|
||||
|
||||
# Validate and fill missing fields based on schema
|
||||
if schema:
|
||||
result = validate_and_fill_extraction_result(
|
||||
extraction_result=result,
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
if context and context.script_mode:
|
||||
print(f"\n✨ 📊 Extracted Information:\n{'-' * 50}")
|
||||
|
||||
|
||||
@@ -858,3 +858,10 @@ class AzureConfigurationError(AzureBaseError):
|
||||
class ScriptTerminationException(SkyvernException):
|
||||
def __init__(self, reason: str | None = None) -> None:
|
||||
super().__init__(reason)
|
||||
|
||||
|
||||
class InvalidSchemaError(SkyvernException):
|
||||
def __init__(self, message: str, validation_errors: list[str] | None = None):
|
||||
self.message = message
|
||||
self.validation_errors = validation_errors or []
|
||||
super().__init__(self.message)
|
||||
|
||||
304
skyvern/forge/sdk/api/llm/schema_validator.py
Normal file
304
skyvern/forge/sdk/api/llm/schema_validator.py
Normal file
@@ -0,0 +1,304 @@
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from jsonschema import Draft202012Validator
|
||||
from jsonschema.exceptions import SchemaError
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
_TYPE_DEFAULT_FACTORIES: dict[str, Any] = {
|
||||
"string": lambda: None,
|
||||
"number": lambda: 0,
|
||||
"integer": lambda: 0,
|
||||
"boolean": lambda: False,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
"null": lambda: None,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_schema_type(schema_type: str | list[Any] | None, path: str) -> str | None:
|
||||
"""Normalize a schema type definition to a single string value."""
|
||||
if isinstance(schema_type, list):
|
||||
non_null_types = [str(t).lower() for t in schema_type if str(t).lower() != "null"]
|
||||
if not non_null_types:
|
||||
return "null"
|
||||
|
||||
if len(non_null_types) > 1:
|
||||
LOG.warning(
|
||||
"Multiple non-null types in schema, using first one",
|
||||
path=path,
|
||||
types=non_null_types,
|
||||
)
|
||||
return non_null_types[0]
|
||||
|
||||
return str(schema_type).lower() if schema_type is not None else None
|
||||
|
||||
|
||||
def get_default_value_for_type(schema_type: str | list[Any] | None, path: str = "root") -> Any:
|
||||
"""Get a default value based on JSON schema type."""
|
||||
normalized_type = _resolve_schema_type(schema_type, path)
|
||||
if normalized_type is None:
|
||||
return None
|
||||
|
||||
factory = _TYPE_DEFAULT_FACTORIES.get(normalized_type)
|
||||
return factory() if callable(factory) else None
|
||||
|
||||
|
||||
def fill_missing_fields(data: Any, schema: dict[str, Any] | list | str | None, path: str = "root") -> Any:
|
||||
"""
|
||||
Recursively fill missing fields in data based on the schema.
|
||||
|
||||
Args:
|
||||
data: The data to validate and fill
|
||||
schema: The JSON schema to validate against
|
||||
path: Current path in the data structure (for logging)
|
||||
|
||||
Returns:
|
||||
The data with missing fields filled with default values
|
||||
"""
|
||||
if schema is None:
|
||||
return data
|
||||
|
||||
if isinstance(schema, (str, list)):
|
||||
LOG.debug("Schema is permissive", path=path, schema=schema)
|
||||
return data
|
||||
|
||||
schema_type = _resolve_schema_type(schema.get("type"), path)
|
||||
raw_schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "null" and data is None:
|
||||
LOG.debug("Data is None and schema allows null type, keeping as None", path=path)
|
||||
return None
|
||||
|
||||
# Check if null is allowed in the schema type
|
||||
is_nullable = isinstance(raw_schema_type, list) and "null" in raw_schema_type
|
||||
|
||||
if schema_type == "object" or "properties" in schema:
|
||||
# If data is None and schema allows null, keep it as None
|
||||
if data is None and is_nullable:
|
||||
LOG.debug("Data is None and schema allows null, keeping as None", path=path)
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
LOG.warning(
|
||||
"Expected object but got different type, creating empty object",
|
||||
path=path,
|
||||
data_type=type(data).__name__,
|
||||
)
|
||||
data = {}
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_path = f"{path}.{field_name}"
|
||||
|
||||
if field_name not in data:
|
||||
if field_name in required_fields:
|
||||
default_value = field_schema.get(
|
||||
"default", get_default_value_for_type(field_schema.get("type"), field_path)
|
||||
)
|
||||
LOG.info(
|
||||
"Filling missing required field with default value",
|
||||
path=field_path,
|
||||
default_value=default_value,
|
||||
)
|
||||
data[field_name] = default_value
|
||||
else:
|
||||
LOG.debug("Skipping optional missing field", path=field_path)
|
||||
continue
|
||||
|
||||
data[field_name] = fill_missing_fields(data[field_name], field_schema, field_path)
|
||||
|
||||
return data
|
||||
|
||||
if schema_type == "array":
|
||||
# If data is None and schema allows null, keep it as None
|
||||
if data is None and is_nullable:
|
||||
LOG.debug("Data is None and schema allows null, keeping as None", path=path)
|
||||
return None
|
||||
|
||||
if not isinstance(data, list):
|
||||
LOG.warning(
|
||||
"Expected array but got different type, creating empty array",
|
||||
path=path,
|
||||
data_type=type(data).__name__,
|
||||
)
|
||||
return []
|
||||
|
||||
items_schema = schema.get("items")
|
||||
if not items_schema:
|
||||
return data
|
||||
|
||||
return [fill_missing_fields(item, items_schema, f"{path}[{idx}]") for idx, item in enumerate(data)]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def validate_schema(schema: dict[str, Any] | list | str | None) -> bool:
|
||||
"""
|
||||
Validate that the schema itself is a valid JSON Schema.
|
||||
|
||||
Args:
|
||||
schema: The JSON schema to validate
|
||||
|
||||
Returns:
|
||||
True if the schema is valid, False otherwise
|
||||
"""
|
||||
if schema is None or isinstance(schema, (str, list)):
|
||||
return True
|
||||
|
||||
try:
|
||||
Draft202012Validator.check_schema(schema)
|
||||
return True
|
||||
except SchemaError as e:
|
||||
LOG.warning("Invalid JSON schema, will return data as-is", error=str(e), schema=schema)
|
||||
return False
|
||||
|
||||
|
||||
def validate_data_against_schema(data: Any, schema: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Validate data against a JSON schema using Draft202012Validator.
|
||||
|
||||
Args:
|
||||
data: The data to validate
|
||||
schema: The JSON schema to validate against
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid)
|
||||
"""
|
||||
validator = Draft202012Validator(schema)
|
||||
errors = []
|
||||
|
||||
for error in validator.iter_errors(data):
|
||||
error_path = ".".join(str(p) for p in error.path) if error.path else "root"
|
||||
errors.append(f"{error_path}: {error.message}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _is_all_default_values(data: dict[str, Any], schema: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if a dict contains only default values (indicating it was created from invalid data).
|
||||
|
||||
Args:
|
||||
data: The data object to check
|
||||
schema: The schema defining the expected structure
|
||||
|
||||
Returns:
|
||||
True if all values are defaults, False otherwise
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
if not properties:
|
||||
return False
|
||||
|
||||
# Check each property against its default value
|
||||
for field_name, field_schema in properties.items():
|
||||
if field_name not in data:
|
||||
continue
|
||||
|
||||
field_value = data[field_name]
|
||||
field_type = _resolve_schema_type(field_schema.get("type"), f"check.{field_name}")
|
||||
default_value = get_default_value_for_type(field_type)
|
||||
|
||||
# If any field has a non-default value, the record is meaningful
|
||||
if field_value != default_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _filter_invalid_array_items(data: list[Any], schema: dict[str, Any]) -> list[Any]:
|
||||
"""
|
||||
Filter out array items that are all default values (created from invalid data like strings).
|
||||
|
||||
Args:
|
||||
data: The array data to filter
|
||||
schema: The array schema
|
||||
|
||||
Returns:
|
||||
Filtered array with invalid items removed
|
||||
"""
|
||||
items_schema = schema.get("items")
|
||||
if not items_schema or not isinstance(items_schema, dict):
|
||||
return data
|
||||
|
||||
# Only filter if items are objects
|
||||
if items_schema.get("type") not in ("object", ["object", "null"]):
|
||||
return data
|
||||
|
||||
filtered = []
|
||||
removed_count = 0
|
||||
|
||||
for item in data:
|
||||
if isinstance(item, dict) and _is_all_default_values(item, items_schema):
|
||||
removed_count += 1
|
||||
LOG.info("Filtering out invalid array item with all default values", item=item)
|
||||
else:
|
||||
filtered.append(item)
|
||||
|
||||
if removed_count > 0:
|
||||
LOG.info(f"Removed {removed_count} invalid array items")
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def validate_and_fill_extraction_result(
|
||||
extraction_result: dict[str, Any],
|
||||
schema: dict[str, Any] | list | str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate extraction result against schema and fill missing fields with defaults.
|
||||
|
||||
This function handles malformed JSON responses from LLMs by:
|
||||
1. Validating the schema itself is valid JSON Schema (returns data as-is if invalid)
|
||||
2. Filling in missing required fields with appropriate default values
|
||||
3. Validating the filled structure against the provided schema using jsonschema
|
||||
4. Preserving optional fields that are present
|
||||
|
||||
Args:
|
||||
extraction_result: The extraction result from the LLM
|
||||
schema: The JSON schema that defines the expected structure
|
||||
|
||||
Returns:
|
||||
The validated and filled extraction result, or the original data if schema is invalid
|
||||
"""
|
||||
if schema is None:
|
||||
LOG.debug("No schema provided, returning extraction result as-is")
|
||||
return extraction_result
|
||||
|
||||
if not validate_schema(schema):
|
||||
LOG.info("Schema is invalid, returning extraction result as-is without transformations")
|
||||
return extraction_result
|
||||
|
||||
LOG.info("Validating and filling extraction result against schema")
|
||||
|
||||
try:
|
||||
filled_result = fill_missing_fields(extraction_result, schema)
|
||||
|
||||
# Filter out invalid array items if the schema is for an array
|
||||
if isinstance(schema, dict) and schema.get("type") == "array" and isinstance(filled_result, list):
|
||||
filled_result = _filter_invalid_array_items(filled_result, schema)
|
||||
|
||||
if isinstance(schema, dict):
|
||||
validation_errors = validate_data_against_schema(filled_result, schema)
|
||||
if validation_errors:
|
||||
LOG.warning(
|
||||
"Validation errors found after filling",
|
||||
errors=validation_errors,
|
||||
)
|
||||
|
||||
LOG.info("Successfully validated and filled extraction result")
|
||||
return filled_result
|
||||
except Exception as e:
|
||||
LOG.error(
|
||||
"Failed to validate and fill extraction result",
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
return extraction_result
|
||||
@@ -67,6 +67,7 @@ from skyvern.forge.sdk.api.files import (
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCallerManager
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||
from skyvern.forge.sdk.api.llm.schema_validator import validate_and_fill_extraction_result
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.skyvern_context import current as skyvern_current
|
||||
from skyvern.forge.sdk.core.skyvern_context import ensure_context
|
||||
@@ -3799,6 +3800,13 @@ async def extract_information_for_navigation_goal(
|
||||
force_dict=False,
|
||||
)
|
||||
|
||||
# Validate and fill missing fields based on schema
|
||||
if task.extracted_information_schema:
|
||||
json_response = validate_and_fill_extraction_result(
|
||||
extraction_result=json_response,
|
||||
schema=task.extracted_information_schema,
|
||||
)
|
||||
|
||||
return ScrapeResult(
|
||||
scraped_data=json_response,
|
||||
)
|
||||
|
||||
520
tests/unit_tests/test_schema_validator.py
Normal file
520
tests/unit_tests/test_schema_validator.py
Normal file
@@ -0,0 +1,520 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.api.llm.schema_validator import (
|
||||
fill_missing_fields,
|
||||
get_default_value_for_type,
|
||||
validate_and_fill_extraction_result,
|
||||
validate_data_against_schema,
|
||||
validate_schema,
|
||||
)
|
||||
|
||||
|
||||
class TestSchemaValidator:
|
||||
@pytest.fixture
|
||||
def medication_schema(self) -> dict[str, Any]:
|
||||
"""Schema for medication extraction data."""
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Medication Name": {"type": "string"},
|
||||
"NDC": {"type": "string"},
|
||||
"quantity": {"type": "string"},
|
||||
"facility": {"type": "string"},
|
||||
"recoverydate": {"type": ["string", "null"]},
|
||||
"isAllocation": {"type": "boolean"},
|
||||
"ErrorMessage": {"type": ["string", "null"]},
|
||||
},
|
||||
"required": [
|
||||
"Medication Name",
|
||||
"NDC",
|
||||
"quantity",
|
||||
"facility",
|
||||
"recoverydate",
|
||||
"isAllocation",
|
||||
"ErrorMessage",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def complete_medication_data(self) -> list[dict[str, Any]]:
|
||||
"""Complete medication data with all required fields."""
|
||||
return [
|
||||
{
|
||||
"Medication Name": "ACETAMINOPHEN 1000MG-30ML CHRY LIQ 237ML",
|
||||
"NDC": "00904-7481-59",
|
||||
"quantity": "0",
|
||||
"facility": "CHI-IL",
|
||||
"recoverydate": None,
|
||||
"isAllocation": False,
|
||||
"ErrorMessage": None,
|
||||
},
|
||||
{
|
||||
"Medication Name": "ACETAMINOPHEN CHERRY 160MG/5ML SOL 473ML",
|
||||
"NDC": "00904-7014-16",
|
||||
"quantity": "100",
|
||||
"facility": "CHI-IL",
|
||||
"recoverydate": None,
|
||||
"isAllocation": False,
|
||||
"ErrorMessage": None,
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def incomplete_medication_data(self) -> list[dict[str, Any]]:
|
||||
"""Incomplete medication data missing some required fields."""
|
||||
return [
|
||||
{
|
||||
"Medication Name": "ACETAMINOPHEN 1000MG-30ML CHRY LIQ 237ML",
|
||||
"NDC": "00904-7481-59",
|
||||
"quantity": "0",
|
||||
"facility": "CHI-IL",
|
||||
# Missing: recoverydate, isAllocation, ErrorMessage
|
||||
},
|
||||
{
|
||||
"Medication Name": "AMOXICILLIN 500MG CAPSULE 500",
|
||||
# Missing: NDC, quantity, facility, recoverydate, isAllocation, ErrorMessage
|
||||
},
|
||||
]
|
||||
|
||||
def test_get_default_value_for_string(self) -> None:
|
||||
"""Test default value generation for string type."""
|
||||
assert get_default_value_for_type("string") is None
|
||||
|
||||
def test_get_default_value_for_boolean(self) -> None:
|
||||
"""Test default value generation for boolean type."""
|
||||
assert get_default_value_for_type("boolean") is False
|
||||
|
||||
def test_get_default_value_for_array(self) -> None:
|
||||
"""Test default value generation for array type."""
|
||||
assert get_default_value_for_type("array") == []
|
||||
|
||||
def test_get_default_value_for_object(self) -> None:
|
||||
"""Test default value generation for object type."""
|
||||
assert get_default_value_for_type("object") == {}
|
||||
|
||||
def test_get_default_value_for_null(self) -> None:
|
||||
"""Test default value generation for null type."""
|
||||
assert get_default_value_for_type("null") is None
|
||||
|
||||
def test_get_default_value_for_type_list_with_null(self) -> None:
|
||||
"""Test default value generation for type list containing null."""
|
||||
assert get_default_value_for_type(["string", "null"]) is None
|
||||
assert get_default_value_for_type(["null", "string"]) is None
|
||||
|
||||
def test_get_default_value_for_type_list_all_null(self) -> None:
|
||||
"""Test default value generation for type list with only null."""
|
||||
assert get_default_value_for_type(["null"]) is None
|
||||
|
||||
def test_get_default_value_for_uppercase_type(self) -> None:
|
||||
"""Test default value generation for uppercase type names."""
|
||||
assert get_default_value_for_type("STRING") is None
|
||||
assert get_default_value_for_type("NUMBER") == 0
|
||||
assert get_default_value_for_type("INTEGER") == 0
|
||||
assert get_default_value_for_type("BOOLEAN") is False
|
||||
assert get_default_value_for_type("ARRAY") == []
|
||||
assert get_default_value_for_type("OBJECT") == {}
|
||||
assert get_default_value_for_type("NULL") is None
|
||||
|
||||
def test_get_default_value_for_mixed_case_type(self) -> None:
|
||||
"""Test default value generation for mixed case type names."""
|
||||
assert get_default_value_for_type("String") is None
|
||||
assert get_default_value_for_type("Boolean") is False
|
||||
assert get_default_value_for_type(["STRING", "null"]) is None
|
||||
assert get_default_value_for_type(["NULL", "STRING"]) is None
|
||||
|
||||
def test_fill_missing_fields_complete_data(
|
||||
self, medication_schema: dict[str, Any], complete_medication_data: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Test that complete data passes through unchanged."""
|
||||
result = fill_missing_fields(complete_medication_data, medication_schema)
|
||||
assert result == complete_medication_data
|
||||
|
||||
def test_fill_missing_fields_incomplete_data(
|
||||
self, medication_schema: dict[str, Any], incomplete_medication_data: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Test that missing required fields are filled with defaults."""
|
||||
result = fill_missing_fields(incomplete_medication_data, medication_schema)
|
||||
|
||||
# First item should have missing fields filled
|
||||
assert result[0]["Medication Name"] == "ACETAMINOPHEN 1000MG-30ML CHRY LIQ 237ML"
|
||||
assert result[0]["NDC"] == "00904-7481-59"
|
||||
assert result[0]["quantity"] == "0"
|
||||
assert result[0]["facility"] == "CHI-IL"
|
||||
assert result[0]["recoverydate"] is None # Default for ["string", "null"]
|
||||
assert result[0]["isAllocation"] is False # Default for boolean
|
||||
assert result[0]["ErrorMessage"] is None # Default for ["string", "null"]
|
||||
|
||||
# Second item should have all missing fields filled
|
||||
assert result[1]["Medication Name"] == "AMOXICILLIN 500MG CAPSULE 500"
|
||||
assert result[1]["NDC"] is None # Default for string
|
||||
assert result[1]["quantity"] is None # Default for string
|
||||
assert result[1]["facility"] is None # Default for string
|
||||
assert result[1]["recoverydate"] is None # Default for ["string", "null"]
|
||||
assert result[1]["isAllocation"] is False # Default for boolean
|
||||
assert result[1]["ErrorMessage"] is None # Default for ["string", "null"]
|
||||
|
||||
def test_fill_missing_fields_with_error_message(self, medication_schema: dict[str, Any]) -> None:
|
||||
"""Test filling fields when ErrorMessage has a value."""
|
||||
data = [
|
||||
{
|
||||
"Medication Name": "TEST MEDICATION",
|
||||
"NDC": "12345-678-90",
|
||||
"quantity": "50",
|
||||
"facility": "TEST-FACILITY",
|
||||
"recoverydate": "2024-01-01",
|
||||
"isAllocation": True,
|
||||
"ErrorMessage": "Some error occurred",
|
||||
}
|
||||
]
|
||||
|
||||
result = fill_missing_fields(data, medication_schema)
|
||||
assert result[0]["ErrorMessage"] == "Some error occurred"
|
||||
|
||||
def test_fill_missing_fields_empty_array(self, medication_schema: dict[str, Any]) -> None:
|
||||
"""Test handling of empty array."""
|
||||
result = fill_missing_fields([], medication_schema)
|
||||
assert result == []
|
||||
|
||||
def test_fill_missing_fields_invalid_data_type(self, medication_schema: dict[str, Any]) -> None:
|
||||
"""Test handling when data is not an array."""
|
||||
# When data is not a list, it should be converted to empty array
|
||||
result = fill_missing_fields("not an array", medication_schema)
|
||||
assert result == []
|
||||
|
||||
def test_fill_missing_fields_nested_object_missing_fields(self, medication_schema: dict[str, Any]) -> None:
|
||||
"""Test that nested objects have missing fields filled."""
|
||||
data = [
|
||||
{
|
||||
"Medication Name": "TEST MEDICATION",
|
||||
# All other fields missing
|
||||
}
|
||||
]
|
||||
|
||||
result = fill_missing_fields(data, medication_schema)
|
||||
assert len(result) == 1
|
||||
assert "NDC" in result[0]
|
||||
assert "quantity" in result[0]
|
||||
assert "facility" in result[0]
|
||||
assert "recoverydate" in result[0]
|
||||
assert "isAllocation" in result[0]
|
||||
assert "ErrorMessage" in result[0]
|
||||
|
||||
def test_validate_and_fill_extraction_result_with_schema(
|
||||
self, medication_schema: dict[str, Any], incomplete_medication_data: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Test validate_and_fill_extraction_result with medication schema."""
|
||||
result = validate_and_fill_extraction_result(incomplete_medication_data, medication_schema)
|
||||
|
||||
# Verify all required fields are present in all items
|
||||
for item in result:
|
||||
assert "Medication Name" in item
|
||||
assert "NDC" in item
|
||||
assert "quantity" in item
|
||||
assert "facility" in item
|
||||
assert "recoverydate" in item
|
||||
assert "isAllocation" in item
|
||||
assert "ErrorMessage" in item
|
||||
|
||||
def test_validate_and_fill_extraction_result_no_schema(self) -> None:
|
||||
"""Test that data passes through unchanged when no schema is provided."""
|
||||
data = {"some": "data"}
|
||||
result = validate_and_fill_extraction_result(data, None)
|
||||
assert result == data
|
||||
|
||||
def test_validate_and_fill_extraction_result_with_exception(self, medication_schema: dict[str, Any]) -> None:
|
||||
"""Test that original data is returned if validation fails."""
|
||||
# This should not raise an exception, but return original data
|
||||
invalid_data = "not a valid structure"
|
||||
result = validate_and_fill_extraction_result(invalid_data, medication_schema)
|
||||
# Should return empty array since invalid_data gets converted
|
||||
assert result == []
|
||||
|
||||
def test_fill_missing_fields_preserves_existing_values(self, medication_schema: dict[str, Any]) -> None:
|
||||
"""Test that existing values are preserved and not overwritten."""
|
||||
data = [
|
||||
{
|
||||
"Medication Name": "EXISTING NAME",
|
||||
"NDC": "EXISTING-NDC",
|
||||
"quantity": "999",
|
||||
"facility": "EXISTING-FACILITY",
|
||||
"recoverydate": "2024-12-31",
|
||||
"isAllocation": True,
|
||||
"ErrorMessage": "Existing error",
|
||||
}
|
||||
]
|
||||
|
||||
result = fill_missing_fields(data, medication_schema)
|
||||
|
||||
# All original values should be preserved
|
||||
assert result[0]["Medication Name"] == "EXISTING NAME"
|
||||
assert result[0]["NDC"] == "EXISTING-NDC"
|
||||
assert result[0]["quantity"] == "999"
|
||||
assert result[0]["facility"] == "EXISTING-FACILITY"
|
||||
assert result[0]["recoverydate"] == "2024-12-31"
|
||||
assert result[0]["isAllocation"] is True
|
||||
assert result[0]["ErrorMessage"] == "Existing error"
|
||||
|
||||
def test_fill_missing_fields_nullable_object_with_null(self) -> None:
|
||||
"""Test handling of nullable object type when data is null."""
|
||||
schema = {
|
||||
"type": ["object", "null"],
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
|
||||
# When data is null, it should remain null (valid for nullable type)
|
||||
result = fill_missing_fields(None, schema)
|
||||
assert result is None
|
||||
|
||||
def test_fill_missing_fields_nullable_object_with_object(self) -> None:
|
||||
"""Test handling of nullable object type when data is an object with missing fields."""
|
||||
schema = {
|
||||
"type": ["object", "null"],
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
|
||||
# When data is an object (not null), missing required fields should be filled
|
||||
data = {"name": "John"} # Missing 'age'
|
||||
result = fill_missing_fields(data, schema)
|
||||
|
||||
# With the fix, missing required fields should be filled
|
||||
assert result == {"name": "John", "age": 0}
|
||||
assert "name" in result
|
||||
assert "age" in result
|
||||
|
||||
def test_fill_missing_fields_nullable_array_with_null(self) -> None:
|
||||
"""Test handling of nullable array type when data is null."""
|
||||
schema = {
|
||||
"type": ["array", "null"],
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
},
|
||||
"required": ["id"],
|
||||
},
|
||||
}
|
||||
|
||||
# When data is null, it should remain null
|
||||
result = fill_missing_fields(None, schema)
|
||||
assert result is None
|
||||
|
||||
def test_fill_missing_fields_nullable_array_with_array(self) -> None:
|
||||
"""Test handling of nullable array type when data is an array."""
|
||||
schema = {
|
||||
"type": ["array", "null"],
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["id", "name"],
|
||||
},
|
||||
}
|
||||
|
||||
# When data is an array (not null), items should be validated
|
||||
data = [{"id": "1"}] # Missing 'name'
|
||||
result = fill_missing_fields(data, schema)
|
||||
|
||||
# With the fix, missing fields in array items should be filled
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"id": "1", "name": None}
|
||||
assert "id" in result[0]
|
||||
assert "name" in result[0]
|
||||
|
||||
def test_validate_schema_valid(self) -> None:
|
||||
"""Test that valid schemas return True."""
|
||||
valid_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
assert validate_schema(valid_schema) is True
|
||||
|
||||
def test_validate_schema_invalid(self) -> None:
|
||||
"""Test that invalid schemas return False."""
|
||||
invalid_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": "not_a_number", # Should be a number
|
||||
}
|
||||
},
|
||||
}
|
||||
# Should return False for invalid schema
|
||||
assert validate_schema(invalid_schema) is False
|
||||
|
||||
def test_validate_schema_none(self) -> None:
|
||||
"""Test that None schema is considered valid."""
|
||||
assert validate_schema(None) is True
|
||||
|
||||
def test_validate_schema_string(self) -> None:
|
||||
"""Test that string schema is considered valid (permissive)."""
|
||||
assert validate_schema("some_string") is True
|
||||
|
||||
def test_validate_schema_list(self) -> None:
|
||||
"""Test that list schema is considered valid (permissive)."""
|
||||
assert validate_schema([]) is True
|
||||
|
||||
def test_validate_data_against_schema_valid(self) -> None:
|
||||
"""Test validation of valid data against schema."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
data = {"name": "John", "age": 30}
|
||||
errors = validate_data_against_schema(data, schema)
|
||||
assert errors == []
|
||||
|
||||
def test_validate_data_against_schema_missing_required(self) -> None:
|
||||
"""Test validation when required fields are missing."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
data = {"name": "John"} # Missing 'age'
|
||||
errors = validate_data_against_schema(data, schema)
|
||||
assert len(errors) > 0
|
||||
assert any("age" in error for error in errors)
|
||||
|
||||
def test_validate_data_against_schema_wrong_type(self) -> None:
|
||||
"""Test validation when data has wrong type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
data = {"name": "John", "age": "thirty"} # age should be integer
|
||||
errors = validate_data_against_schema(data, schema)
|
||||
assert len(errors) > 0
|
||||
assert any("age" in error for error in errors)
|
||||
|
||||
def test_validate_data_against_schema_array(self) -> None:
|
||||
"""Test validation of array data."""
|
||||
schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
},
|
||||
"required": ["id"],
|
||||
},
|
||||
}
|
||||
data = [{"id": "1"}, {"id": "2"}]
|
||||
errors = validate_data_against_schema(data, schema)
|
||||
assert errors == []
|
||||
|
||||
def test_validate_and_fill_with_jsonschema_validation(self, medication_schema: dict[str, Any]) -> None:
|
||||
"""Test that validate_and_fill uses jsonschema for validation."""
|
||||
# Data with all required fields filled correctly
|
||||
data = [
|
||||
{
|
||||
"Medication Name": "TEST MED",
|
||||
"NDC": "12345",
|
||||
"quantity": "10",
|
||||
"facility": "TEST",
|
||||
"recoverydate": None,
|
||||
"isAllocation": False,
|
||||
"ErrorMessage": None,
|
||||
}
|
||||
]
|
||||
|
||||
result = validate_and_fill_extraction_result(data, medication_schema)
|
||||
assert result == data
|
||||
|
||||
def test_validate_and_fill_with_invalid_schema(self) -> None:
|
||||
"""Test that validate_and_fill returns data as-is for invalid schemas."""
|
||||
# Create a schema that will fail Draft202012Validator.check_schema
|
||||
invalid_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": "not_a_number", # Should be a number
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
data = {"name": "test"}
|
||||
# Should return data as-is without transformations when schema is invalid
|
||||
result = validate_and_fill_extraction_result(data, invalid_schema)
|
||||
assert result == data
|
||||
|
||||
def test_filter_invalid_array_items_with_string(self, medication_schema: dict[str, Any]) -> None:
|
||||
"""Test that array items created from invalid data (strings) are filtered out."""
|
||||
# Simulate LLM response with a string mixed in the array
|
||||
data = [
|
||||
{
|
||||
"Medication Name": "ACETAMINOPHEN 500MG",
|
||||
"NDC": "12345-678-90",
|
||||
"quantity": "100",
|
||||
"facility": "TEST-FACILITY",
|
||||
"recoverydate": None,
|
||||
"isAllocation": False,
|
||||
"ErrorMessage": None,
|
||||
},
|
||||
"This is an invalid string that should be filtered out",
|
||||
{
|
||||
"Medication Name": "IBUPROFEN 200MG",
|
||||
"NDC": "98765-432-10",
|
||||
"quantity": "50",
|
||||
"facility": "TEST-FACILITY",
|
||||
"recoverydate": "2024-01-01",
|
||||
"isAllocation": True,
|
||||
"ErrorMessage": None,
|
||||
},
|
||||
]
|
||||
|
||||
result = validate_and_fill_extraction_result(data, medication_schema)
|
||||
|
||||
# Should have only 2 valid records (the string should be filtered out)
|
||||
assert len(result) == 2
|
||||
assert result[0]["Medication Name"] == "ACETAMINOPHEN 500MG"
|
||||
assert result[1]["Medication Name"] == "IBUPROFEN 200MG"
|
||||
|
||||
def test_filter_invalid_array_items_preserves_valid_defaults(self, medication_schema: dict[str, Any]) -> None:
|
||||
"""Test that records with some valid data are preserved even if some fields are defaults."""
|
||||
data = [
|
||||
{
|
||||
"Medication Name": "VALID MEDICATION",
|
||||
"NDC": "12345-678-90",
|
||||
# Missing other required fields - should be filled with defaults but NOT filtered
|
||||
}
|
||||
]
|
||||
|
||||
result = validate_and_fill_extraction_result(data, medication_schema)
|
||||
|
||||
# Should preserve the record because it has meaningful data
|
||||
assert len(result) == 1
|
||||
assert result[0]["Medication Name"] == "VALID MEDICATION"
|
||||
assert result[0]["NDC"] == "12345-678-90"
|
||||
assert result[0]["quantity"] is None # Filled with default
|
||||
assert result[0]["facility"] is None # Filled with default
|
||||
4
uv.lock
generated
4
uv.lock
generated
@@ -1,5 +1,4 @@
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.11, <3.14"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.13'",
|
||||
@@ -2095,7 +2094,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/13/3a/d61707803260d59520721fa326babfae25e9573a88d8b7b9cb54c5423a59/jiter-0.11.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:089f9df9f69532d1339e83142438668f52c97cd22ee2d1195551c2b1a9e6cf33", size = 313737 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/cc/c9f0eec5d00f2a1da89f6bdfac12b8afdf8d5ad974184863c75060026457/jiter-0.11.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29ed1fe69a8c69bf0f2a962d8d706c7b89b50f1332cd6b9fbda014f60bd03a03", size = 346183 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/87/fc632776344e7aabbab05a95a0075476f418c5d29ab0f2eec672b7a1f0ac/jiter-0.11.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a4d71d7ea6ea8786291423fe209acf6f8d398a0759d03e7f24094acb8ab686ba", size = 204225 },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/f3/ce100253c80063a7b8b406e1d1562657fd4b9b4e1b562db40e68645342fb/jiter-0.11.0-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:902b43386c04739229076bd1c4c69de5d115553d982ab442a8ae82947c72ede7", size = 336380 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4774,6 +4772,7 @@ dependencies = [
|
||||
{ name = "httpx", extra = ["socks"] },
|
||||
{ name = "jinja2" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "lark" },
|
||||
{ name = "libcst" },
|
||||
{ name = "litellm" },
|
||||
@@ -4865,6 +4864,7 @@ requires-dist = [
|
||||
{ name = "httpx", extras = ["socks"], specifier = ">=0.27.0" },
|
||||
{ name = "jinja2", specifier = ">=3.1.2,<4" },
|
||||
{ name = "json-repair", specifier = ">=0.34.0,<0.35" },
|
||||
{ name = "jsonschema", specifier = ">=4.25.1" },
|
||||
{ name = "lark", specifier = ">=1.2.2,<2" },
|
||||
{ name = "libcst", specifier = ">=1.8.2,<2" },
|
||||
{ name = "litellm", specifier = ">=1.75.8" },
|
||||
|
||||
Reference in New Issue
Block a user