Skip to content

Commit bc09851

Browse files
pablogsalrhettinger
authored andcommitted
bpo-35606: Implement math.prod (GH-11359)
1 parent e9bc417 commit bc09851

File tree

6 files changed

+260
-1
lines changed

6 files changed

+260
-1
lines changed

Doc/library/math.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,18 @@ Number-theoretic and representation functions
178178
of *x* and are floats.
179179

180180

181+
.. function:: prod(iterable, *, start=1)
182+
183+
Calculate the product of all the elements in the input *iterable*.
184+
The default *start* value for the product is ``1``.
185+
186+
When the iterable is empty, return the start value. This function is
187+
intended specifically for use with numeric values and may reject
188+
non-numeric types.
189+
190+
.. versionadded:: 3.8
191+
192+
181193
.. function:: remainder(x, y)
182194

183195
Return the IEEE 754-style remainder of *x* with respect to *y*. For

Doc/whatsnew/3.8.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,15 @@ json.tool
171171
Add option ``--json-lines`` to parse every input line as separate JSON object.
172172
(Contributed by Weipeng Hong in :issue:`31553`.)
173173

174+
175+
math
176+
----
177+
178+
Added new function, :func:`math.prod`, as analogous function to :func:`sum`
179+
that returns the product of a 'start' value (default: 1) times an iterable of
180+
numbers. (Contributed by Pablo Galindo in :issue:`issue35606`)
181+
182+
174183
os.path
175184
-------
176185

Lib/test/test_math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,37 @@ def test_fractions(self):
17241724
self.assertAllClose(fraction_examples, rel_tol=1e-8)
17251725
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
17261726

1727+
def test_prod(self):
1728+
prod = math.prod
1729+
self.assertEqual(prod([]), 1)
1730+
self.assertEqual(prod([], start=5), 5)
1731+
self.assertEqual(prod(list(range(2,8))), 5040)
1732+
self.assertEqual(prod(iter(list(range(2,8)))), 5040)
1733+
self.assertEqual(prod(range(1, 10), start=10), 3628800)
1734+
1735+
self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
1736+
self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
1737+
self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
1738+
self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
1739+
1740+
# Test overflow in fast-path for integers
1741+
self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
1742+
# Test overflow in fast-path for floats
1743+
self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
1744+
1745+
self.assertRaises(TypeError, prod)
1746+
self.assertRaises(TypeError, prod, 42)
1747+
self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
1748+
self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
1749+
self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
1750+
values = [bytearray(b'a'), bytearray(b'b')]
1751+
self.assertRaises(TypeError, prod, values, bytearray(b''))
1752+
self.assertRaises(TypeError, prod, [[1], [2], [3]])
1753+
self.assertRaises(TypeError, prod, [{2:3}])
1754+
self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
1755+
self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
1756+
with self.assertRaises(TypeError):
1757+
prod([10, 20], [30, 40]) # start is a keyword-only argument
17271758

17281759
def test_main():
17291760
from doctest import DocFileSuite
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Implement :func:`math.prod` as analogous function to :func:`sum` that
2+
returns the product of a 'start' value (default: 1) times an iterable of
3+
numbers. Patch by Pablo Galindo.

Modules/clinic/mathmodule.c.h

Lines changed: 38 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Modules/mathmodule.c

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2494,6 +2494,172 @@ math_isclose_impl(PyObject *module, double a, double b, double rel_tol,
24942494
}
24952495

24962496

