Skip to content

Commit

Permalink
.skip: add the until param (#56)
Browse files Browse the repository at this point in the history
Co-Author: maximebonnal <bonnalmaxime@gmail.com>
  • Loading branch information
maximebonnal authored Jan 17, 2025
1 parent 7f9a9a0 commit 0df6fc9
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 21 deletions.
16 changes: 10 additions & 6 deletions streamable/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CatchIterator,
ConcurrentFlattenIterator,
ConsecutiveDistinctIterator,
CountSkipIterator,
CountTruncateIterator,
DistinctIterator,
FlattenIterator,
Expand All @@ -29,18 +30,18 @@
IntervalThrottleIterator,
ObserveIterator,
OSConcurrentMapIterator,
PredicateSkipIterator,
PredicateTruncateIterator,
SkipIterator,
YieldsPerPeriodThrottleIterator,
)
from streamable.util.constants import NO_REPLACEMENT
from streamable.util.functiontools import wrap_error
from streamable.util.validationtools import (
validate_concurrency,
validate_count,
validate_group_interval,
validate_group_size,
validate_iterator,
validate_skip_args,
validate_throttle_interval,
validate_throttle_per_period,
validate_truncate_args,
Expand Down Expand Up @@ -165,12 +166,15 @@ def observe(iterator: Iterator[T], what: str) -> Iterator[T]:

def skip(
iterator: Iterator[T],
count: int,
count: Optional[int] = None,
until: Optional[Callable[[T], Any]] = None,
) -> Iterator[T]:
validate_iterator(iterator)
validate_count(count)
if count > 0:
iterator = SkipIterator(iterator, count)
validate_skip_args(count, until)
if until is not None:
iterator = PredicateSkipIterator(iterator, until)
elif count is not None:
iterator = CountSkipIterator(iterator, count)
return iterator


Expand Down
32 changes: 26 additions & 6 deletions streamable/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def __next__(self) -> T:


class DistinctIterator(Iterator[T]):
def __init__(self, iterator: Iterator[T], key: Optional[Callable[[T], Any]]) -> None:
def __init__(
self, iterator: Iterator[T], key: Optional[Callable[[T], Any]]
) -> None:
validate_iterator(iterator)
self.iterator = iterator
self.key = wrap_error(key, StopIteration) if key else None
Expand All @@ -113,7 +115,9 @@ def __next__(self) -> T:


class ConsecutiveDistinctIterator(Iterator[T]):
def __init__(self, iterator: Iterator[T], key: Optional[Callable[[T], Any]]) -> None:
def __init__(
self, iterator: Iterator[T], key: Optional[Callable[[T], Any]]
) -> None:
validate_iterator(iterator)
self.iterator = iterator
self.key = wrap_error(key, StopIteration) if key else None
Expand Down Expand Up @@ -276,7 +280,7 @@ def __next__(self) -> Tuple[U, List[T]]:
return next(self)


class SkipIterator(Iterator[T]):
class CountSkipIterator(Iterator[T]):
def __init__(self, iterator: Iterator[T], count: int) -> None:
validate_iterator(iterator)
validate_count(count)
Expand All @@ -285,7 +289,7 @@ def __init__(self, iterator: Iterator[T], count: int) -> None:
self._n_skipped = 0
self._done_skipping = False

def __next__(self):
def __next__(self) -> T:
if not self._done_skipping:
while self._n_skipped < self.count:
next(self.iterator)
Expand All @@ -295,6 +299,22 @@ def __next__(self):
return next(self.iterator)


class PredicateSkipIterator(Iterator[T]):
def __init__(self, iterator: Iterator[T], until: Callable[[T], Any]) -> None:
validate_iterator(iterator)
self.iterator = iterator
self.until = wrap_error(until, StopIteration)
self._done_skipping = False

def __next__(self) -> T:
elem = next(self.iterator)
if not self._done_skipping:
while not self.until(elem):
elem = next(self.iterator)
self._done_skipping = True
return elem


class CountTruncateIterator(Iterator[T]):
def __init__(self, iterator: Iterator[T], count: int) -> None:
validate_iterator(iterator)
Expand All @@ -303,7 +323,7 @@ def __init__(self, iterator: Iterator[T], count: int) -> None:
self.count = count
self._current_count = 0

def __next__(self):
def __next__(self) -> T:
if self._current_count == self.count:
raise StopIteration()
elem = next(self.iterator)
Expand All @@ -318,7 +338,7 @@ def __init__(self, iterator: Iterator[T], when: Callable[[T], Any]) -> None:
self.when = wrap_error(when, StopIteration)
self._satisfied = False

def __next__(self):
def __next__(self) -> T:
if self._satisfied:
raise StopIteration()
elem = next(self.iterator)
Expand Down
19 changes: 12 additions & 7 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from streamable.util.loggertools import get_logger
from streamable.util.validationtools import (
validate_concurrency,
validate_count,
validate_group_interval,
validate_group_size,
validate_skip_args,
validate_throttle_interval,
validate_throttle_per_period,
validate_truncate_args,
Expand Down Expand Up @@ -429,18 +429,21 @@ def observe(self, what: str = "elements") -> "Stream[T]":
"""
return ObserveStream(self, what)

def skip(self, count: int) -> "Stream[T]":
def skip(
self, count: Optional[int] = None, until: Optional[Callable[[T], Any]] = None
) -> "Stream[T]":
"""
Skips the first `count` elements.
Skips the first `count` elements, or skips `until` a predicate becomes satisfied.
Args:
count (int): The number of elements to skip.
count (Optional[int], optional): The number of elements to skip. (by default: no count-based skipping)
until (Optional[Callable[[T], Any]], optional): Elements are skipped until the first one for which `until(elem)` is truthy. This element and all the subsequent ones will be yielded. (by default: no predicate-based skipping)
Returns:
Stream: A stream of the upstream elements remaining after skipping.
"""
validate_count(count)
return SkipStream(self, count)
validate_skip_args(count, until)
return SkipStream(self, count, until)

def throttle(
self,
Expand Down Expand Up @@ -679,10 +682,12 @@ class SkipStream(DownStream[T, T]):
def __init__(
self,
upstream: Stream[T],
count: int,
count: Optional[int],
until: Optional[Callable[[T], Any]],
) -> None:
super().__init__(upstream)
self._count = count
self._until = until

def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_skip_stream(self)
Expand Down
12 changes: 12 additions & 0 deletions streamable/util/validationtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,15 @@ def validate_truncate_args(
raise ValueError("`count` and `when` cannot both be None")
else:
validate_count(count)


def validate_skip_args(
count: Optional[int] = None, until: Optional[Callable[[T], Any]] = None
) -> None:
if count is None:
if until is None:
raise ValueError("`count` and `until` cannot both be None")
else:
if until is not None:
raise ValueError("`count` and `until` cannot both be set")
validate_count(count)
1 change: 1 addition & 0 deletions streamable/visitors/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def visit_skip_stream(self, stream: SkipStream[T]) -> Iterator[T]:
return functions.skip(
stream.upstream.accept(self),
stream._count,
stream._until,
)

def visit_throttle_stream(self, stream: ThrottleStream[T]) -> Iterator[T]:
Expand Down
4 changes: 3 additions & 1 deletion streamable/visitors/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def visit_observe_stream(self, stream: ObserveStream[T]) -> str:
return stream.upstream.accept(self)

def visit_skip_stream(self, stream: SkipStream[T]) -> str:
self.methods_reprs.append(f"skip({self.to_string(stream._count)})")
self.methods_reprs.append(
f"skip({self.to_string(stream._count)}, until={self.to_string(stream._until)})"
)
return stream.upstream.accept(self)

def visit_throttle_stream(self, stream: ThrottleStream[T]) -> str:
Expand Down
30 changes: 29 additions & 1 deletion tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class CustomCallable:
Stream(src)
.truncate(1024, when=lambda _: False)
.skip(10)
.skip(until=lambda _: True)
.distinct(lambda _: _)
.filter()
.map(lambda i: (i,))
Expand Down Expand Up @@ -258,7 +259,8 @@ class CustomCallable:
"""(
Stream(range(0, 256))
.truncate(count=1024, when=<lambda>)
.skip(10)
.skip(10, until=None)
.skip(None, until=<lambda>)
.distinct(<lambda>, consecutive_only=False)
.filter(bool)
.map(<lambda>, concurrency=1, ordered=True)
Expand Down Expand Up @@ -811,6 +813,20 @@ def test_skip(self) -> None:
):
Stream(src).skip(-1)

with self.assertRaisesRegex(
ValueError,
"`count` and `until` cannot both be set",
msg="`skip` must raise ValueError if both `count` and `until` are set",
):
Stream(src).skip(0, until=bool)

with self.assertRaisesRegex(
ValueError,
"`count` and `until` cannot both be None",
msg="`skip` must raise ValueError if both `count` and `until` are None",
):
Stream(src).skip()

for count in [0, 1, 3]:
self.assertEqual(
list(Stream(src).skip(count)),
Expand All @@ -828,6 +844,18 @@ def test_skip(self) -> None:
msg="`skip` must not count exceptions as skipped elements",
)

self.assertEqual(
list(Stream(src).skip(until=lambda n: n >= count)),
list(src)[count:],
msg="`skip` must yield starting from the first element satisfying `until`",
)

self.assertEqual(
list(Stream(src).skip(until=lambda n: False)),
[],
msg="`skip` must not yield any element if `until` is never satisfied",
)

def test_truncate(self) -> None:
with self.assertRaisesRegex(
ValueError,
Expand Down

0 comments on commit 0df6fc9

Please sign in to comment.