diff --git a/pyproject.toml b/pyproject.toml index 4eb839e0..ba13cfdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ dependencies = [ "pandas>=2.3.1,<3", "azure-identity>=1.24.0,<2", "azure-keyvault-secrets>=4.2.0,<5", + "jsonschema>=4.25.1", ] [dependency-groups] diff --git a/skyvern/core/script_generations/real_skyvern_page_ai.py b/skyvern/core/script_generations/real_skyvern_page_ai.py index 34ae0a28..5c002b35 100644 --- a/skyvern/core/script_generations/real_skyvern_page_ai.py +++ b/skyvern/core/script_generations/real_skyvern_page_ai.py @@ -14,6 +14,7 @@ from skyvern.core.script_generations.skyvern_page_ai import SkyvernPageAi from skyvern.forge import app from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.api.files import validate_download_url +from skyvern.forge.sdk.api.llm.schema_validator import validate_and_fill_extraction_result from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.schemas.totp_codes import OTPType from skyvern.services.otp_service import poll_otp_value @@ -540,6 +541,14 @@ class RealSkyvernPageAi(SkyvernPageAi): screenshots=self.scraped_page.screenshots, prompt_name="extract-information", ) + + # Validate and fill missing fields based on schema + if schema: + result = validate_and_fill_extraction_result( + extraction_result=result, + schema=schema, + ) + if context and context.script_mode: print(f"\n✨ 📊 Extracted Information:\n{'-' * 50}") diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 49eaeb05..780e2783 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -858,3 +858,10 @@ class AzureConfigurationError(AzureBaseError): class ScriptTerminationException(SkyvernException): def __init__(self, reason: str | None = None) -> None: super().__init__(reason) + + +class InvalidSchemaError(SkyvernException): + def __init__(self, message: str, validation_errors: list[str] | None = None): + self.message = message + self.validation_errors = validation_errors or [] + super().__init__(self.message) diff --git a/skyvern/forge/sdk/api/llm/schema_validator.py b/skyvern/forge/sdk/api/llm/schema_validator.py new file mode 100644 index 00000000..de17320e --- /dev/null +++ b/skyvern/forge/sdk/api/llm/schema_validator.py @@ -0,0 +1,304 @@ +from typing import Any + +import structlog +from jsonschema import Draft202012Validator +from jsonschema.exceptions import SchemaError + +LOG = structlog.get_logger() + + +_TYPE_DEFAULT_FACTORIES: dict[str, Any] = { + "string": lambda: None, + "number": lambda: 0, + "integer": lambda: 0, + "boolean": lambda: False, + "array": list, + "object": dict, + "null": lambda: None, +} + + +def _resolve_schema_type(schema_type: str | list[Any] | None, path: str) -> str | None: + """Normalize a schema type definition to a single string value.""" + if isinstance(schema_type, list): + non_null_types = [str(t).lower() for t in schema_type if str(t).lower() != "null"] + if not non_null_types: + return "null" + + if len(non_null_types) > 1: + LOG.warning( + "Multiple non-null types in schema, using first one", + path=path, + types=non_null_types, + ) + return non_null_types[0] + + return str(schema_type).lower() if schema_type is not None else None + + +def get_default_value_for_type(schema_type: str | list[Any] | None, path: str = "root") -> Any: + """Get a default value based on JSON schema type.""" + normalized_type = _resolve_schema_type(schema_type, path) + if normalized_type is None: + return None + + factory = _TYPE_DEFAULT_FACTORIES.get(normalized_type) + return factory() if callable(factory) else None + + +def fill_missing_fields(data: Any, schema: dict[str, Any] | list | str | None, path: str = "root") -> Any: + """ + Recursively fill missing fields in data based on the schema. + + Args: + data: The data to validate and fill + schema: The JSON schema to validate against + path: Current path in the data structure (for logging) + + Returns: + The data with missing fields filled with default values + """ + if schema is None: + return data + + if isinstance(schema, (str, list)): + LOG.debug("Schema is permissive", path=path, schema=schema) + return data + + schema_type = _resolve_schema_type(schema.get("type"), path) + raw_schema_type = schema.get("type") + + if schema_type == "null" and data is None: + LOG.debug("Data is None and schema allows null type, keeping as None", path=path) + return None + + # Check if null is allowed in the schema type + is_nullable = isinstance(raw_schema_type, list) and "null" in raw_schema_type + + if schema_type == "object" or "properties" in schema: + # If data is None and schema allows null, keep it as None + if data is None and is_nullable: + LOG.debug("Data is None and schema allows null, keeping as None", path=path) + return None + + if not isinstance(data, dict): + LOG.warning( + "Expected object but got different type, creating empty object", + path=path, + data_type=type(data).__name__, + ) + data = {} + + properties = schema.get("properties", {}) + required_fields = set(schema.get("required", [])) + + for field_name, field_schema in properties.items(): + field_path = f"{path}.{field_name}" + + if field_name not in data: + if field_name in required_fields: + default_value = field_schema.get( + "default", get_default_value_for_type(field_schema.get("type"), field_path) + ) + LOG.info( + "Filling missing required field with default value", + path=field_path, + default_value=default_value, + ) + data[field_name] = default_value + else: + LOG.debug("Skipping optional missing field", path=field_path) + continue + + data[field_name] = fill_missing_fields(data[field_name], field_schema, field_path) + + return data + + if schema_type == "array": + # If data is None and schema allows null, keep it as None + if data is None and is_nullable: + LOG.debug("Data is None and schema allows null, keeping as None", path=path) + return None + + if not isinstance(data, list): + LOG.warning( + "Expected array but got different type, creating empty array", + path=path, + data_type=type(data).__name__, + ) + return [] + + items_schema = schema.get("items") + if not items_schema: + return data + + return [fill_missing_fields(item, items_schema, f"{path}[{idx}]") for idx, item in enumerate(data)] + + return data + + +def validate_schema(schema: dict[str, Any] | list | str | None) -> bool: + """ + Validate that the schema itself is a valid JSON Schema. + + Args: + schema: The JSON schema to validate + + Returns: + True if the schema is valid, False otherwise + """ + if schema is None or isinstance(schema, (str, list)): + return True + + try: + Draft202012Validator.check_schema(schema) + return True + except SchemaError as e: + LOG.warning("Invalid JSON schema, will return data as-is", error=str(e), schema=schema) + return False + + +def validate_data_against_schema(data: Any, schema: dict[str, Any]) -> list[str]: + """ + Validate data against a JSON schema using Draft202012Validator. + + Args: + data: The data to validate + schema: The JSON schema to validate against + + Returns: + List of validation error messages (empty if valid) + """ + validator = Draft202012Validator(schema) + errors = [] + + for error in validator.iter_errors(data): + error_path = ".".join(str(p) for p in error.path) if error.path else "root" + errors.append(f"{error_path}: {error.message}") + + return errors + + +def _is_all_default_values(data: dict[str, Any], schema: dict[str, Any]) -> bool: + """ + Check if a dict contains only default values (indicating it was created from invalid data). + + Args: + data: The data object to check + schema: The schema defining the expected structure + + Returns: + True if all values are defaults, False otherwise + """ + if not isinstance(data, dict): + return False + + properties = schema.get("properties", {}) + if not properties: + return False + + # Check each property against its default value + for field_name, field_schema in properties.items(): + if field_name not in data: + continue + + field_value = data[field_name] + field_type = _resolve_schema_type(field_schema.get("type"), f"check.{field_name}") + default_value = get_default_value_for_type(field_type) + + # If any field has a non-default value, the record is meaningful + if field_value != default_value: + return False + + return True + + +def _filter_invalid_array_items(data: list[Any], schema: dict[str, Any]) -> list[Any]: + """ + Filter out array items that are all default values (created from invalid data like strings). + + Args: + data: The array data to filter + schema: The array schema + + Returns: + Filtered array with invalid items removed + """ + items_schema = schema.get("items") + if not items_schema or not isinstance(items_schema, dict): + return data + + # Only filter if items are objects + if items_schema.get("type") not in ("object", ["object", "null"]): + return data + + filtered = [] + removed_count = 0 + + for item in data: + if isinstance(item, dict) and _is_all_default_values(item, items_schema): + removed_count += 1 + LOG.info("Filtering out invalid array item with all default values", item=item) + else: + filtered.append(item) + + if removed_count > 0: + LOG.info(f"Removed {removed_count} invalid array items") + + return filtered + + +def validate_and_fill_extraction_result( + extraction_result: dict[str, Any], + schema: dict[str, Any] | list | str | None, +) -> dict[str, Any]: + """ + Validate extraction result against schema and fill missing fields with defaults. + + This function handles malformed JSON responses from LLMs by: + 1. Validating the schema itself is valid JSON Schema (returns data as-is if invalid) + 2. Filling in missing required fields with appropriate default values + 3. Validating the filled structure against the provided schema using jsonschema + 4. Preserving optional fields that are present + + Args: + extraction_result: The extraction result from the LLM + schema: The JSON schema that defines the expected structure + + Returns: + The validated and filled extraction result, or the original data if schema is invalid + """ + if schema is None: + LOG.debug("No schema provided, returning extraction result as-is") + return extraction_result + + if not validate_schema(schema): + LOG.info("Schema is invalid, returning extraction result as-is without transformations") + return extraction_result + + LOG.info("Validating and filling extraction result against schema") + + try: + filled_result = fill_missing_fields(extraction_result, schema) + + # Filter out invalid array items if the schema is for an array + if isinstance(schema, dict) and schema.get("type") == "array" and isinstance(filled_result, list): + filled_result = _filter_invalid_array_items(filled_result, schema) + + if isinstance(schema, dict): + validation_errors = validate_data_against_schema(filled_result, schema) + if validation_errors: + LOG.warning( + "Validation errors found after filling", + errors=validation_errors, + ) + + LOG.info("Successfully validated and filled extraction result") + return filled_result + except Exception as e: + LOG.error( + "Failed to validate and fill extraction result", + error=str(e), + exc_info=True, + ) + return extraction_result diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index 27253e0b..6d8f6822 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -67,6 +67,7 @@ from skyvern.forge.sdk.api.files import ( ) from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCallerManager from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError +from skyvern.forge.sdk.api.llm.schema_validator import validate_and_fill_extraction_result from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.skyvern_context import current as skyvern_current from skyvern.forge.sdk.core.skyvern_context import ensure_context @@ -3799,6 +3800,13 @@ async def extract_information_for_navigation_goal( force_dict=False, ) + # Validate and fill missing fields based on schema + if task.extracted_information_schema: + json_response = validate_and_fill_extraction_result( + extraction_result=json_response, + schema=task.extracted_information_schema, + ) + return ScrapeResult( scraped_data=json_response, ) diff --git a/tests/unit_tests/test_schema_validator.py b/tests/unit_tests/test_schema_validator.py new file mode 100644 index 00000000..da3d027a --- /dev/null +++ b/tests/unit_tests/test_schema_validator.py @@ -0,0 +1,520 @@ +from typing import Any + +import pytest + +from skyvern.forge.sdk.api.llm.schema_validator import ( + fill_missing_fields, + get_default_value_for_type, + validate_and_fill_extraction_result, + validate_data_against_schema, + validate_schema, +) + + +class TestSchemaValidator: + @pytest.fixture + def medication_schema(self) -> dict[str, Any]: + """Schema for medication extraction data.""" + return { + "type": "array", + "items": { + "type": "object", + "properties": { + "Medication Name": {"type": "string"}, + "NDC": {"type": "string"}, + "quantity": {"type": "string"}, + "facility": {"type": "string"}, + "recoverydate": {"type": ["string", "null"]}, + "isAllocation": {"type": "boolean"}, + "ErrorMessage": {"type": ["string", "null"]}, + }, + "required": [ + "Medication Name", + "NDC", + "quantity", + "facility", + "recoverydate", + "isAllocation", + "ErrorMessage", + ], + }, + } + + @pytest.fixture + def complete_medication_data(self) -> list[dict[str, Any]]: + """Complete medication data with all required fields.""" + return [ + { + "Medication Name": "ACETAMINOPHEN 1000MG-30ML CHRY LIQ 237ML", + "NDC": "00904-7481-59", + "quantity": "0", + "facility": "CHI-IL", + "recoverydate": None, + "isAllocation": False, + "ErrorMessage": None, + }, + { + "Medication Name": "ACETAMINOPHEN CHERRY 160MG/5ML SOL 473ML", + "NDC": "00904-7014-16", + "quantity": "100", + "facility": "CHI-IL", + "recoverydate": None, + "isAllocation": False, + "ErrorMessage": None, + }, + ] + + @pytest.fixture + def incomplete_medication_data(self) -> list[dict[str, Any]]: + """Incomplete medication data missing some required fields.""" + return [ + { + "Medication Name": "ACETAMINOPHEN 1000MG-30ML CHRY LIQ 237ML", + "NDC": "00904-7481-59", + "quantity": "0", + "facility": "CHI-IL", + # Missing: recoverydate, isAllocation, ErrorMessage + }, + { + "Medication Name": "AMOXICILLIN 500MG CAPSULE 500", + # Missing: NDC, quantity, facility, recoverydate, isAllocation, ErrorMessage + }, + ] + + def test_get_default_value_for_string(self) -> None: + """Test default value generation for string type.""" + assert get_default_value_for_type("string") is None + + def test_get_default_value_for_boolean(self) -> None: + """Test default value generation for boolean type.""" + assert get_default_value_for_type("boolean") is False + + def test_get_default_value_for_array(self) -> None: + """Test default value generation for array type.""" + assert get_default_value_for_type("array") == [] + + def test_get_default_value_for_object(self) -> None: + """Test default value generation for object type.""" + assert get_default_value_for_type("object") == {} + + def test_get_default_value_for_null(self) -> None: + """Test default value generation for null type.""" + assert get_default_value_for_type("null") is None + + def test_get_default_value_for_type_list_with_null(self) -> None: + """Test default value generation for type list containing null.""" + assert get_default_value_for_type(["string", "null"]) is None + assert get_default_value_for_type(["null", "string"]) is None + + def test_get_default_value_for_type_list_all_null(self) -> None: + """Test default value generation for type list with only null.""" + assert get_default_value_for_type(["null"]) is None + + def test_get_default_value_for_uppercase_type(self) -> None: + """Test default value generation for uppercase type names.""" + assert get_default_value_for_type("STRING") is None + assert get_default_value_for_type("NUMBER") == 0 + assert get_default_value_for_type("INTEGER") == 0 + assert get_default_value_for_type("BOOLEAN") is False + assert get_default_value_for_type("ARRAY") == [] + assert get_default_value_for_type("OBJECT") == {} + assert get_default_value_for_type("NULL") is None + + def test_get_default_value_for_mixed_case_type(self) -> None: + """Test default value generation for mixed case type names.""" + assert get_default_value_for_type("String") is None + assert get_default_value_for_type("Boolean") is False + assert get_default_value_for_type(["STRING", "null"]) is None + assert get_default_value_for_type(["NULL", "STRING"]) is None + + def test_fill_missing_fields_complete_data( + self, medication_schema: dict[str, Any], complete_medication_data: list[dict[str, Any]] + ) -> None: + """Test that complete data passes through unchanged.""" + result = fill_missing_fields(complete_medication_data, medication_schema) + assert result == complete_medication_data + + def test_fill_missing_fields_incomplete_data( + self, medication_schema: dict[str, Any], incomplete_medication_data: list[dict[str, Any]] + ) -> None: + """Test that missing required fields are filled with defaults.""" + result = fill_missing_fields(incomplete_medication_data, medication_schema) + + # First item should have missing fields filled + assert result[0]["Medication Name"] == "ACETAMINOPHEN 1000MG-30ML CHRY LIQ 237ML" + assert result[0]["NDC"] == "00904-7481-59" + assert result[0]["quantity"] == "0" + assert result[0]["facility"] == "CHI-IL" + assert result[0]["recoverydate"] is None # Default for ["string", "null"] + assert result[0]["isAllocation"] is False # Default for boolean + assert result[0]["ErrorMessage"] is None # Default for ["string", "null"] + + # Second item should have all missing fields filled + assert result[1]["Medication Name"] == "AMOXICILLIN 500MG CAPSULE 500" + assert result[1]["NDC"] is None # Default for string + assert result[1]["quantity"] is None # Default for string + assert result[1]["facility"] is None # Default for string + assert result[1]["recoverydate"] is None # Default for ["string", "null"] + assert result[1]["isAllocation"] is False # Default for boolean + assert result[1]["ErrorMessage"] is None # Default for ["string", "null"] + + def test_fill_missing_fields_with_error_message(self, medication_schema: dict[str, Any]) -> None: + """Test filling fields when ErrorMessage has a value.""" + data = [ + { + "Medication Name": "TEST MEDICATION", + "NDC": "12345-678-90", + "quantity": "50", + "facility": "TEST-FACILITY", + "recoverydate": "2024-01-01", + "isAllocation": True, + "ErrorMessage": "Some error occurred", + } + ] + + result = fill_missing_fields(data, medication_schema) + assert result[0]["ErrorMessage"] == "Some error occurred" + + def test_fill_missing_fields_empty_array(self, medication_schema: dict[str, Any]) -> None: + """Test handling of empty array.""" + result = fill_missing_fields([], medication_schema) + assert result == [] + + def test_fill_missing_fields_invalid_data_type(self, medication_schema: dict[str, Any]) -> None: + """Test handling when data is not an array.""" + # When data is not a list, it should be converted to empty array + result = fill_missing_fields("not an array", medication_schema) + assert result == [] + + def test_fill_missing_fields_nested_object_missing_fields(self, medication_schema: dict[str, Any]) -> None: + """Test that nested objects have missing fields filled.""" + data = [ + { + "Medication Name": "TEST MEDICATION", + # All other fields missing + } + ] + + result = fill_missing_fields(data, medication_schema) + assert len(result) == 1 + assert "NDC" in result[0] + assert "quantity" in result[0] + assert "facility" in result[0] + assert "recoverydate" in result[0] + assert "isAllocation" in result[0] + assert "ErrorMessage" in result[0] + + def test_validate_and_fill_extraction_result_with_schema( + self, medication_schema: dict[str, Any], incomplete_medication_data: list[dict[str, Any]] + ) -> None: + """Test validate_and_fill_extraction_result with medication schema.""" + result = validate_and_fill_extraction_result(incomplete_medication_data, medication_schema) + + # Verify all required fields are present in all items + for item in result: + assert "Medication Name" in item + assert "NDC" in item + assert "quantity" in item + assert "facility" in item + assert "recoverydate" in item + assert "isAllocation" in item + assert "ErrorMessage" in item + + def test_validate_and_fill_extraction_result_no_schema(self) -> None: + """Test that data passes through unchanged when no schema is provided.""" + data = {"some": "data"} + result = validate_and_fill_extraction_result(data, None) + assert result == data + + def test_validate_and_fill_extraction_result_with_exception(self, medication_schema: dict[str, Any]) -> None: + """Test that original data is returned if validation fails.""" + # This should not raise an exception, but return original data + invalid_data = "not a valid structure" + result = validate_and_fill_extraction_result(invalid_data, medication_schema) + # Should return empty array since invalid_data gets converted + assert result == [] + + def test_fill_missing_fields_preserves_existing_values(self, medication_schema: dict[str, Any]) -> None: + """Test that existing values are preserved and not overwritten.""" + data = [ + { + "Medication Name": "EXISTING NAME", + "NDC": "EXISTING-NDC", + "quantity": "999", + "facility": "EXISTING-FACILITY", + "recoverydate": "2024-12-31", + "isAllocation": True, + "ErrorMessage": "Existing error", + } + ] + + result = fill_missing_fields(data, medication_schema) + + # All original values should be preserved + assert result[0]["Medication Name"] == "EXISTING NAME" + assert result[0]["NDC"] == "EXISTING-NDC" + assert result[0]["quantity"] == "999" + assert result[0]["facility"] == "EXISTING-FACILITY" + assert result[0]["recoverydate"] == "2024-12-31" + assert result[0]["isAllocation"] is True + assert result[0]["ErrorMessage"] == "Existing error" + + def test_fill_missing_fields_nullable_object_with_null(self) -> None: + """Test handling of nullable object type when data is null.""" + schema = { + "type": ["object", "null"], + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + } + + # When data is null, it should remain null (valid for nullable type) + result = fill_missing_fields(None, schema) + assert result is None + + def test_fill_missing_fields_nullable_object_with_object(self) -> None: + """Test handling of nullable object type when data is an object with missing fields.""" + schema = { + "type": ["object", "null"], + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + } + + # When data is an object (not null), missing required fields should be filled + data = {"name": "John"} # Missing 'age' + result = fill_missing_fields(data, schema) + + # With the fix, missing required fields should be filled + assert result == {"name": "John", "age": 0} + assert "name" in result + assert "age" in result + + def test_fill_missing_fields_nullable_array_with_null(self) -> None: + """Test handling of nullable array type when data is null.""" + schema = { + "type": ["array", "null"], + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + }, + "required": ["id"], + }, + } + + # When data is null, it should remain null + result = fill_missing_fields(None, schema) + assert result is None + + def test_fill_missing_fields_nullable_array_with_array(self) -> None: + """Test handling of nullable array type when data is an array.""" + schema = { + "type": ["array", "null"], + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "name": {"type": "string"}, + }, + "required": ["id", "name"], + }, + } + + # When data is an array (not null), items should be validated + data = [{"id": "1"}] # Missing 'name' + result = fill_missing_fields(data, schema) + + # With the fix, missing fields in array items should be filled + assert len(result) == 1 + assert result[0] == {"id": "1", "name": None} + assert "id" in result[0] + assert "name" in result[0] + + def test_validate_schema_valid(self) -> None: + """Test that valid schemas return True.""" + valid_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + } + assert validate_schema(valid_schema) is True + + def test_validate_schema_invalid(self) -> None: + """Test that invalid schemas return False.""" + invalid_schema = { + "type": "object", + "properties": { + "name": { + "type": "string", + "minLength": "not_a_number", # Should be a number + } + }, + } + # Should return False for invalid schema + assert validate_schema(invalid_schema) is False + + def test_validate_schema_none(self) -> None: + """Test that None schema is considered valid.""" + assert validate_schema(None) is True + + def test_validate_schema_string(self) -> None: + """Test that string schema is considered valid (permissive).""" + assert validate_schema("some_string") is True + + def test_validate_schema_list(self) -> None: + """Test that list schema is considered valid (permissive).""" + assert validate_schema([]) is True + + def test_validate_data_against_schema_valid(self) -> None: + """Test validation of valid data against schema.""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + } + data = {"name": "John", "age": 30} + errors = validate_data_against_schema(data, schema) + assert errors == [] + + def test_validate_data_against_schema_missing_required(self) -> None: + """Test validation when required fields are missing.""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + } + data = {"name": "John"} # Missing 'age' + errors = validate_data_against_schema(data, schema) + assert len(errors) > 0 + assert any("age" in error for error in errors) + + def test_validate_data_against_schema_wrong_type(self) -> None: + """Test validation when data has wrong type.""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + data = {"name": "John", "age": "thirty"} # age should be integer + errors = validate_data_against_schema(data, schema) + assert len(errors) > 0 + assert any("age" in error for error in errors) + + def test_validate_data_against_schema_array(self) -> None: + """Test validation of array data.""" + schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + }, + "required": ["id"], + }, + } + data = [{"id": "1"}, {"id": "2"}] + errors = validate_data_against_schema(data, schema) + assert errors == [] + + def test_validate_and_fill_with_jsonschema_validation(self, medication_schema: dict[str, Any]) -> None: + """Test that validate_and_fill uses jsonschema for validation.""" + # Data with all required fields filled correctly + data = [ + { + "Medication Name": "TEST MED", + "NDC": "12345", + "quantity": "10", + "facility": "TEST", + "recoverydate": None, + "isAllocation": False, + "ErrorMessage": None, + } + ] + + result = validate_and_fill_extraction_result(data, medication_schema) + assert result == data + + def test_validate_and_fill_with_invalid_schema(self) -> None: + """Test that validate_and_fill returns data as-is for invalid schemas.""" + # Create a schema that will fail Draft202012Validator.check_schema + invalid_schema = { + "type": "object", + "properties": { + "name": { + "type": "string", + "minLength": "not_a_number", # Should be a number + } + }, + } + + data = {"name": "test"} + # Should return data as-is without transformations when schema is invalid + result = validate_and_fill_extraction_result(data, invalid_schema) + assert result == data + + def test_filter_invalid_array_items_with_string(self, medication_schema: dict[str, Any]) -> None: + """Test that array items created from invalid data (strings) are filtered out.""" + # Simulate LLM response with a string mixed in the array + data = [ + { + "Medication Name": "ACETAMINOPHEN 500MG", + "NDC": "12345-678-90", + "quantity": "100", + "facility": "TEST-FACILITY", + "recoverydate": None, + "isAllocation": False, + "ErrorMessage": None, + }, + "This is an invalid string that should be filtered out", + { + "Medication Name": "IBUPROFEN 200MG", + "NDC": "98765-432-10", + "quantity": "50", + "facility": "TEST-FACILITY", + "recoverydate": "2024-01-01", + "isAllocation": True, + "ErrorMessage": None, + }, + ] + + result = validate_and_fill_extraction_result(data, medication_schema) + + # Should have only 2 valid records (the string should be filtered out) + assert len(result) == 2 + assert result[0]["Medication Name"] == "ACETAMINOPHEN 500MG" + assert result[1]["Medication Name"] == "IBUPROFEN 200MG" + + def test_filter_invalid_array_items_preserves_valid_defaults(self, medication_schema: dict[str, Any]) -> None: + """Test that records with some valid data are preserved even if some fields are defaults.""" + data = [ + { + "Medication Name": "VALID MEDICATION", + "NDC": "12345-678-90", + # Missing other required fields - should be filled with defaults but NOT filtered + } + ] + + result = validate_and_fill_extraction_result(data, medication_schema) + + # Should preserve the record because it has meaningful data + assert len(result) == 1 + assert result[0]["Medication Name"] == "VALID MEDICATION" + assert result[0]["NDC"] == "12345-678-90" + assert result[0]["quantity"] is None # Filled with default + assert result[0]["facility"] is None # Filled with default diff --git a/uv.lock b/uv.lock index 5a3c29cc..873809fd 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.11, <3.14" resolution-markers = [ "python_full_version >= '3.13'", @@ -2095,7 +2094,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/3a/d61707803260d59520721fa326babfae25e9573a88d8b7b9cb54c5423a59/jiter-0.11.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:089f9df9f69532d1339e83142438668f52c97cd22ee2d1195551c2b1a9e6cf33", size = 313737 }, { url = "https://files.pythonhosted.org/packages/cd/cc/c9f0eec5d00f2a1da89f6bdfac12b8afdf8d5ad974184863c75060026457/jiter-0.11.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29ed1fe69a8c69bf0f2a962d8d706c7b89b50f1332cd6b9fbda014f60bd03a03", size = 346183 }, { url = "https://files.pythonhosted.org/packages/a6/87/fc632776344e7aabbab05a95a0075476f418c5d29ab0f2eec672b7a1f0ac/jiter-0.11.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a4d71d7ea6ea8786291423fe209acf6f8d398a0759d03e7f24094acb8ab686ba", size = 204225 }, - { url = "https://files.pythonhosted.org/packages/70/f3/ce100253c80063a7b8b406e1d1562657fd4b9b4e1b562db40e68645342fb/jiter-0.11.0-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:902b43386c04739229076bd1c4c69de5d115553d982ab442a8ae82947c72ede7", size = 336380 }, ] [[package]] @@ -4774,6 +4772,7 @@ dependencies = [ { name = "httpx", extra = ["socks"] }, { name = "jinja2" }, { name = "json-repair" }, + { name = "jsonschema" }, { name = "lark" }, { name = "libcst" }, { name = "litellm" }, @@ -4865,6 +4864,7 @@ requires-dist = [ { name = "httpx", extras = ["socks"], specifier = ">=0.27.0" }, { name = "jinja2", specifier = ">=3.1.2,<4" }, { name = "json-repair", specifier = ">=0.34.0,<0.35" }, + { name = "jsonschema", specifier = ">=4.25.1" }, { name = "lark", specifier = ">=1.2.2,<2" }, { name = "libcst", specifier = ">=1.8.2,<2" }, { name = "litellm", specifier = ">=1.75.8" },