Skip to content

Commit ac4d6a5

Browse files
committed
Add tests
1 parent 5af75af commit ac4d6a5

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

test_selectlib.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/usr/bin/env python
2+
import unittest
3+
import random
4+
5+
# Import the compiled selectlib module.
6+
import selectlib
7+
8+
class TestQuickselect(unittest.TestCase):
9+
10+
def sorted_index_check(self, values, k):
11+
"""
12+
Helper function: Given a list and a target index k,
13+
use quickselect to partition the list in place, then verify
14+
that the element at index k equals the kth smallest element.
15+
"""
16+
# Create a copy of the original list to compute the sorted target.
17+
expected = sorted(values)
18+
# Call quickselect: this mutates the list in-place.
19+
selectlib.quickselect(values, k)
20+
# Check that the element at index k is what we expect.
21+
self.assertEqual(values[k], expected[k])
22+
23+
# Additionally, verify that all elements before index k are less than or equal
24+
# to the kth element, and all elements after index k are greater than or equal.
25+
kth_value = values[k]
26+
for item in values[:k]:
27+
self.assertLessEqual(item, kth_value)
28+
for item in values[k+1:]:
29+
self.assertGreaterEqual(item, kth_value)
30+
31+
def test_ordered_list(self):
32+
# Test on a sorted list.
33+
values = list(range(10))
34+
k = 5
35+
selectlib.quickselect(values, k)
36+
self.assertEqual(values[k], 5)
37+
# Check partition condition.
38+
for item in values[:k]:
39+
self.assertLessEqual(item, values[k])
40+
for item in values[k+1:]:
41+
self.assertGreaterEqual(item, values[k])
42+
43+
def test_reversed_list(self):
44+
# Test on a reverse-sorted list.
45+
values = list(range(10, 0, -1))
46+
k = 3
47+
self.sorted_index_check(values, k)
48+
49+
def test_random_list(self):
50+
# Test on a list of random integers.
51+
values = [random.randint(0, 100) for _ in range(20)]
52+
k = random.randint(0, len(values) - 1)
53+
self.sorted_index_check(values, k)
54+
55+
def test_with_duplicates(self):
56+
# Test on a list with duplicate values.
57+
values = [5, 1, 3, 5, 2, 5, 4, 1, 3]
58+
k = 4
59+
self.sorted_index_check(values, k)
60+
61+
def test_with_key_function(self):
62+
# Test the 'key' argument.
63+
# In this example, we use a simple key that returns the negative of the value,
64+
# effectively partitioning to find the kth largest element.
65+
values = [random.randint(0, 100) for _ in range(15)]
66+
k = 7 # kth largest element if we sort descending
67+
# Make a copy for expected result.
68+
expected = sorted(values, key=lambda x: -x)
69+
# When using a key, quickselect should partition based on the key.
70+
selectlib.quickselect(values, k, key=lambda x: -x)
71+
self.assertEqual(values[k], expected[k])
72+
kth_value = values[k]
73+
# Check that all prior items have keys less than or equal to the kth item.
74+
for item in values[:k]:
75+
self.assertLessEqual(-item, -kth_value)
76+
for item in values[k+1:]:
77+
self.assertGreaterEqual(-item, -kth_value)
78+
79+
def test_non_list_input(self):
80+
# Test that providing a non-list as values raises a TypeError.
81+
with self.assertRaises(TypeError):
82+
selectlib.quickselect("not a list", 0)
83+
84+
def test_out_of_range_index(self):
85+
# Test that an out-of-range index raises an IndexError.
86+
values = [3, 1, 2]
87+
with self.assertRaises(IndexError):
88+
selectlib.quickselect(values, 5)
89+
90+
if __name__ == '__main__':
91+
unittest.main()

0 commit comments

Comments
 (0)