22# http://www-graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
33
44from bitarray import bitarray
5- from bitarray .util import ba2int , int2ba
5+ from bitarray .util import zeros , ones , ba2int , int2ba
66
77from math import comb
88
@@ -13,16 +13,18 @@ def all_perm(n, k, endian=None):
1313Return an iterator over all bitarrays of length `n` with `k` bits set to 1
1414in lexicographical order.
1515"""
16- n = int (n )
1716 if n < 0 :
1817 raise ValueError ("length must be >= 0" )
19- k = int (k )
20- if k < 0 or k > n :
21- raise ValueError ("number of set bits must be in range(0, n + 1)" )
2218
23- if k == 0 :
24- yield bitarray (n , endian )
25- return
19+ # error check inputs and handle edge cases
20+ if k <= 0 or k > n :
21+ if k == 0 :
22+ yield zeros (n , endian )
23+ return
24+ if k == n :
25+ yield ones (n , endian )
26+ return
27+ raise ValueError ("k must be in range 0 <= k <= n, got %s" % k )
2628
2729 v = (1 << k ) - 1
2830 for _ in range (comb (n , k )):
@@ -146,7 +148,11 @@ def test_all_perm_explicit(self):
146148 (1 , 1 , ['1' ]),
147149 (2 , 0 , ['00' ]),
148150 (2 , 1 , ['01' , '10' ]),
151+ (2 , 2 , ['11' ]),
152+ (3 , 0 , ['000' ]),
153+ (3 , 1 , ['001' , '010' , '100' ]),
149154 (3 , 2 , ['011' , '101' , '110' ]),
155+ (3 , 3 , ['111' ]),
150156 ]:
151157 self .assertEqual (list (all_perm (n , k , 'big' )),
152158 [bitarray (s ) for s in res ])
0 commit comments