Skip to content

Commit

Permalink
bpo-40282: Allow random.getrandbits(0) (GH-19539)
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou authored Apr 17, 2020
1 parent d7c657d commit 75a3378
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 44 deletions.
3 changes: 3 additions & 0 deletions Doc/library/random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ Bookkeeping functions
as an optional part of the API. When available, :meth:`getrandbits` enables
:meth:`randrange` to handle arbitrarily large ranges.

.. versionchanged:: 3.9
This method now accepts zero for *k*.


.. function:: randbytes(n)

Expand Down
6 changes: 4 additions & 2 deletions Lib/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ def randint(self, a, b):
def _randbelow_with_getrandbits(self, n):
"Return a random int in the range [0,n). Raises ValueError if n==0."

if not n:
raise ValueError("Boundary cannot be zero")
getrandbits = self.getrandbits
k = n.bit_length() # don't use (n-1) here because n can be 1
r = getrandbits(k) # 0 <= r < 2**k
Expand Down Expand Up @@ -733,8 +735,8 @@ def random(self):

def getrandbits(self, k):
"""getrandbits(k) -> x. Generates an int with k random bits."""
if k <= 0:
raise ValueError('number of bits must be greater than zero')
if k < 0:
raise ValueError('number of bits must be non-negative')
numbytes = (k + 7) // 8 # bits / 8 and rounded up
x = int.from_bytes(_urandom(numbytes), 'big')
return x >> (numbytes * 8 - k) # trim excess bits
Expand Down
69 changes: 29 additions & 40 deletions Lib/test/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,31 @@ def test_gauss(self):
self.assertEqual(x1, x2)
self.assertEqual(y1, y2)

def test_getrandbits(self):
# Verify ranges
for k in range(1, 1000):
self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)
self.assertEqual(self.gen.getrandbits(0), 0)

# Verify all bits active
getbits = self.gen.getrandbits
for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
all_bits = 2**span-1
cum = 0
cpl_cum = 0
for i in range(100):
v = getbits(span)
cum |= v
cpl_cum |= all_bits ^ v
self.assertEqual(cum, all_bits)
self.assertEqual(cpl_cum, all_bits)

# Verify argument checking
self.assertRaises(TypeError, self.gen.getrandbits)
self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
self.assertRaises(ValueError, self.gen.getrandbits, -1)
self.assertRaises(TypeError, self.gen.getrandbits, 10.1)

def test_pickling(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
state = pickle.dumps(self.gen, proto)
Expand Down Expand Up @@ -390,26 +415,6 @@ def test_randrange_errors(self):
raises(0, 42, 0)
raises(0, 42, 3.14159)

def test_genrandbits(self):
# Verify ranges
for k in range(1, 1000):
self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)

# Verify all bits active
getbits = self.gen.getrandbits
for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
cum = 0
for i in range(100):
cum |= getbits(span)
self.assertEqual(cum, 2**span-1)

# Verify argument checking
self.assertRaises(TypeError, self.gen.getrandbits)
self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
self.assertRaises(ValueError, self.gen.getrandbits, 0)
self.assertRaises(ValueError, self.gen.getrandbits, -1)
self.assertRaises(TypeError, self.gen.getrandbits, 10.1)

def test_randbelow_logic(self, _log=log, int=int):
# check bitcount transition points: 2**i and 2**(i+1)-1
# show that: k = int(1.001 + _log(n, 2))
Expand Down Expand Up @@ -629,34 +634,18 @@ def test_rangelimits(self):
self.assertEqual(set(range(start,stop)),
set([self.gen.randrange(start,stop) for i in range(100)]))

def test_genrandbits(self):
def test_getrandbits(self):
super().test_getrandbits()

# Verify cross-platform repeatability
self.gen.seed(1234567)
self.assertEqual(self.gen.getrandbits(100),
97904845777343510404718956115)
# Verify ranges
for k in range(1, 1000):
self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k)

# Verify all bits active
getbits = self.gen.getrandbits
for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]:
cum = 0
for i in range(100):
cum |= getbits(span)
self.assertEqual(cum, 2**span-1)

# Verify argument checking
self.assertRaises(TypeError, self.gen.getrandbits)
self.assertRaises(TypeError, self.gen.getrandbits, 'a')
self.assertRaises(TypeError, self.gen.getrandbits, 1, 2)
self.assertRaises(ValueError, self.gen.getrandbits, 0)
self.assertRaises(ValueError, self.gen.getrandbits, -1)

def test_randrange_uses_getrandbits(self):
# Verify use of getrandbits by randrange
# Use same seed as in the cross-platform repeatability test
# in test_genrandbits above.
# in test_getrandbits above.
self.gen.seed(1234567)
# If randrange uses getrandbits, it should pick getrandbits(100)
# when called with a 100-bits stop argument.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow ``random.getrandbits(0)`` to succeed and to return 0.
7 changes: 5 additions & 2 deletions Modules/_randommodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,15 @@ _random_Random_getrandbits_impl(RandomObject *self, int k)
uint32_t *wordarray;
PyObject *result;

if (k <= 0) {
if (k < 0) {
PyErr_SetString(PyExc_ValueError,
"number of bits must be greater than zero");
"number of bits must be non-negative");
return NULL;
}

if (k == 0)
return PyLong_FromLong(0);

if (k <= 32) /* Fast path */
return PyLong_FromUnsignedLong(genrand_uint32(self) >> (32 - k));

Expand Down

0 comments on commit 75a3378

Please sign in to comment.