Skip to content

Commit 25e1406

Browse files
committed
py/modmath: Add full checks for math domain errors.
This patch changes how most of the plain math functions are implemented: there are now two generic math wrapper functions that take a pointer to a math function (like sin, cos) and perform the necessary conversion to and from MicroPython types. This helps to reduce code size. The generic functions can also check for math domain errors in a generic way, by testing if the result is NaN or infinity combined with finite inputs. The result is that, with this patch, all math functions now have full domain error checking (even gamma and lgamma) and code size has decreased for most ports. Code size changes in bytes for those with the math module are: unix x64: -432 unix nanbox: -792 stm32: -88 esp8266: +12 Tests are also added to check domain errors are handled correctly.
1 parent f599a38 commit 25e1406

File tree

3 files changed

+132
-20
lines changed

3 files changed

+132
-20
lines changed

py/modmath.c

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
*
44
* The MIT License (MIT)
55
*
6-
* Copyright (c) 2013, 2014 Damien P. George
6+
* Copyright (c) 2013-2017 Damien P. George
77
*
88
* Permission is hereby granted, free of charge, to any person obtaining a copy
99
* of this software and associated documentation files (the "Software"), to deal
@@ -39,14 +39,31 @@ STATIC NORETURN void math_error(void) {
3939
mp_raise_ValueError("math domain error");
4040
}
4141

42+
STATIC mp_obj_t math_generic_1(mp_obj_t x_obj, mp_float_t (*f)(mp_float_t)) {
43+
mp_float_t x = mp_obj_get_float(x_obj);
44+
mp_float_t ans = f(x);
45+
if ((isnan(ans) && !isnan(x)) || (isinf(ans) && !isinf(x))) {
46+
math_error();
47+
}
48+
return mp_obj_new_float(ans);
49+
}
50+
51+
STATIC mp_obj_t math_generic_2(mp_obj_t x_obj, mp_obj_t y_obj, mp_float_t (*f)(mp_float_t, mp_float_t)) {
52+
mp_float_t x = mp_obj_get_float(x_obj);
53+
mp_float_t y = mp_obj_get_float(y_obj);
54+
mp_float_t ans = f(x, y);
55+
if ((isnan(ans) && !isnan(x) && !isnan(y)) || (isinf(ans) && !isinf(x))) {
56+
math_error();
57+
}
58+
return mp_obj_new_float(ans);
59+
}
60+
4261
#define MATH_FUN_1(py_name, c_name) \
43-
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \
62+
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { \
63+
return math_generic_1(x_obj, MICROPY_FLOAT_C_FUN(c_name)); \
64+
} \
4465
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
4566

46-
#define MATH_FUN_2(py_name, c_name) \
47-
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj), mp_obj_get_float(y_obj))); } \
48-
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name);
49-
5067
#define MATH_FUN_1_TO_BOOL(py_name, c_name) \
5168
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_bool(c_name(mp_obj_get_float(x_obj))); } \
5269
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
@@ -55,23 +72,25 @@ STATIC NORETURN void math_error(void) {
5572
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_int_from_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \
5673
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
5774

58-
#define MATH_FUN_1_ERRCOND(py_name, c_name, error_condition) \
59-
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { \
60-
mp_float_t x = mp_obj_get_float(x_obj); \
61-
if (error_condition) { \
62-
math_error(); \
63-
} \
64-
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(x)); \
75+
#define MATH_FUN_2(py_name, c_name) \
76+
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { \
77+
return math_generic_2(x_obj, y_obj, MICROPY_FLOAT_C_FUN(c_name)); \
6578
} \
66-
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
79+
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name);
80+
81+
#define MATH_FUN_2_FLT_INT(py_name, c_name) \
82+
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { \
83+
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj), mp_obj_get_int(y_obj))); \
84+
} \
85+
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name);
6786

6887
#if MP_NEED_LOG2
6988
// 1.442695040888963407354163704 is 1/_M_LN2
7089
#define log2(x) (log(x) * 1.442695040888963407354163704)
7190
#endif
7291

