Add 116 unit tests for core utility modules (#4269)
Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
243
tests/unit_tests/test_id_generator.py
Normal file
243
tests/unit_tests/test_id_generator.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.sdk.db.id import (
|
||||
ACTION_PREFIX,
|
||||
ARTIFACT_PREFIX,
|
||||
BASE_EPOCH,
|
||||
CREDENTIAL_PREFIX,
|
||||
DEBUG_SESSION_PREFIX,
|
||||
FOLDER_PREFIX,
|
||||
ORG_PREFIX,
|
||||
SEQUENCE_BITS,
|
||||
SEQUENCE_MAX,
|
||||
STEP_PREFIX,
|
||||
TASK_PREFIX,
|
||||
TASK_V2_ID,
|
||||
THOUGHT_ID,
|
||||
TIMESTAMP_BITS,
|
||||
TOTP_CODE_PREFIX,
|
||||
USER_PREFIX,
|
||||
VERSION,
|
||||
VERSION_BITS,
|
||||
WORKER_ID_BITS,
|
||||
WORKFLOW_PREFIX,
|
||||
WORKFLOW_RUN_BLOCK_PREFIX,
|
||||
WORKFLOW_RUN_PREFIX,
|
||||
_mask_shift,
|
||||
current_time,
|
||||
current_time_ms,
|
||||
generate_action_id,
|
||||
generate_artifact_id,
|
||||
generate_credential_id,
|
||||
generate_debug_session_id,
|
||||
generate_folder_id,
|
||||
generate_id,
|
||||
generate_org_id,
|
||||
generate_step_id,
|
||||
generate_task_id,
|
||||
generate_task_v2_id,
|
||||
generate_thought_id,
|
||||
generate_totp_code_id,
|
||||
generate_user_id,
|
||||
generate_workflow_id,
|
||||
generate_workflow_run_block_id,
|
||||
generate_workflow_run_id,
|
||||
)
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for ID generator constants."""
|
||||
|
||||
def test_base_epoch(self):
|
||||
"""BASE_EPOCH should be June 20, 2022 12:00 AM UTC."""
|
||||
assert BASE_EPOCH == 1655683200
|
||||
|
||||
def test_version(self):
|
||||
"""Version should be 0."""
|
||||
assert VERSION == 0
|
||||
|
||||
def test_bit_allocation(self):
|
||||
"""Total bits should add up to 64."""
|
||||
total_bits = TIMESTAMP_BITS + WORKER_ID_BITS + SEQUENCE_BITS + VERSION_BITS
|
||||
assert total_bits == 64
|
||||
|
||||
def test_sequence_max(self):
|
||||
"""SEQUENCE_MAX should be 2^10 - 1 = 1023."""
|
||||
assert SEQUENCE_MAX == 1023
|
||||
|
||||
|
||||
class TestMaskShift:
|
||||
"""Tests for the _mask_shift helper function."""
|
||||
|
||||
def test_mask_shift_basic(self):
|
||||
"""Basic mask and shift operation."""
|
||||
# value=5 (binary: 101), mask=3 bits, shift=2
|
||||
# masked: 5 & 0b111 = 5
|
||||
# shifted: 5 << 2 = 20
|
||||
result = _mask_shift(5, 3, 2)
|
||||
assert result == 20
|
||||
|
||||
def test_mask_shift_overflow(self):
|
||||
"""Value exceeding mask bits should be truncated."""
|
||||
# value=15 (binary: 1111), mask=2 bits
|
||||
# masked: 15 & 0b11 = 3
|
||||
# shifted: 3 << 0 = 3
|
||||
result = _mask_shift(15, 2, 0)
|
||||
assert result == 3
|
||||
|
||||
def test_mask_shift_zero(self):
|
||||
"""Zero value should return zero."""
|
||||
assert _mask_shift(0, 10, 5) == 0
|
||||
|
||||
def test_mask_shift_max_bits(self):
|
||||
"""Test with maximum number of bits."""
|
||||
value = (1 << 32) - 1 # Max 32-bit value
|
||||
result = _mask_shift(value, 32, 0)
|
||||
assert result == value
|
||||
|
||||
|
||||
class TestCurrentTime:
|
||||
"""Tests for time-related functions."""
|
||||
|
||||
def test_current_time_is_integer(self):
|
||||
"""current_time should return an integer."""
|
||||
assert isinstance(current_time(), int)
|
||||
|
||||
def test_current_time_is_reasonable(self):
|
||||
"""current_time should be a reasonable Unix timestamp."""
|
||||
now = current_time()
|
||||
# Should be after BASE_EPOCH
|
||||
assert now > BASE_EPOCH
|
||||
# Should be within reasonable bounds (before year 2100)
|
||||
assert now < 4102444800
|
||||
|
||||
def test_current_time_ms_is_integer(self):
|
||||
"""current_time_ms should return an integer."""
|
||||
assert isinstance(current_time_ms(), int)
|
||||
|
||||
def test_current_time_ms_is_milliseconds(self):
|
||||
"""current_time_ms should be roughly 1000x current_time."""
|
||||
t = current_time()
|
||||
t_ms = current_time_ms()
|
||||
# Allow for timing differences
|
||||
assert abs(t_ms - t * 1000) < 2000
|
||||
|
||||
|
||||
class TestGenerateId:
|
||||
"""Tests for the generate_id function."""
|
||||
|
||||
def test_generate_id_returns_integer(self):
|
||||
"""generate_id should return an integer."""
|
||||
assert isinstance(generate_id(), int)
|
||||
|
||||
def test_generate_id_is_positive(self):
|
||||
"""generate_id should return a positive integer."""
|
||||
assert generate_id() > 0
|
||||
|
||||
def test_generate_id_is_64_bit(self):
|
||||
"""generate_id should fit in 64 bits."""
|
||||
id_val = generate_id()
|
||||
assert id_val < (1 << 64)
|
||||
|
||||
def test_generate_id_uniqueness(self):
|
||||
"""Multiple calls should generate unique IDs."""
|
||||
ids = [generate_id() for _ in range(1000)]
|
||||
assert len(set(ids)) == 1000
|
||||
|
||||
def test_generate_id_monotonic(self):
|
||||
"""IDs generated in sequence should generally be increasing."""
|
||||
id1 = generate_id()
|
||||
id2 = generate_id()
|
||||
id3 = generate_id()
|
||||
# Due to sequence counter, IDs should increase
|
||||
# (within the same second, at least)
|
||||
assert id1 < id2 < id3
|
||||
|
||||
|
||||
class TestPrefixedIdGenerators:
|
||||
"""Tests for all prefixed ID generator functions."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"generator,prefix",
|
||||
[
|
||||
(generate_workflow_id, WORKFLOW_PREFIX),
|
||||
(generate_workflow_run_id, WORKFLOW_RUN_PREFIX),
|
||||
(generate_workflow_run_block_id, WORKFLOW_RUN_BLOCK_PREFIX),
|
||||
(generate_task_id, TASK_PREFIX),
|
||||
(generate_step_id, STEP_PREFIX),
|
||||
(generate_artifact_id, ARTIFACT_PREFIX),
|
||||
(generate_user_id, USER_PREFIX),
|
||||
(generate_org_id, ORG_PREFIX),
|
||||
(generate_action_id, ACTION_PREFIX),
|
||||
(generate_task_v2_id, TASK_V2_ID),
|
||||
(generate_thought_id, THOUGHT_ID),
|
||||
(generate_totp_code_id, TOTP_CODE_PREFIX),
|
||||
(generate_credential_id, CREDENTIAL_PREFIX),
|
||||
(generate_debug_session_id, DEBUG_SESSION_PREFIX),
|
||||
(generate_folder_id, FOLDER_PREFIX),
|
||||
],
|
||||
)
|
||||
def test_id_has_correct_prefix(self, generator, prefix):
|
||||
"""Generated ID should have the correct prefix."""
|
||||
generated_id = generator()
|
||||
assert generated_id.startswith(f"{prefix}_")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"generator",
|
||||
[
|
||||
generate_workflow_id,
|
||||
generate_workflow_run_id,
|
||||
generate_task_id,
|
||||
generate_step_id,
|
||||
generate_artifact_id,
|
||||
generate_user_id,
|
||||
generate_org_id,
|
||||
generate_action_id,
|
||||
],
|
||||
)
|
||||
def test_id_format(self, generator):
|
||||
"""Generated ID should match expected format: prefix_integer."""
|
||||
generated_id = generator()
|
||||
assert re.match(r"^[a-z_]+_\d+$", generated_id)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"generator",
|
||||
[
|
||||
generate_workflow_id,
|
||||
generate_task_id,
|
||||
generate_org_id,
|
||||
],
|
||||
)
|
||||
def test_id_uniqueness(self, generator):
|
||||
"""Multiple calls should generate unique IDs."""
|
||||
ids = [generator() for _ in range(100)]
|
||||
assert len(set(ids)) == 100
|
||||
|
||||
def test_different_generators_produce_different_prefixes(self):
|
||||
"""Different generators should produce IDs with different prefixes."""
|
||||
task_id = generate_task_id()
|
||||
step_id = generate_step_id()
|
||||
org_id = generate_org_id()
|
||||
|
||||
assert task_id.split("_")[0] != step_id.split("_")[0]
|
||||
assert task_id.split("_")[0] != org_id.split("_")[0]
|
||||
assert step_id.split("_")[0] != org_id.split("_")[0]
|
||||
|
||||
|
||||
class TestIdExtraction:
|
||||
"""Tests for extracting components from generated IDs."""
|
||||
|
||||
def test_extract_numeric_part(self):
|
||||
"""Should be able to extract numeric part from ID."""
|
||||
task_id = generate_task_id()
|
||||
parts = task_id.split("_")
|
||||
numeric_part = int(parts[-1])
|
||||
assert numeric_part > 0
|
||||
|
||||
def test_numeric_part_is_valid_64bit(self):
|
||||
"""Numeric part should be a valid 64-bit integer."""
|
||||
task_id = generate_task_id()
|
||||
numeric_part = int(task_id.split("_")[-1])
|
||||
assert numeric_part < (1 << 64)
|
||||
63
tests/unit_tests/test_string_util.py
Normal file
63
tests/unit_tests/test_string_util.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from skyvern.webeye.string_util import remove_whitespace
|
||||
|
||||
|
||||
class TestRemoveWhitespace:
|
||||
"""Tests for the remove_whitespace function."""
|
||||
|
||||
def test_remove_multiple_spaces(self):
|
||||
"""Multiple spaces should be collapsed to single space."""
|
||||
assert remove_whitespace("hello world") == "hello world"
|
||||
|
||||
def test_remove_tabs(self):
|
||||
"""Tab characters should be converted to single space."""
|
||||
assert remove_whitespace("hello\tworld") == "hello world"
|
||||
|
||||
def test_remove_newlines(self):
|
||||
"""Newline characters should be converted to single space."""
|
||||
assert remove_whitespace("hello\nworld") == "hello world"
|
||||
|
||||
def test_remove_mixed_whitespace(self):
|
||||
"""Mixed whitespace (spaces, tabs, newlines) should be collapsed."""
|
||||
assert remove_whitespace("hello \t\n world") == "hello world"
|
||||
|
||||
def test_leading_trailing_whitespace(self):
|
||||
"""Leading and trailing whitespace should be collapsed but not removed."""
|
||||
assert remove_whitespace(" hello ") == " hello "
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Empty string should return empty string."""
|
||||
assert remove_whitespace("") == ""
|
||||
|
||||
def test_single_space(self):
|
||||
"""Single space should remain unchanged."""
|
||||
assert remove_whitespace(" ") == " "
|
||||
|
||||
def test_no_whitespace(self):
|
||||
"""String without extra whitespace should remain unchanged."""
|
||||
assert remove_whitespace("hello") == "hello"
|
||||
|
||||
def test_only_whitespace(self):
|
||||
"""String of only whitespace should collapse to single space."""
|
||||
assert remove_whitespace(" \t\n ") == " "
|
||||
|
||||
def test_multiline_text(self):
|
||||
"""Multiline text should have all whitespace collapsed."""
|
||||
input_text = """Hello
|
||||
World
|
||||
Test"""
|
||||
assert remove_whitespace(input_text) == "Hello World Test"
|
||||
|
||||
def test_preserves_non_whitespace_special_chars(self):
|
||||
"""Non-whitespace special characters should be preserved."""
|
||||
assert remove_whitespace("hello!@#$%^&*()world") == "hello!@#$%^&*()world"
|
||||
|
||||
def test_unicode_text(self):
|
||||
"""Unicode text with whitespace should work correctly."""
|
||||
assert remove_whitespace("你好 世界") == "你好 世界"
|
||||
|
||||
def test_carriage_return_not_matched(self):
|
||||
"""Carriage return is not in the regex pattern, verify behavior."""
|
||||
# Note: \r is not in the original regex pattern [ \n\t]+
|
||||
# This test documents the current behavior
|
||||
result = remove_whitespace("hello\rworld")
|
||||
assert result == "hello\rworld"
|
||||
91
tests/unit_tests/test_strings.py
Normal file
91
tests/unit_tests/test_strings.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Tests for string utility functions."""
|
||||
|
||||
import string
|
||||
|
||||
from skyvern.utils.strings import RANDOM_STRING_POOL, generate_random_string
|
||||
|
||||
|
||||
class TestRandomStringPool:
|
||||
"""Tests for RANDOM_STRING_POOL constant."""
|
||||
|
||||
def test_pool_contains_letters(self):
|
||||
"""Pool should contain all ASCII letters."""
|
||||
for char in string.ascii_letters:
|
||||
assert char in RANDOM_STRING_POOL
|
||||
|
||||
def test_pool_contains_digits(self):
|
||||
"""Pool should contain all digits."""
|
||||
for char in string.digits:
|
||||
assert char in RANDOM_STRING_POOL
|
||||
|
||||
def test_pool_size(self):
|
||||
"""Pool should have expected size (26*2 + 10 = 62)."""
|
||||
assert len(RANDOM_STRING_POOL) == 62
|
||||
|
||||
|
||||
class TestGenerateRandomString:
|
||||
"""Tests for generate_random_string function."""
|
||||
|
||||
def test_default_length(self):
|
||||
"""Default length should be 5."""
|
||||
result = generate_random_string()
|
||||
assert len(result) == 5
|
||||
|
||||
def test_custom_length(self):
|
||||
"""Custom length should be respected."""
|
||||
for length in [1, 10, 50, 100]:
|
||||
result = generate_random_string(length)
|
||||
assert len(result) == length
|
||||
|
||||
def test_zero_length(self):
|
||||
"""Zero length should return empty string."""
|
||||
result = generate_random_string(0)
|
||||
assert result == ""
|
||||
|
||||
def test_returns_string(self):
|
||||
"""Should return a string type."""
|
||||
result = generate_random_string()
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_only_alphanumeric(self):
|
||||
"""Result should only contain alphanumeric characters."""
|
||||
result = generate_random_string(100)
|
||||
for char in result:
|
||||
assert char in RANDOM_STRING_POOL
|
||||
|
||||
def test_randomness(self):
|
||||
"""Multiple calls should produce different results (with high probability)."""
|
||||
results = [generate_random_string(20) for _ in range(10)]
|
||||
# All results should be unique (statistically extremely likely with length 20)
|
||||
assert len(set(results)) == 10
|
||||
|
||||
def test_distribution(self):
|
||||
"""Characters should be reasonably distributed."""
|
||||
# Generate a long string and check distribution
|
||||
result = generate_random_string(10000)
|
||||
char_counts = {}
|
||||
for char in result:
|
||||
char_counts[char] = char_counts.get(char, 0) + 1
|
||||
|
||||
# Each character should appear at least once in 10000 characters
|
||||
# (statistically extremely likely)
|
||||
assert len(char_counts) > 50 # Most of the 62 chars should appear
|
||||
|
||||
def test_contains_letters(self):
|
||||
"""Generated strings should typically contain letters."""
|
||||
# With 62 possible chars and length 100, very likely to have letters
|
||||
result = generate_random_string(100)
|
||||
has_letter = any(c in string.ascii_letters for c in result)
|
||||
assert has_letter
|
||||
|
||||
def test_contains_digits(self):
|
||||
"""Generated strings should typically contain digits."""
|
||||
# With 62 possible chars and length 100, very likely to have digits
|
||||
result = generate_random_string(100)
|
||||
has_digit = any(c in string.digits for c in result)
|
||||
assert has_digit
|
||||
|
||||
def test_large_length(self):
|
||||
"""Should handle large lengths."""
|
||||
result = generate_random_string(10000)
|
||||
assert len(result) == 10000
|
||||
108
tests/unit_tests/test_token_counter.py
Normal file
108
tests/unit_tests/test_token_counter.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Tests for token counter utility."""
|
||||
|
||||
from skyvern.utils.token_counter import count_tokens
|
||||
|
||||
|
||||
class TestCountTokens:
|
||||
"""Tests for count_tokens function."""
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Empty string should have 0 tokens."""
|
||||
assert count_tokens("") == 0
|
||||
|
||||
def test_single_word(self):
|
||||
"""Single word should return token count."""
|
||||
result = count_tokens("hello")
|
||||
assert result > 0
|
||||
assert isinstance(result, int)
|
||||
|
||||
def test_simple_sentence(self):
|
||||
"""Simple sentence should have reasonable token count."""
|
||||
result = count_tokens("Hello, world!")
|
||||
assert result > 0
|
||||
# "Hello, world!" typically tokenizes to ~4 tokens
|
||||
assert result < 10
|
||||
|
||||
def test_longer_text(self):
|
||||
"""Longer text should have more tokens."""
|
||||
short = count_tokens("Hi")
|
||||
long = count_tokens("This is a much longer sentence with many more words in it.")
|
||||
assert long > short
|
||||
|
||||
def test_returns_integer(self):
|
||||
"""Should return an integer."""
|
||||
result = count_tokens("test")
|
||||
assert isinstance(result, int)
|
||||
|
||||
def test_whitespace_only(self):
|
||||
"""Whitespace should be tokenized."""
|
||||
result = count_tokens(" ")
|
||||
# Whitespace is typically tokenized
|
||||
assert isinstance(result, int)
|
||||
|
||||
def test_special_characters(self):
|
||||
"""Special characters should be tokenized."""
|
||||
result = count_tokens("!@#$%^&*()")
|
||||
assert result > 0
|
||||
|
||||
def test_numbers(self):
|
||||
"""Numbers should be tokenized."""
|
||||
result = count_tokens("12345")
|
||||
assert result > 0
|
||||
|
||||
def test_unicode(self):
|
||||
"""Unicode characters should be tokenized."""
|
||||
result = count_tokens("你好世界")
|
||||
assert result > 0
|
||||
|
||||
def test_mixed_content(self):
|
||||
"""Mixed content (text, numbers, special chars) should work."""
|
||||
result = count_tokens("Hello123!@#World")
|
||||
assert result > 0
|
||||
|
||||
def test_newlines(self):
|
||||
"""Text with newlines should be tokenized."""
|
||||
result = count_tokens("Hello\nWorld\nTest")
|
||||
assert result > 0
|
||||
|
||||
def test_code_snippet(self):
|
||||
"""Code snippets should be tokenized."""
|
||||
code = """
|
||||
def hello():
|
||||
print("Hello, World!")
|
||||
"""
|
||||
result = count_tokens(code)
|
||||
assert result > 5 # Code should have multiple tokens
|
||||
|
||||
def test_json_content(self):
|
||||
"""JSON content should be tokenized."""
|
||||
json_str = '{"key": "value", "number": 123}'
|
||||
result = count_tokens(json_str)
|
||||
assert result > 0
|
||||
|
||||
def test_url(self):
|
||||
"""URLs should be tokenized."""
|
||||
result = count_tokens("https://www.example.com/path?query=value")
|
||||
assert result > 0
|
||||
|
||||
def test_consistency(self):
|
||||
"""Same input should always produce same output."""
|
||||
text = "This is a test sentence."
|
||||
result1 = count_tokens(text)
|
||||
result2 = count_tokens(text)
|
||||
assert result1 == result2
|
||||
|
||||
def test_very_long_text(self):
|
||||
"""Very long text should be handled."""
|
||||
long_text = "word " * 1000
|
||||
result = count_tokens(long_text)
|
||||
assert result > 100 # Should have many tokens
|
||||
|
||||
def test_token_count_approximation(self):
|
||||
"""Token count should be roughly 1 token per 4 chars for English."""
|
||||
text = "This is a sample text for testing token count approximation."
|
||||
result = count_tokens(text)
|
||||
# GPT tokenizers typically produce ~1 token per 4 characters
|
||||
char_count = len(text)
|
||||
assert result > char_count / 10 # Very loose lower bound
|
||||
assert result < char_count # Token count should be less than char count
|
||||
173
tests/unit_tests/test_url_validators_extended.py
Normal file
173
tests/unit_tests/test_url_validators_extended.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Extended tests for URL validators module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.exceptions import InvalidUrl
|
||||
from skyvern.utils.url_validators import encode_url, prepend_scheme_and_validate_url
|
||||
|
||||
|
||||
class TestEncodeUrl:
|
||||
"""Tests for encode_url function."""
|
||||
|
||||
def test_encode_simple_path(self):
|
||||
"""Simple path with spaces should be encoded."""
|
||||
url = "https://example.com/path with spaces"
|
||||
result = encode_url(url)
|
||||
assert result == "https://example.com/path%20with%20spaces"
|
||||
|
||||
def test_encode_query_params(self):
|
||||
"""Query parameters with spaces should be encoded."""
|
||||
url = "https://example.com/search?q=hello world"
|
||||
result = encode_url(url)
|
||||
assert result == "https://example.com/search?q=hello%20world"
|
||||
|
||||
def test_preserve_slashes_in_path(self):
|
||||
"""Slashes in path should be preserved."""
|
||||
url = "https://example.com/path/to/resource"
|
||||
result = encode_url(url)
|
||||
assert result == "https://example.com/path/to/resource"
|
||||
|
||||
def test_preserve_existing_encoding(self):
|
||||
"""Already encoded characters should be preserved."""
|
||||
url = "https://example.com/path%20already%20encoded"
|
||||
result = encode_url(url)
|
||||
assert result == "https://example.com/path%20already%20encoded"
|
||||
|
||||
def test_encode_special_characters(self):
|
||||
"""Special characters in path should be encoded."""
|
||||
url = "https://example.com/path<with>brackets"
|
||||
result = encode_url(url)
|
||||
# Angle brackets should be percent-encoded
|
||||
assert result == "https://example.com/path%3Cwith%3Ebrackets"
|
||||
|
||||
def test_preserve_query_structure(self):
|
||||
"""Query string structure (= and &) should be preserved."""
|
||||
url = "https://example.com/search?key1=value1&key2=value2"
|
||||
result = encode_url(url)
|
||||
assert "key1=value1" in result
|
||||
assert "key2=value2" in result
|
||||
assert "&" in result
|
||||
|
||||
def test_empty_path(self):
|
||||
"""URL with empty path should work."""
|
||||
url = "https://example.com"
|
||||
result = encode_url(url)
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_encode_unicode_path(self):
|
||||
"""Unicode characters in path should be encoded."""
|
||||
url = "https://example.com/路径"
|
||||
result = encode_url(url)
|
||||
# Unicode should be percent-encoded, original characters should not appear
|
||||
assert "路径" not in result
|
||||
assert "example.com/" in result
|
||||
# "路径" in UTF-8 is encoded as %E8%B7%AF%E5%BE%84
|
||||
assert "%E8%B7%AF%E5%BE%84" in result
|
||||
|
||||
def test_fragment_preserved(self):
|
||||
"""URL fragments should be preserved."""
|
||||
url = "https://example.com/page#section"
|
||||
result = encode_url(url)
|
||||
# Fragment should be preserved in the output
|
||||
assert result == "https://example.com/page#section"
|
||||
|
||||
|
||||
class TestPrependSchemeAndValidateUrl:
|
||||
"""Tests for prepend_scheme_and_validate_url function."""
|
||||
|
||||
def test_empty_url_returns_empty(self):
|
||||
"""Empty URL should return empty string."""
|
||||
assert prepend_scheme_and_validate_url("") == ""
|
||||
|
||||
def test_https_url_unchanged(self):
|
||||
"""URL with https scheme should remain unchanged."""
|
||||
url = "https://example.com"
|
||||
result = prepend_scheme_and_validate_url(url)
|
||||
assert result == url
|
||||
|
||||
def test_http_url_unchanged(self):
|
||||
"""URL with http scheme should remain unchanged."""
|
||||
url = "http://example.com"
|
||||
result = prepend_scheme_and_validate_url(url)
|
||||
assert result == url
|
||||
|
||||
def test_no_scheme_gets_https(self):
|
||||
"""URL without scheme should get https prepended."""
|
||||
url = "example.com"
|
||||
result = prepend_scheme_and_validate_url(url)
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_invalid_scheme_raises_error(self):
|
||||
"""URL with invalid scheme should raise InvalidUrl."""
|
||||
with pytest.raises(InvalidUrl):
|
||||
prepend_scheme_and_validate_url("ftp://example.com")
|
||||
|
||||
def test_file_scheme_raises_error(self):
|
||||
"""URL with file scheme should raise InvalidUrl."""
|
||||
with pytest.raises(InvalidUrl):
|
||||
prepend_scheme_and_validate_url("file:///etc/passwd")
|
||||
|
||||
def test_javascript_scheme_raises_error(self):
|
||||
"""URL with javascript scheme should raise InvalidUrl."""
|
||||
with pytest.raises(InvalidUrl):
|
||||
prepend_scheme_and_validate_url("javascript:alert(1)")
|
||||
|
||||
def test_valid_url_with_path(self):
|
||||
"""Valid URL with path should work."""
|
||||
url = "https://example.com/path/to/resource"
|
||||
result = prepend_scheme_and_validate_url(url)
|
||||
assert result == url
|
||||
|
||||
def test_valid_url_with_query(self):
|
||||
"""Valid URL with query parameters should work."""
|
||||
url = "https://example.com/search?q=test"
|
||||
result = prepend_scheme_and_validate_url(url)
|
||||
assert result == url
|
||||
|
||||
def test_url_with_port(self):
|
||||
"""URL with port number should work."""
|
||||
url = "https://example.com:8080/path"
|
||||
result = prepend_scheme_and_validate_url(url)
|
||||
assert result == url
|
||||
|
||||
def test_subdomain_url(self):
|
||||
"""URL with subdomain should work."""
|
||||
url = "https://api.example.com/v1"
|
||||
result = prepend_scheme_and_validate_url(url)
|
||||
assert result == url
|
||||
|
||||
def test_invalid_url_raises_error(self):
|
||||
"""Completely invalid URL should raise InvalidUrl."""
|
||||
with pytest.raises(InvalidUrl):
|
||||
prepend_scheme_and_validate_url("not a valid url at all!!!")
|
||||
|
||||
|
||||
class TestEncodeUrlEdgeCases:
|
||||
"""Edge case tests for encode_url."""
|
||||
|
||||
def test_double_slashes_in_path(self):
|
||||
"""Double slashes in path should be preserved."""
|
||||
url = "https://example.com//double//slashes"
|
||||
result = encode_url(url)
|
||||
assert "//double//slashes" in result
|
||||
|
||||
def test_url_with_credentials(self):
|
||||
"""URL with credentials should preserve the authority portion."""
|
||||
url = "https://user:pass@example.com/path"
|
||||
result = encode_url(url)
|
||||
# Full authority portion including credentials should be preserved
|
||||
assert "user:pass@example.com" in result
|
||||
|
||||
def test_very_long_url(self):
|
||||
"""Very long URL should be handled."""
|
||||
long_path = "/a" * 1000
|
||||
url = f"https://example.com{long_path}"
|
||||
result = encode_url(url)
|
||||
assert len(result) >= len(url)
|
||||
|
||||
def test_url_with_multiple_query_params(self):
|
||||
"""URL with multiple query parameters should work."""
|
||||
url = "https://example.com/search?a=1&b=2&c=3&d=4&e=5"
|
||||
result = encode_url(url)
|
||||
assert "a=1" in result
|
||||
assert "e=5" in result
|
||||
Reference in New Issue
Block a user