From e73a0fbefb803556202db7dc162ba86009dd0fd3 Mon Sep 17 00:00:00 2001 From: ebonnal Date: Fri, 17 Jan 2025 00:57:46 +0100 Subject: [PATCH] test_stream: use specialized assertions --- tests/test_stream.py | 112 +++++++++++++++++++++++-------------------- 1 file changed, 59 insertions(+), 53 deletions(-) diff --git a/tests/test_stream.py b/tests/test_stream.py index 9019325..af4f38d 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -154,7 +154,7 @@ def range_raising_at_exhaustion( class TestStream(unittest.TestCase): def test_init(self) -> None: stream = Stream(src) - self.assertEqual( + self.assertIs( stream._source, src, msg="The stream's `source` must be the source argument.", @@ -644,7 +644,7 @@ def test_flatten_concurrency(self) -> None: ).flatten(concurrency=2), times=3, ) - self.assertEqual( + self.assertListEqual( res, ["a", "b"] * iterable_size + ["c"] * iterable_size, msg="`flatten` should process 'a's and 'b's concurrently and then 'c's", @@ -828,13 +828,13 @@ def test_skip(self) -> None: Stream(src).skip() for count in [0, 1, 3]: - self.assertEqual( + self.assertListEqual( list(Stream(src).skip(count)), list(src)[count:], msg="`skip` must skip `count` elements", ) - self.assertEqual( + self.assertListEqual( list( Stream(map(throw_for_odd_func(TestError), src)) .skip(count) @@ -844,13 +844,13 @@ def test_skip(self) -> None: msg="`skip` must not count exceptions as skipped elements", ) - self.assertEqual( + self.assertListEqual( list(Stream(src).skip(until=lambda n: n >= count)), list(src)[count:], msg="`skip` must yield starting from the first element satisfying `until`", ) - self.assertEqual( + self.assertListEqual( list(Stream(src).skip(until=lambda n: False)), [], msg="`skip` must not yield any element if `until` is never satisfied", @@ -863,22 +863,22 @@ def test_truncate(self) -> None: ): Stream(src).truncate() - self.assertEqual( + self.assertListEqual( list(Stream(src).truncate(N * 2)), list(src), msg="`truncate` must be ok with count >= stream length", ) - self.assertEqual( + self.assertListEqual( list(Stream(src).truncate(2)), [0, 1], msg="`truncate` must be ok with count >= 1", ) - self.assertEqual( + self.assertListEqual( list(Stream(src).truncate(1)), [0], msg="`truncate` must be ok with count == 1", ) - self.assertEqual( + self.assertListEqual( list(Stream(src).truncate(0)), [], msg="`truncate` must be ok with count == 0", @@ -908,7 +908,7 @@ def test_truncate(self) -> None: ): next(raising_stream_iterator) - self.assertEqual(list(raising_stream_iterator), list(range(1, count + 1))) + self.assertListEqual(list(raising_stream_iterator), list(range(1, count + 1))) with self.assertRaises( StopIteration, @@ -917,7 +917,7 @@ def test_truncate(self) -> None: next(raising_stream_iterator) iter_truncated_on_predicate = iter(Stream(src).truncate(when=lambda n: n == 5)) - self.assertEqual( + self.assertListEqual( list(iter_truncated_on_predicate), list(Stream(src).truncate(5)), msg="`when` n == 5 must be equivalent to `count` = 5", @@ -934,13 +934,13 @@ def test_truncate(self) -> None: ): list(Stream(src).truncate(when=lambda _: 1 / 0)) - self.assertEqual( + self.assertListEqual( list(Stream(src).truncate(6, when=lambda n: n == 5)), list(range(5)), msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", ) - self.assertEqual( + self.assertListEqual( list(Stream(src).truncate(5, when=lambda n: n == 6)), list(range(5)), msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", @@ -1108,7 +1108,7 @@ def f(i): size=3, by=lambda n: throw(StopIteration) if n == 2 else n ) ) - self.assertEqual( + self.assertListEqual( [next(stream_iter), next(stream_iter)], [[0], [1]], msg="`group` should yield incomplete groups when `by` raises", @@ -1119,7 +1119,7 @@ def f(i): msg="`group` should raise and skip `elem` if `by(elem)` raises", ): next(stream_iter) - self.assertEqual( + self.assertListEqual( next(stream_iter), [3], msg="`group` should continue yielding after `by`'s exception has been raised.", @@ -1163,24 +1163,27 @@ def slow_first_elem(elem: int): time.sleep(super_slow_elem_pull_seconds) return elem - for stream, expected_elems in [ - ( - Stream(map(slow_first_elem, integers)).throttle( - interval=datetime.timedelta(seconds=interval_seconds) + for stream, expected_elems in cast( + List[Tuple[Stream, List]], + [ + ( + Stream(map(slow_first_elem, integers)).throttle( + interval=datetime.timedelta(seconds=interval_seconds) + ), + list(integers), ), - list(integers), - ), - ( - Stream(map(throw_func(TestError), map(slow_first_elem, integers))) - .throttle(interval=datetime.timedelta(seconds=interval_seconds)) - .catch(TestError), - [], - ), - ]: + ( + Stream(map(throw_func(TestError), map(slow_first_elem, integers))) + .throttle(interval=datetime.timedelta(seconds=interval_seconds)) + .catch(TestError), + [], + ), + ], + ): with self.subTest(stream=stream): duration, res = timestream(stream) - self.assertEqual( + self.assertListEqual( res, expected_elems, msg="`throttle` with `interval` must yield upstream elements", @@ -1212,21 +1215,24 @@ def slow_first_elem(elem: int): for N in [1, 10, 11]: integers = range(N) per_second = 2 - for stream, expected_elems in [ - ( - Stream(integers).throttle(per_second=per_second), - list(integers), - ), - ( - Stream(map(throw_func(TestError), integers)) - .throttle(per_second=per_second) - .catch(TestError), - [], - ), - ]: + for stream, expected_elems in cast( + List[Tuple[Stream, List]], + [ + ( + Stream(integers).throttle(per_second=per_second), + list(integers), + ), + ( + Stream(map(throw_func(TestError), integers)) + .throttle(per_second=per_second) + .catch(TestError), + [], + ), + ], + ): with self.subTest(N=N, stream=stream): duration, res = timestream(stream) - self.assertEqual( + self.assertListEqual( res, expected_elems, msg="`throttle` with `per_second` must yield upstream elements", @@ -1260,18 +1266,18 @@ def slow_first_elem(elem: int): ) def test_distinct(self) -> None: - self.assertEqual( + self.assertListEqual( list(Stream("abbcaabcccddd").distinct()), list("abcd"), msg="`distinct` should yield distinct elements", ) - self.assertEqual( + self.assertListEqual( list(Stream("aabbcccaabbcccc").distinct(consecutive_only=True)), list("abcabc"), msg="`distinct` should only remove the duplicates that are consecutive if `consecutive_only=True`", ) for consecutive_only in [True, False]: - self.assertEqual( + self.assertListEqual( list( Stream(["foo", "bar", "a", "b"]).distinct( len, consecutive_only=consecutive_only @@ -1280,7 +1286,7 @@ def test_distinct(self) -> None: ["foo", "a"], msg="`distinct` should yield the first encountered elem among duplicates", ) - self.assertEqual( + self.assertListEqual( list(Stream([]).distinct(consecutive_only=consecutive_only)), [], msg="`distinct` should yield zero elements on empty stream", @@ -1293,7 +1299,7 @@ def test_distinct(self) -> None: list(Stream([[1]]).distinct()) def test_catch(self) -> None: - self.assertEqual( + self.assertListEqual( list(Stream(src).catch(finally_raise=True)), list(src), msg="`catch` should yield elements in exception-less scenarios", @@ -1375,7 +1381,7 @@ def f(i): only_catched_errors_stream = Stream( map(lambda _: throw(TestError), range(2000)) ).catch(TestError) - self.assertEqual( + self.assertListEqual( list(only_catched_errors_stream), [], msg="When upstream raise exceptions without yielding any element, listing the stream must return empty list, without recursion issue.", @@ -1417,7 +1423,7 @@ def f(i): ) ) - self.assertEqual( + self.assertListEqual( list( Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).catch( ZeroDivisionError, replacement=float("inf") @@ -1426,7 +1432,7 @@ def f(i): [float("inf"), 1, 0.5, 0.25], msg="`catch` should be able to yield a non-None replacement", ) - self.assertEqual( + self.assertListEqual( list( Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).catch( ZeroDivisionError, replacement=cast(float, None) @@ -1487,7 +1493,7 @@ def test_call(self) -> None: stream, msg="`__call__` should return the stream.", ) - self.assertEqual( + self.assertListEqual( l, list(src), msg="`__call__` should exhaust the stream.", @@ -1496,7 +1502,7 @@ def test_call(self) -> None: def test_multiple_iterations(self) -> None: stream = Stream(src) for _ in range(3): - self.assertEqual( + self.assertListEqual( list(stream), list(src), msg="The first iteration over a stream should yield the same elements as any subsequent iteration on the same stream, even if it is based on a `source` returning an iterator that only support 1 iteration.",