Skip to content

Commit 0411411

Browse files
authored
Rework integer overflow path in math.prod and add more tests (GH-11809)
The overflow check was relying on undefined behaviour as it was using the result of the multiplication to do the check, and once the overflow has already happened, any operation on the result is undefined behaviour. Some extra checks that exercise code paths related to this are also added.
1 parent 62fa51f commit 0411411

File tree

2 files changed

+137
-40
lines changed

2 files changed

+137
-40
lines changed

Lib/test/test_math.py

Lines changed: 86 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,92 @@ def test_mtestfile(self):
15951595
self.fail('Failures in test_mtestfile:\n ' +
15961596
'\n '.join(failures))
15971597

1598+
def test_prod(self):
1599+
prod = math.prod
1600+
self.assertEqual(prod([]), 1)
1601+
self.assertEqual(prod([], start=5), 5)
1602+
self.assertEqual(prod(list(range(2,8))), 5040)
1603+
self.assertEqual(prod(iter(list(range(2,8)))), 5040)
1604+
self.assertEqual(prod(range(1, 10), start=10), 3628800)
1605+
1606+
self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
1607+
self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
1608+
self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
1609+
self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
1610+
1611+
# Test overflow in fast-path for integers
1612+
self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
1613+
# Test overflow in fast-path for floats
1614+
self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
1615+
1616+
self.assertRaises(TypeError, prod)
1617+
self.assertRaises(TypeError, prod, 42)
1618+
self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
1619+
self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
1620+
self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
1621+
values = [bytearray(b'a'), bytearray(b'b')]
1622+
self.assertRaises(TypeError, prod, values, bytearray(b''))
1623+
self.assertRaises(TypeError, prod, [[1], [2], [3]])
1624+
self.assertRaises(TypeError, prod, [{2:3}])
1625+
self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
1626+
self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
1627+
with self.assertRaises(TypeError):
1628+
prod([10, 20], [30, 40]) # start is a keyword-only argument
1629+
1630+
self.assertEqual(prod([0, 1, 2, 3]), 0)
1631+
self.assertEqual(prod([1, 0, 2, 3]), 0)
1632+
self.assertEqual(prod([1, 2, 3, 0]), 0)
1633+
1634+
def _naive_prod(iterable, start=1):
1635+
for elem in iterable:
1636+
start *= elem
1637+
return start
1638+
1639+
# Big integers
1640+
1641+
iterable = range(1, 10000)
1642+
self.assertEqual(prod(iterable), _naive_prod(iterable))
1643+
iterable = range(-10000, -1)
1644+
self.assertEqual(prod(iterable), _naive_prod(iterable))
1645+
iterable = range(-1000, 1000)
1646+
self.assertEqual(prod(iterable), 0)
1647+
1648+
# Big floats
1649+
1650+
iterable = [float(x) for x in range(1, 1000)]
1651+
self.assertEqual(prod(iterable), _naive_prod(iterable))
1652+
iterable = [float(x) for x in range(-1000, -1)]
1653+
self.assertEqual(prod(iterable), _naive_prod(iterable))
1654+
iterable = [float(x) for x in range(-1000, 1000)]
1655+
self.assertIsNaN(prod(iterable))
1656+
1657+
# Float tests
1658+
1659+
self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3]))
1660+
self.assertIsNaN(prod([1, 0, float("nan"), 2, 3]))
1661+
self.assertIsNaN(prod([1, float("nan"), 0, 3]))
1662+
self.assertIsNaN(prod([1, float("inf"), float("nan"),3]))
1663+
self.assertIsNaN(prod([1, float("-inf"), float("nan"),3]))
1664+
self.assertIsNaN(prod([1, float("nan"), float("inf"),3]))
1665+
self.assertIsNaN(prod([1, float("nan"), float("-inf"),3]))
1666+
1667+
self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf'))
1668+
self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf'))
1669+
1670+
self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4]))
1671+
self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4]))
1672+
self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3]))
1673+
self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2]))
1674+
1675+
# Type preservation
1676+
1677+
self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int)
1678+
self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float)
1679+
self.assertEqual(type(prod(range(1, 10000))), int)
1680+
self.assertEqual(type(prod(range(1, 10000), start=1.0)), float)
1681+
self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
1682+
decimal.Decimal)
1683+
15981684
# Custom assertions.
15991685

16001686
def assertIsNaN(self, value):
@@ -1724,41 +1810,6 @@ def test_fractions(self):
17241810
self.assertAllClose(fraction_examples, rel_tol=1e-8)
17251811
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
17261812

