Skip to content

Commit

Permalink
Merge locate_seq and locate
Browse files Browse the repository at this point in the history
  • Loading branch information
bbayles committed Jul 3, 2018
1 parent 8c49fff commit 44a3cc4
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions more_itertools/more.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,7 +1463,7 @@ def count_cycle(iterable, n=None):
return ((i, item) for i in counter for item in iterable)


def locate(iterable, pred=bool):
def locate(iterable, pred=bool, n=None):
"""Yield the index of each item in *iterable* for which *pred* returns
``True``.
Expand All @@ -1473,18 +1473,17 @@ def locate(iterable, pred=bool):
[1, 2, 4]
Set *pred* to a custom function to, e.g., find the indexes for a particular
item:
item.
>>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
[1, 3]
Use with :func:`windowed` to find the indexes of a sub-sequence:
If *n* is given, the argument given to the *pred* function will be a tuple
with containing *n* items. This enables searching for sub-sequences:
>>> from more_itertools import windowed
>>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
>>> sub = [1, 2, 3]
>>> pred = lambda w: w == tuple(sub) # windowed() returns tuples
>>> list(locate(windowed(iterable, len(sub)), pred=pred))
>>> pred = lambda sub: sub == (1, 2, 3)
>>> list(locate(iterable, pred=pred, n=3))
[1, 5, 9]
Use with :func:`seekable` to find indexes and then retrieve the associated
Expand All @@ -1502,22 +1501,11 @@ def locate(iterable, pred=bool):
106
"""
return compress(count(), map(pred, iterable))


def locate_seq(iterable, seq):
"""Yield each index in *iterable* where the sequence *seq* begins.
>>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
>>> seq = [1, 2, 3]
>>> list(locate_seq(iterable, seq))
[1, 5, 9]
if n is None:
return compress(count(), map(pred, iterable))

"""
seq = tuple(seq)
it = windowed(iterable, len(seq))
pred = lambda w: w == seq
return locate(it, pred=pred)
it = windowed(iterable, n)
return compress(count(), map(pred, it))


def lstrip(iterable, pred):
Expand Down

0 comments on commit 44a3cc4

Please sign in to comment.