Skip to content

Commit

Permalink
Fix lzip(strict) for older Pythons
Browse files Browse the repository at this point in the history
  • Loading branch information
Suor committed Feb 4, 2023
1 parent faef75e commit 1de0e19
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
54 changes: 51 additions & 3 deletions funcy/seqs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from itertools import islice, chain, tee, groupby, filterfalse, accumulate, \
takewhile as _takewhile, dropwhile as _dropwhile
from collections.abc import Sequence
Expand Down Expand Up @@ -426,9 +427,56 @@ def pairwise(seq):
next(b, None)
return zip(a, b)

def lzip(*seqs, strict=False):
"""List zip() version."""
return list(zip(*seqs, strict=strict))
if sys.version_info >= (3, 10):
def lzip(*seqs, strict=False):
"""List zip() version."""
return list(zip(*seqs, strict=strict))
else:
def lzip(*seqs, strict=False):
"""List zip() version."""
if strict and len(seqs) > 1:
return list(_zip_strict(*seqs))
return list(zip(*seqs))

def _zip_strict(*seqs):
try:
# Try compare lens if they are available and use a fast zip() builtin
len_1 = len(seqs[0])
for i, s in enumerate(seqs, start=1):
len_i = len(s)
if len_i != len_1:
short_i, long_i = (1, i) if len_1 < len_i else (i, 1)
raise _zip_strict_error(short_i, long_i)
except TypeError:
return _zip_strict_iters(*seqs)
else:
return zip(*seqs)

def _zip_strict_iters(*seqs):
iters = [iter(s) for s in seqs]
while True:
values, stop_i, val_i = [], 0, 0
for i, it in enumerate(iters, start=1):
try:
values.append(next(it))
if not val_i:
val_i = i
except StopIteration:
if not stop_i:
stop_i = i

if stop_i:
if val_i:
raise _zip_strict_error(stop_i, val_i)
break
yield tuple(values)

def _zip_strict_error(short_i, long_i):
if short_i == 1:
return ValueError("zip() argument %d is longer than argument 1" % long_i)
else:
start = "argument 1" if short_i == 2 else "argument 1-%d" % (short_i - 1)
return ValueError("zip() argument %d is shorter than %s" % (short_i, start))


def _reductions(f, seq, acc):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_seqs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterator
from operator import add
import sys
import pytest
from whatever import _

Expand Down Expand Up @@ -185,6 +186,22 @@ def test_with_next():
def test_pairwise():
assert list(pairwise(range(3))) == [(0, 1), (1, 2)]

def test_lzip():
assert lzip('12', 'xy') == [('1', 'x'), ('2', 'y')]
assert lzip('123', 'xy') == [('1', 'x'), ('2', 'y')]
assert lzip('12', 'xyz') == [('1', 'x'), ('2', 'y')]
assert lzip('12', iter('xyz')) == [('1', 'x'), ('2', 'y')]

def test_lzip_strict():
assert lzip('123', 'xy', strict=False) == [('1', 'x'), ('2', 'y')]
assert lzip('12', 'xy', strict=True) == [('1', 'x'), ('2', 'y')]
assert lzip('12', iter('xy'), strict=True) == [('1', 'x'), ('2', 'y')]
for wrap in (str, iter):
with pytest.raises(ValueError): lzip(wrap('123'), wrap('xy'), strict=True)
with pytest.raises(ValueError): lzip(wrap('12'), wrap('xyz'), wrap('abcd'), strict=True)
with pytest.raises(ValueError): lzip(wrap('123'), wrap('xy'), wrap('abcd'), strict=True)
with pytest.raises(ValueError): lzip(wrap('123'), wrap('xyz'), wrap('ab'), strict=True)


def test_reductions():
assert lreductions(add, []) == []
Expand Down

0 comments on commit 1de0e19

Please sign in to comment.