1727-
def test_prod(self):
1728-
prod = math.prod
1729-
self.assertEqual(prod([]), 1)
1730-
self.assertEqual(prod([], start=5), 5)
1731-
self.assertEqual(prod(list(range(2,8))), 5040)
1732-
self.assertEqual(prod(iter(list(range(2,8)))), 5040)
1733-
self.assertEqual(prod(range(1, 10), start=10), 3628800)
1734-
1735-
self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
1736-
self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
1737-
self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
1738-
self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
1739-
1740-
# Test overflow in fast-path for integers
1741-
self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
1742-
# Test overflow in fast-path for floats
1743-
self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
1744-
1745-
self.assertRaises(TypeError, prod)
1746-
self.assertRaises(TypeError, prod, 42)
1747-
self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
1748-
self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
1749-
self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
1750-
values = [bytearray(b'a'), bytearray(b'b')]
1751-
self.assertRaises(TypeError, prod, values, bytearray(b''))
1752-
self.assertRaises(TypeError, prod, [[1], [2], [3]])
1753-
self.assertRaises(TypeError, prod, [{2:3}])
1754-
self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
1755-
self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
1756-
with self.assertRaises(TypeError):
1757-
prod([10, 20], [30, 40]) # start is a keyword-only argument
1758-
1759-
self.assertEqual(prod([0, 1, 2, 3]), 0)
1760-
self.assertEqual(prod([1, 0, 2, 3]), 0)
1761-
self.assertEqual(prod(range(10)), 0)
17621813

17631814
def test_main():
17641815
from doctest import DocFileSuite

Modules/mathmodule.c

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2493,6 +2493,55 @@ math_isclose_impl(PyObject *module, double a, double b, double rel_tol,
24932493
(diff <= abs_tol));
24942494
}
24952495

2496+
static inline int
2497+
_check_long_mult_overflow(long a, long b) {
2498+
2499+
/* From Python2's int_mul code:
2500+
2501+
Integer overflow checking for * is painful: Python tried a couple ways, but
2502+
they didn't work on all platforms, or failed in endcases (a product of
2503+
-sys.maxint-1 has been a particular pain).
2504+
2505+
Here's another way:
2506+
2507+
The native long product x*y is either exactly right or *way* off, being
2508+
just the last n bits of the true product, where n is the number of bits
2509+
in a long (the delivered product is the true product plus i*2**n for
2510+
some integer i).
2511+
2512+
The native double product (double)x * (double)y is subject to three
2513+
rounding errors: on a sizeof(long)==8 box, each cast to double can lose
2514+
info, and even on a sizeof(long)==4 box, the multiplication can lose info.
2515+
But, unlike the native long product, it's not in *range* trouble: even
2516+
if sizeof(long)==32 (256-bit longs), the product easily fits in the
2517+
dynamic range of a double. So the leading 50 (or so) bits of the double
2518+
product are correct.
2519+
2520+
We check these two ways against each other, and declare victory if they're
2521+
approximately the same. Else, because the native long product is the only
2522+
one that can lose catastrophic amounts of information, it's the native long
2523+
product that must have overflowed.
2524+
2525+
*/
2526+
2527+
long longprod = (long)((unsigned long)a * b);
2528+
double doubleprod = (double)a * (double)b;
2529+
double doubled_longprod = (double)longprod;
2530+
2531+
if (doubled_longprod == doubleprod) {
2532+
return 0;
2533+
}
2534+
2535+
const double diff = doubled_longprod - doubleprod;
2536+
const double absdiff = diff >= 0.0 ? diff : -diff;
2537+
const double absprod = doubleprod >= 0.0 ? doubleprod : -doubleprod;
2538+
2539+
if (32.0 * absdiff <= absprod) {
2540+
return 0;
2541+
}
2542+
2543+
return 1;
2544+
}
24962545

24972546
/*[clinic input]
24982547
math.prod
@@ -2558,11 +2607,8 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
25582607
}
25592608
if (PyLong_CheckExact(item)) {
25602609
long b = PyLong_AsLongAndOverflow(item, &overflow);
2561-
long x = i_result * b;
2562-
/* Continue if there is no overflow */
2563-
if (overflow == 0
2564-
&& x < LONG_MAX && x > LONG_MIN
2565-
&& !(b != 0 && x / b != i_result)) {
2610+
if (overflow == 0 && !_check_long_mult_overflow(i_result, b)) {
2611+
long x = i_result * b;
25662612
i_result = x;
25672613
Py_DECREF(item);
25682614
continue;

0 commit comments

Comments
 (0)