7392
// sqrt(x): returns the square root of x
74-
MATH_FUN_1_ERRCOND(sqrt, sqrt, (x < (mp_float_t)0.0))
93+
MATH_FUN_1(sqrt, sqrt)
7594
// pow(x, y): returns x to the power of y
7695
MATH_FUN_2(pow, pow)
7796
// exp(x)
@@ -80,9 +99,9 @@ MATH_FUN_1(exp, exp)
8099
// expm1(x)
81100
MATH_FUN_1(expm1, expm1)
82101
// log2(x)
83-
MATH_FUN_1_ERRCOND(log2, log2, (x <= (mp_float_t)0.0))
102+
MATH_FUN_1(log2, log2)
84103
// log10(x)
85-
MATH_FUN_1_ERRCOND(log10, log10, (x <= (mp_float_t)0.0))
104+
MATH_FUN_1(log10, log10)
86105
// cosh(x)
87106
MATH_FUN_1(cosh, cosh)
88107
// sinh(x)
@@ -113,9 +132,15 @@ MATH_FUN_2(atan2, atan2)
113132
// ceil(x)
114133
MATH_FUN_1_TO_INT(ceil, ceil)
115134
// copysign(x, y)
116-
MATH_FUN_2(copysign, copysign)
135+
STATIC mp_float_t MICROPY_FLOAT_C_FUN(copysign_func)(mp_float_t x, mp_float_t y) {
136+
return MICROPY_FLOAT_C_FUN(copysign)(x, y);
137+
}
138+
MATH_FUN_2(copysign, copysign_func)
117139
// fabs(x)
118-
MATH_FUN_1(fabs, fabs)
140+
STATIC mp_float_t MICROPY_FLOAT_C_FUN(fabs_func)(mp_float_t x) {
141+
return MICROPY_FLOAT_C_FUN(fabs)(x);
142+
}
143+
MATH_FUN_1(fabs, fabs_func)
119144
// floor(x)
120145
MATH_FUN_1_TO_INT(floor, floor) //TODO: delegate to x.__floor__() if x is not a float
121146
// fmod(x, y)
@@ -129,7 +154,7 @@ MATH_FUN_1_TO_BOOL(isnan, isnan)
129154
// trunc(x)
130155
MATH_FUN_1_TO_INT(trunc, trunc)
131156
// ldexp(x, exp)
132-
MATH_FUN_2(ldexp, ldexp)
157+
MATH_FUN_2_FLT_INT(ldexp, ldexp)
133158
#if MICROPY_PY_MATH_SPECIAL_FUNCTIONS
134159
// erf(x): return the error function of x
135160
MATH_FUN_1(erf, erf)

tests/float/math_domain.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Tests domain errors in math functions
2+
3+
try:
4+
import math
5+
except ImportError:
6+
print("SKIP")
7+
raise SystemExit
8+
9+
inf = float('inf')
10+
nan = float('nan')
11+
12+
# single argument functions
13+
for name, f, args in (
14+
('fabs', math.fabs, ()),
15+
('ceil', math.ceil, ()),
16+
('floor', math.floor, ()),
17+
('trunc', math.trunc, ()),
18+
('sqrt', math.sqrt, (-1, 0)),
19+
('exp', math.exp, ()),
20+
('sin', math.sin, ()),
21+
('cos', math.cos, ()),
22+
('tan', math.tan, ()),
23+
('asin', math.asin, (-1.1, 1, 1.1)),
24+
('acos', math.acos, (-1.1, 1, 1.1)),
25+
('atan', math.atan, ()),
26+
('ldexp', lambda x: math.ldexp(x, 0), ()),
27+
('radians', math.radians, ()),
28+
('degrees', math.degrees, ()),
29+
):
30+
for x in args + (inf, nan):
31+
try:
32+
ans = f(x)
33+
print('%.4f' % ans)
34+
except ValueError:
35+
print(name, 'ValueError')
36+
except OverflowError:
37+
print(name, 'OverflowError')
38+
39+
# double argument functions
40+
for name, f, args in (
41+
('pow', math.pow, ((0, 2), (-1, 2), (0, -1), (-1, 2.3))),
42+
('fmod', math.fmod, ((1.2, inf), (1.2, 0), (inf, 1.2))),
43+
('atan2', math.atan2, ((0, 0),)),
44+
('copysign', math.copysign, ()),
45+
):
46+
for x in args + ((0, inf), (inf, 0), (inf, inf), (inf, nan), (nan, inf), (nan, nan)):
47+
try:
48+
ans = f(*x)
49+
print('%.4f' % ans)
50+
except ValueError:
51+
print(name, 'ValueError')

tests/float/math_domain_special.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Tests domain errors in special math functions
2+
3+
try:
4+
import math
5+
math.erf
6+
except (ImportError, AttributeError):
7+
print("SKIP")
8+
raise SystemExit
9+
10+
inf = float('inf')
11+
nan = float('nan')
12+
13+
# single argument functions
14+
for name, f, args in (
15+
('expm1', math.exp, ()),
16+
('log2', math.log2, (-1, 0)),
17+
('log10', math.log10, (-1, 0)),
18+
('sinh', math.sinh, ()),
19+
('cosh', math.cosh, ()),
20+
('tanh', math.tanh, ()),
21+
('asinh', math.asinh, ()),
22+
('acosh', math.acosh, (-1, 0.9, 1)),
23+
('atanh', math.atanh, (-1, 1)),
24+
('erf', math.erf, ()),
25+
('erfc', math.erfc, ()),
26+
('gamma', math.gamma, (-2, -1, 0, 1)),
27+
('lgamma', math.lgamma, (-2, -1, 0, 1)),
28+
):
29+
for x in args + (inf, nan):
30+
try:
31+
ans = f(x)
32+
print('%.4f' % ans)
33+
except ValueError:
34+
print(name, 'ValueError')
35+
except OverflowError:
36+
print(name, 'OverflowError')

0 commit comments

Comments
 (0)