2497+
/*[clinic input]
2498+
math.prod
2499+
2500+
iterable: object
2501+
/
2502+
*
2503+
start: object(c_default="NULL") = 1
2504+
2505+
Calculate the product of all the elements in the input iterable.
2506+
2507+
The default start value for the product is 1.
2508+
2509+
When the iterable is empty, return the start value. This function is
2510+
intended specifically for use with numeric values and may reject
2511+
non-numeric types.
2512+
[clinic start generated code]*/
2513+
2514+
static PyObject *
2515+
math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
2516+
/*[clinic end generated code: output=36153bedac74a198 input=4c5ab0682782ed54]*/
2517+
{
2518+
PyObject *result = start;
2519+
PyObject *temp, *item, *iter;
2520+
2521+
iter = PyObject_GetIter(iterable);
2522+
if (iter == NULL) {
2523+
return NULL;
2524+
}
2525+
2526+
if (result == NULL) {
2527+
result = PyLong_FromLong(1);
2528+
if (result == NULL) {
2529+
Py_DECREF(iter);
2530+
return NULL;
2531+
}
2532+
} else {
2533+
Py_INCREF(result);
2534+
}
2535+
#ifndef SLOW_PROD
2536+
/* Fast paths for integers keeping temporary products in C.
2537+
* Assumes all inputs are the same type.
2538+
* If the assumption fails, default to use PyObjects instead.
2539+
*/
2540+
if (PyLong_CheckExact(result)) {
2541+
int overflow;
2542+
long i_result = PyLong_AsLongAndOverflow(result, &overflow);
2543+
/* If this already overflowed, don't even enter the loop. */
2544+
if (overflow == 0) {
2545+
Py_DECREF(result);
2546+
result = NULL;
2547+
}
2548+
/* Loop over all the items in the iterable until we finish, we overflow
2549+
* or we found a non integer element */
2550+
while(result == NULL) {
2551+
item = PyIter_Next(iter);
2552+
if (item == NULL) {
2553+
Py_DECREF(iter);
2554+
if (PyErr_Occurred()) {
2555+
return NULL;
2556+
}
2557+
return PyLong_FromLong(i_result);
2558+
}
2559+
if (PyLong_CheckExact(item)) {
2560+
long b = PyLong_AsLongAndOverflow(item, &overflow);
2561+
long x = i_result * b;
2562+
/* Continue if there is no overflow */
2563+
if (overflow == 0
2564+
&& x < INT_MAX && x > INT_MIN
2565+
&& !(b != 0 && x / i_result != b)) {
2566+
i_result = x;
2567+
Py_DECREF(item);
2568+
continue;
2569+
}
2570+
}
2571+
/* Either overflowed or is not an int.
2572+
* Restore real objects and process normally */
2573+
result = PyLong_FromLong(i_result);
2574+
if (result == NULL) {
2575+
Py_DECREF(item);
2576+
Py_DECREF(iter);
2577+
return NULL;
2578+
}
2579+
temp = PyNumber_Multiply(result, item);
2580+
Py_DECREF(result);
2581+
Py_DECREF(item);
2582+
result = temp;
2583+
if (result == NULL) {
2584+
Py_DECREF(iter);
2585+
return NULL;
2586+
}
2587+
}
2588+
}
2589+
2590+
/* Fast paths for floats keeping temporary products in C.
2591+
* Assumes all inputs are the same type.
2592+
* If the assumption fails, default to use PyObjects instead.
2593+
*/
2594+
if (PyFloat_CheckExact(result)) {
2595+
double f_result = PyFloat_AS_DOUBLE(result);
2596+
Py_DECREF(result);
2597+
result = NULL;
2598+
while(result == NULL) {
2599+
item = PyIter_Next(iter);
2600+
if (item == NULL) {
2601+
Py_DECREF(iter);
2602+
if (PyErr_Occurred()) {
2603+
return NULL;
2604+
}
2605+
return PyFloat_FromDouble(f_result);
2606+
}
2607+
if (PyFloat_CheckExact(item)) {
2608+
f_result *= PyFloat_AS_DOUBLE(item);
2609+
Py_DECREF(item);
2610+
continue;
2611+
}
2612+
if (PyLong_CheckExact(item)) {
2613+
long value;
2614+
int overflow;
2615+
value = PyLong_AsLongAndOverflow(item, &overflow);
2616+
if (!overflow) {
2617+
f_result *= (double)value;
2618+
Py_DECREF(item);
2619+
continue;
2620+
}
2621+
}
2622+
result = PyFloat_FromDouble(f_result);
2623+
if (result == NULL) {
2624+
Py_DECREF(item);
2625+
Py_DECREF(iter);
2626+
return NULL;
2627+
}
2628+
temp = PyNumber_Multiply(result, item);
2629+
Py_DECREF(result);
2630+
Py_DECREF(item);
2631+
result = temp;
2632+
if (result == NULL) {
2633+
Py_DECREF(iter);
2634+
return NULL;
2635+
}
2636+
}
2637+
}
2638+
#endif
2639+
/* Consume rest of the iterable (if any) that could not be handled
2640+
* by specialized functions above.*/
2641+
for(;;) {
2642+
item = PyIter_Next(iter);
2643+
if (item == NULL) {
2644+
/* error, or end-of-sequence */
2645+
if (PyErr_Occurred()) {
2646+
Py_DECREF(result);
2647+
result = NULL;
2648+
}
2649+
break;
2650+
}
2651+
temp = PyNumber_Multiply(result, item);
2652+
Py_DECREF(result);
2653+
Py_DECREF(item);
2654+
result = temp;
2655+
if (result == NULL)
2656+
break;
2657+
}
2658+
Py_DECREF(iter);
2659+
return result;
2660+
}
2661+
2662+
24972663
static PyMethodDef math_methods[] = {
24982664
{"acos", math_acos, METH_O, math_acos_doc},
24992665
{"acosh", math_acosh, METH_O, math_acosh_doc},
@@ -2541,6 +2707,7 @@ static PyMethodDef math_methods[] = {
25412707
{"tan", math_tan, METH_O, math_tan_doc},
25422708
{"tanh", math_tanh, METH_O, math_tanh_doc},
25432709
MATH_TRUNC_METHODDEF
2710+
MATH_PROD_METHODDEF
25442711
{NULL, NULL} /* sentinel */
25452712
};
25462713

0 commit comments

Comments
 (0)