From 9ad3f1f2b24b43e2c5ab57c69e7d6bf7597a1c77 Mon Sep 17 00:00:00 2001 From: ShiZai Date: Mon, 15 Dec 2025 07:06:56 +0800 Subject: [PATCH] Add 116 unit tests for core utility modules (#4269) Co-authored-by: Claude Co-authored-by: Shuchang Zheng --- tests/unit_tests/test_id_generator.py | 243 ++++++++++++++++++ tests/unit_tests/test_string_util.py | 63 +++++ tests/unit_tests/test_strings.py | 91 +++++++ tests/unit_tests/test_token_counter.py | 108 ++++++++ .../test_url_validators_extended.py | 173 +++++++++++++ 5 files changed, 678 insertions(+) create mode 100644 tests/unit_tests/test_id_generator.py create mode 100644 tests/unit_tests/test_string_util.py create mode 100644 tests/unit_tests/test_strings.py create mode 100644 tests/unit_tests/test_token_counter.py create mode 100644 tests/unit_tests/test_url_validators_extended.py diff --git a/tests/unit_tests/test_id_generator.py b/tests/unit_tests/test_id_generator.py new file mode 100644 index 00000000..e0530a58 --- /dev/null +++ b/tests/unit_tests/test_id_generator.py @@ -0,0 +1,243 @@ +import re + +import pytest + +from skyvern.forge.sdk.db.id import ( + ACTION_PREFIX, + ARTIFACT_PREFIX, + BASE_EPOCH, + CREDENTIAL_PREFIX, + DEBUG_SESSION_PREFIX, + FOLDER_PREFIX, + ORG_PREFIX, + SEQUENCE_BITS, + SEQUENCE_MAX, + STEP_PREFIX, + TASK_PREFIX, + TASK_V2_ID, + THOUGHT_ID, + TIMESTAMP_BITS, + TOTP_CODE_PREFIX, + USER_PREFIX, + VERSION, + VERSION_BITS, + WORKER_ID_BITS, + WORKFLOW_PREFIX, + WORKFLOW_RUN_BLOCK_PREFIX, + WORKFLOW_RUN_PREFIX, + _mask_shift, + current_time, + current_time_ms, + generate_action_id, + generate_artifact_id, + generate_credential_id, + generate_debug_session_id, + generate_folder_id, + generate_id, + generate_org_id, + generate_step_id, + generate_task_id, + generate_task_v2_id, + generate_thought_id, + generate_totp_code_id, + generate_user_id, + generate_workflow_id, + generate_workflow_run_block_id, + generate_workflow_run_id, +) + + +class TestConstants: + """Tests for ID generator constants.""" + + def test_base_epoch(self): + """BASE_EPOCH should be June 20, 2022 12:00 AM UTC.""" + assert BASE_EPOCH == 1655683200 + + def test_version(self): + """Version should be 0.""" + assert VERSION == 0 + + def test_bit_allocation(self): + """Total bits should add up to 64.""" + total_bits = TIMESTAMP_BITS + WORKER_ID_BITS + SEQUENCE_BITS + VERSION_BITS + assert total_bits == 64 + + def test_sequence_max(self): + """SEQUENCE_MAX should be 2^10 - 1 = 1023.""" + assert SEQUENCE_MAX == 1023 + + +class TestMaskShift: + """Tests for the _mask_shift helper function.""" + + def test_mask_shift_basic(self): + """Basic mask and shift operation.""" + # value=5 (binary: 101), mask=3 bits, shift=2 + # masked: 5 & 0b111 = 5 + # shifted: 5 << 2 = 20 + result = _mask_shift(5, 3, 2) + assert result == 20 + + def test_mask_shift_overflow(self): + """Value exceeding mask bits should be truncated.""" + # value=15 (binary: 1111), mask=2 bits + # masked: 15 & 0b11 = 3 + # shifted: 3 << 0 = 3 + result = _mask_shift(15, 2, 0) + assert result == 3 + + def test_mask_shift_zero(self): + """Zero value should return zero.""" + assert _mask_shift(0, 10, 5) == 0 + + def test_mask_shift_max_bits(self): + """Test with maximum number of bits.""" + value = (1 << 32) - 1 # Max 32-bit value + result = _mask_shift(value, 32, 0) + assert result == value + + +class TestCurrentTime: + """Tests for time-related functions.""" + + def test_current_time_is_integer(self): + """current_time should return an integer.""" + assert isinstance(current_time(), int) + + def test_current_time_is_reasonable(self): + """current_time should be a reasonable Unix timestamp.""" + now = current_time() + # Should be after BASE_EPOCH + assert now > BASE_EPOCH + # Should be within reasonable bounds (before year 2100) + assert now < 4102444800 + + def test_current_time_ms_is_integer(self): + """current_time_ms should return an integer.""" + assert isinstance(current_time_ms(), int) + + def test_current_time_ms_is_milliseconds(self): + """current_time_ms should be roughly 1000x current_time.""" + t = current_time() + t_ms = current_time_ms() + # Allow for timing differences + assert abs(t_ms - t * 1000) < 2000 + + +class TestGenerateId: + """Tests for the generate_id function.""" + + def test_generate_id_returns_integer(self): + """generate_id should return an integer.""" + assert isinstance(generate_id(), int) + + def test_generate_id_is_positive(self): + """generate_id should return a positive integer.""" + assert generate_id() > 0 + + def test_generate_id_is_64_bit(self): + """generate_id should fit in 64 bits.""" + id_val = generate_id() + assert id_val < (1 << 64) + + def test_generate_id_uniqueness(self): + """Multiple calls should generate unique IDs.""" + ids = [generate_id() for _ in range(1000)] + assert len(set(ids)) == 1000 + + def test_generate_id_monotonic(self): + """IDs generated in sequence should generally be increasing.""" + id1 = generate_id() + id2 = generate_id() + id3 = generate_id() + # Due to sequence counter, IDs should increase + # (within the same second, at least) + assert id1 < id2 < id3 + + +class TestPrefixedIdGenerators: + """Tests for all prefixed ID generator functions.""" + + @pytest.mark.parametrize( + "generator,prefix", + [ + (generate_workflow_id, WORKFLOW_PREFIX), + (generate_workflow_run_id, WORKFLOW_RUN_PREFIX), + (generate_workflow_run_block_id, WORKFLOW_RUN_BLOCK_PREFIX), + (generate_task_id, TASK_PREFIX), + (generate_step_id, STEP_PREFIX), + (generate_artifact_id, ARTIFACT_PREFIX), + (generate_user_id, USER_PREFIX), + (generate_org_id, ORG_PREFIX), + (generate_action_id, ACTION_PREFIX), + (generate_task_v2_id, TASK_V2_ID), + (generate_thought_id, THOUGHT_ID), + (generate_totp_code_id, TOTP_CODE_PREFIX), + (generate_credential_id, CREDENTIAL_PREFIX), + (generate_debug_session_id, DEBUG_SESSION_PREFIX), + (generate_folder_id, FOLDER_PREFIX), + ], + ) + def test_id_has_correct_prefix(self, generator, prefix): + """Generated ID should have the correct prefix.""" + generated_id = generator() + assert generated_id.startswith(f"{prefix}_") + + @pytest.mark.parametrize( + "generator", + [ + generate_workflow_id, + generate_workflow_run_id, + generate_task_id, + generate_step_id, + generate_artifact_id, + generate_user_id, + generate_org_id, + generate_action_id, + ], + ) + def test_id_format(self, generator): + """Generated ID should match expected format: prefix_integer.""" + generated_id = generator() + assert re.match(r"^[a-z_]+_\d+$", generated_id) + + @pytest.mark.parametrize( + "generator", + [ + generate_workflow_id, + generate_task_id, + generate_org_id, + ], + ) + def test_id_uniqueness(self, generator): + """Multiple calls should generate unique IDs.""" + ids = [generator() for _ in range(100)] + assert len(set(ids)) == 100 + + def test_different_generators_produce_different_prefixes(self): + """Different generators should produce IDs with different prefixes.""" + task_id = generate_task_id() + step_id = generate_step_id() + org_id = generate_org_id() + + assert task_id.split("_")[0] != step_id.split("_")[0] + assert task_id.split("_")[0] != org_id.split("_")[0] + assert step_id.split("_")[0] != org_id.split("_")[0] + + +class TestIdExtraction: + """Tests for extracting components from generated IDs.""" + + def test_extract_numeric_part(self): + """Should be able to extract numeric part from ID.""" + task_id = generate_task_id() + parts = task_id.split("_") + numeric_part = int(parts[-1]) + assert numeric_part > 0 + + def test_numeric_part_is_valid_64bit(self): + """Numeric part should be a valid 64-bit integer.""" + task_id = generate_task_id() + numeric_part = int(task_id.split("_")[-1]) + assert numeric_part < (1 << 64) diff --git a/tests/unit_tests/test_string_util.py b/tests/unit_tests/test_string_util.py new file mode 100644 index 00000000..26548b3b --- /dev/null +++ b/tests/unit_tests/test_string_util.py @@ -0,0 +1,63 @@ +from skyvern.webeye.string_util import remove_whitespace + + +class TestRemoveWhitespace: + """Tests for the remove_whitespace function.""" + + def test_remove_multiple_spaces(self): + """Multiple spaces should be collapsed to single space.""" + assert remove_whitespace("hello world") == "hello world" + + def test_remove_tabs(self): + """Tab characters should be converted to single space.""" + assert remove_whitespace("hello\tworld") == "hello world" + + def test_remove_newlines(self): + """Newline characters should be converted to single space.""" + assert remove_whitespace("hello\nworld") == "hello world" + + def test_remove_mixed_whitespace(self): + """Mixed whitespace (spaces, tabs, newlines) should be collapsed.""" + assert remove_whitespace("hello \t\n world") == "hello world" + + def test_leading_trailing_whitespace(self): + """Leading and trailing whitespace should be collapsed but not removed.""" + assert remove_whitespace(" hello ") == " hello " + + def test_empty_string(self): + """Empty string should return empty string.""" + assert remove_whitespace("") == "" + + def test_single_space(self): + """Single space should remain unchanged.""" + assert remove_whitespace(" ") == " " + + def test_no_whitespace(self): + """String without extra whitespace should remain unchanged.""" + assert remove_whitespace("hello") == "hello" + + def test_only_whitespace(self): + """String of only whitespace should collapse to single space.""" + assert remove_whitespace(" \t\n ") == " " + + def test_multiline_text(self): + """Multiline text should have all whitespace collapsed.""" + input_text = """Hello + World + Test""" + assert remove_whitespace(input_text) == "Hello World Test" + + def test_preserves_non_whitespace_special_chars(self): + """Non-whitespace special characters should be preserved.""" + assert remove_whitespace("hello!@#$%^&*()world") == "hello!@#$%^&*()world" + + def test_unicode_text(self): + """Unicode text with whitespace should work correctly.""" + assert remove_whitespace("你好 世界") == "你好 世界" + + def test_carriage_return_not_matched(self): + """Carriage return is not in the regex pattern, verify behavior.""" + # Note: \r is not in the original regex pattern [ \n\t]+ + # This test documents the current behavior + result = remove_whitespace("hello\rworld") + assert result == "hello\rworld" diff --git a/tests/unit_tests/test_strings.py b/tests/unit_tests/test_strings.py new file mode 100644 index 00000000..fe58e3c7 --- /dev/null +++ b/tests/unit_tests/test_strings.py @@ -0,0 +1,91 @@ +"""Tests for string utility functions.""" + +import string + +from skyvern.utils.strings import RANDOM_STRING_POOL, generate_random_string + + +class TestRandomStringPool: + """Tests for RANDOM_STRING_POOL constant.""" + + def test_pool_contains_letters(self): + """Pool should contain all ASCII letters.""" + for char in string.ascii_letters: + assert char in RANDOM_STRING_POOL + + def test_pool_contains_digits(self): + """Pool should contain all digits.""" + for char in string.digits: + assert char in RANDOM_STRING_POOL + + def test_pool_size(self): + """Pool should have expected size (26*2 + 10 = 62).""" + assert len(RANDOM_STRING_POOL) == 62 + + +class TestGenerateRandomString: + """Tests for generate_random_string function.""" + + def test_default_length(self): + """Default length should be 5.""" + result = generate_random_string() + assert len(result) == 5 + + def test_custom_length(self): + """Custom length should be respected.""" + for length in [1, 10, 50, 100]: + result = generate_random_string(length) + assert len(result) == length + + def test_zero_length(self): + """Zero length should return empty string.""" + result = generate_random_string(0) + assert result == "" + + def test_returns_string(self): + """Should return a string type.""" + result = generate_random_string() + assert isinstance(result, str) + + def test_only_alphanumeric(self): + """Result should only contain alphanumeric characters.""" + result = generate_random_string(100) + for char in result: + assert char in RANDOM_STRING_POOL + + def test_randomness(self): + """Multiple calls should produce different results (with high probability).""" + results = [generate_random_string(20) for _ in range(10)] + # All results should be unique (statistically extremely likely with length 20) + assert len(set(results)) == 10 + + def test_distribution(self): + """Characters should be reasonably distributed.""" + # Generate a long string and check distribution + result = generate_random_string(10000) + char_counts = {} + for char in result: + char_counts[char] = char_counts.get(char, 0) + 1 + + # Each character should appear at least once in 10000 characters + # (statistically extremely likely) + assert len(char_counts) > 50 # Most of the 62 chars should appear + + def test_contains_letters(self): + """Generated strings should typically contain letters.""" + # With 62 possible chars and length 100, very likely to have letters + result = generate_random_string(100) + has_letter = any(c in string.ascii_letters for c in result) + assert has_letter + + def test_contains_digits(self): + """Generated strings should typically contain digits.""" + # With 62 possible chars and length 100, very likely to have digits + result = generate_random_string(100) + has_digit = any(c in string.digits for c in result) + assert has_digit + + def test_large_length(self): + """Should handle large lengths.""" + result = generate_random_string(10000) + assert len(result) == 10000 diff --git a/tests/unit_tests/test_token_counter.py b/tests/unit_tests/test_token_counter.py new file mode 100644 index 00000000..3ebc0f31 --- /dev/null +++ b/tests/unit_tests/test_token_counter.py @@ -0,0 +1,108 @@ +"""Tests for token counter utility.""" + +from skyvern.utils.token_counter import count_tokens + + +class TestCountTokens: + """Tests for count_tokens function.""" + + def test_empty_string(self): + """Empty string should have 0 tokens.""" + assert count_tokens("") == 0 + + def test_single_word(self): + """Single word should return token count.""" + result = count_tokens("hello") + assert result > 0 + assert isinstance(result, int) + + def test_simple_sentence(self): + """Simple sentence should have reasonable token count.""" + result = count_tokens("Hello, world!") + assert result > 0 + # "Hello, world!" typically tokenizes to ~4 tokens + assert result < 10 + + def test_longer_text(self): + """Longer text should have more tokens.""" + short = count_tokens("Hi") + long = count_tokens("This is a much longer sentence with many more words in it.") + assert long > short + + def test_returns_integer(self): + """Should return an integer.""" + result = count_tokens("test") + assert isinstance(result, int) + + def test_whitespace_only(self): + """Whitespace should be tokenized.""" + result = count_tokens(" ") + # Whitespace is typically tokenized + assert isinstance(result, int) + + def test_special_characters(self): + """Special characters should be tokenized.""" + result = count_tokens("!@#$%^&*()") + assert result > 0 + + def test_numbers(self): + """Numbers should be tokenized.""" + result = count_tokens("12345") + assert result > 0 + + def test_unicode(self): + """Unicode characters should be tokenized.""" + result = count_tokens("你好世界") + assert result > 0 + + def test_mixed_content(self): + """Mixed content (text, numbers, special chars) should work.""" + result = count_tokens("Hello123!@#World") + assert result > 0 + + def test_newlines(self): + """Text with newlines should be tokenized.""" + result = count_tokens("Hello\nWorld\nTest") + assert result > 0 + + def test_code_snippet(self): + """Code snippets should be tokenized.""" + code = """ +def hello(): + print("Hello, World!") +""" + result = count_tokens(code) + assert result > 5 # Code should have multiple tokens + + def test_json_content(self): + """JSON content should be tokenized.""" + json_str = '{"key": "value", "number": 123}' + result = count_tokens(json_str) + assert result > 0 + + def test_url(self): + """URLs should be tokenized.""" + result = count_tokens("https://www.example.com/path?query=value") + assert result > 0 + + def test_consistency(self): + """Same input should always produce same output.""" + text = "This is a test sentence." + result1 = count_tokens(text) + result2 = count_tokens(text) + assert result1 == result2 + + def test_very_long_text(self): + """Very long text should be handled.""" + long_text = "word " * 1000 + result = count_tokens(long_text) + assert result > 100 # Should have many tokens + + def test_token_count_approximation(self): + """Token count should be roughly 1 token per 4 chars for English.""" + text = "This is a sample text for testing token count approximation." + result = count_tokens(text) + # GPT tokenizers typically produce ~1 token per 4 characters + char_count = len(text) + assert result > char_count / 10 # Very loose lower bound + assert result < char_count # Token count should be less than char count diff --git a/tests/unit_tests/test_url_validators_extended.py b/tests/unit_tests/test_url_validators_extended.py new file mode 100644 index 00000000..2e50f5b5 --- /dev/null +++ b/tests/unit_tests/test_url_validators_extended.py @@ -0,0 +1,173 @@ +"""Extended tests for URL validators module.""" + +import pytest + +from skyvern.exceptions import InvalidUrl +from skyvern.utils.url_validators import encode_url, prepend_scheme_and_validate_url + + +class TestEncodeUrl: + """Tests for encode_url function.""" + + def test_encode_simple_path(self): + """Simple path with spaces should be encoded.""" + url = "https://example.com/path with spaces" + result = encode_url(url) + assert result == "https://example.com/path%20with%20spaces" + + def test_encode_query_params(self): + """Query parameters with spaces should be encoded.""" + url = "https://example.com/search?q=hello world" + result = encode_url(url) + assert result == "https://example.com/search?q=hello%20world" + + def test_preserve_slashes_in_path(self): + """Slashes in path should be preserved.""" + url = "https://example.com/path/to/resource" + result = encode_url(url) + assert result == "https://example.com/path/to/resource" + + def test_preserve_existing_encoding(self): + """Already encoded characters should be preserved.""" + url = "https://example.com/path%20already%20encoded" + result = encode_url(url) + assert result == "https://example.com/path%20already%20encoded" + + def test_encode_special_characters(self): + """Special characters in path should be encoded.""" + url = "https://example.com/pathbrackets" + result = encode_url(url) + # Angle brackets should be percent-encoded + assert result == "https://example.com/path%3Cwith%3Ebrackets" + + def test_preserve_query_structure(self): + """Query string structure (= and &) should be preserved.""" + url = "https://example.com/search?key1=value1&key2=value2" + result = encode_url(url) + assert "key1=value1" in result + assert "key2=value2" in result + assert "&" in result + + def test_empty_path(self): + """URL with empty path should work.""" + url = "https://example.com" + result = encode_url(url) + assert result == "https://example.com" + + def test_encode_unicode_path(self): + """Unicode characters in path should be encoded.""" + url = "https://example.com/路径" + result = encode_url(url) + # Unicode should be percent-encoded, original characters should not appear + assert "路径" not in result + assert "example.com/" in result + # "路径" in UTF-8 is encoded as %E8%B7%AF%E5%BE%84 + assert "%E8%B7%AF%E5%BE%84" in result + + def test_fragment_preserved(self): + """URL fragments should be preserved.""" + url = "https://example.com/page#section" + result = encode_url(url) + # Fragment should be preserved in the output + assert result == "https://example.com/page#section" + + +class TestPrependSchemeAndValidateUrl: + """Tests for prepend_scheme_and_validate_url function.""" + + def test_empty_url_returns_empty(self): + """Empty URL should return empty string.""" + assert prepend_scheme_and_validate_url("") == "" + + def test_https_url_unchanged(self): + """URL with https scheme should remain unchanged.""" + url = "https://example.com" + result = prepend_scheme_and_validate_url(url) + assert result == url + + def test_http_url_unchanged(self): + """URL with http scheme should remain unchanged.""" + url = "http://example.com" + result = prepend_scheme_and_validate_url(url) + assert result == url + + def test_no_scheme_gets_https(self): + """URL without scheme should get https prepended.""" + url = "example.com" + result = prepend_scheme_and_validate_url(url) + assert result == "https://example.com" + + def test_invalid_scheme_raises_error(self): + """URL with invalid scheme should raise InvalidUrl.""" + with pytest.raises(InvalidUrl): + prepend_scheme_and_validate_url("ftp://example.com") + + def test_file_scheme_raises_error(self): + """URL with file scheme should raise InvalidUrl.""" + with pytest.raises(InvalidUrl): + prepend_scheme_and_validate_url("file:///etc/passwd") + + def test_javascript_scheme_raises_error(self): + """URL with javascript scheme should raise InvalidUrl.""" + with pytest.raises(InvalidUrl): + prepend_scheme_and_validate_url("javascript:alert(1)") + + def test_valid_url_with_path(self): + """Valid URL with path should work.""" + url = "https://example.com/path/to/resource" + result = prepend_scheme_and_validate_url(url) + assert result == url + + def test_valid_url_with_query(self): + """Valid URL with query parameters should work.""" + url = "https://example.com/search?q=test" + result = prepend_scheme_and_validate_url(url) + assert result == url + + def test_url_with_port(self): + """URL with port number should work.""" + url = "https://example.com:8080/path" + result = prepend_scheme_and_validate_url(url) + assert result == url + + def test_subdomain_url(self): + """URL with subdomain should work.""" + url = "https://api.example.com/v1" + result = prepend_scheme_and_validate_url(url) + assert result == url + + def test_invalid_url_raises_error(self): + """Completely invalid URL should raise InvalidUrl.""" + with pytest.raises(InvalidUrl): + prepend_scheme_and_validate_url("not a valid url at all!!!") + + +class TestEncodeUrlEdgeCases: + """Edge case tests for encode_url.""" + + def test_double_slashes_in_path(self): + """Double slashes in path should be preserved.""" + url = "https://example.com//double//slashes" + result = encode_url(url) + assert "//double//slashes" in result + + def test_url_with_credentials(self): + """URL with credentials should preserve the authority portion.""" + url = "https://user:pass@example.com/path" + result = encode_url(url) + # Full authority portion including credentials should be preserved + assert "user:pass@example.com" in result + + def test_very_long_url(self): + """Very long URL should be handled.""" + long_path = "/a" * 1000 + url = f"https://example.com{long_path}" + result = encode_url(url) + assert len(result) >= len(url) + + def test_url_with_multiple_query_params(self): + """URL with multiple query parameters should work.""" + url = "https://example.com/search?a=1&b=2&c=3&d=4&e=5" + result = encode_url(url) + assert "a=1" in result + assert "e=5" in result