Skip to content

Commit 7afe9ac

Browse files
Support gmpy2-like rounding modes in to_str() (mpmath#830)
Closes mpmath#757
1 parent ba7aac8 commit 7afe9ac

File tree

3 files changed

+32
-31
lines changed

3 files changed

+32
-31
lines changed

mpmath/libmp/libmpf.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ def to_digits_exp(s, dps, base=10):
10981098
exponent += len(digits) - fixdps - 1
10991099
return sign, digits, exponent
11001100

1101-
def round_digits(digits, dps, base, rounding=round_nearest):
1101+
def round_digits(sign, digits, dps, base, rounding=round_nearest):
11021102
'''
11031103
Returns the rounded digits, and the number of places the decimal point was
11041104
shifted.
@@ -1107,7 +1107,15 @@ def round_digits(digits, dps, base, rounding=round_nearest):
11071107
'''
11081108

11091109
assert len(digits) > dps
1110-
assert rounding in (round_nearest, round_up, round_down)
1110+
assert rounding in (round_nearest, round_up, round_down, round_ceiling,
1111+
round_floor)
1112+
1113+
if rounding == round_ceiling:
1114+
rounding = round_down if sign else round_up
1115+
elif rounding == round_floor:
1116+
rounding = round_up if sign else round_down
1117+
else:
1118+
rounding = rounding
11111119

11121120
exponent = 0
11131121

@@ -1170,9 +1178,9 @@ def round_digits(digits, dps, base, rounding=round_nearest):
11701178
return digits, exponent
11711179

11721180

1173-
11741181
def to_str(s, dps, strip_zeros=True, min_fixed=None, max_fixed=None,
1175-
show_zero_exponent=False, base=10, binary_exp=False):
1182+
show_zero_exponent=False, base=10, binary_exp=False,
1183+
rounding=round_nearest):
11761184
"""
11771185
Convert a raw mpf to a floating-point literal in the given base
11781186
with at most `dps` digits in the mantissa (not counting extra zeros
@@ -1202,6 +1210,13 @@ def to_str(s, dps, strip_zeros=True, min_fixed=None, max_fixed=None,
12021210
if base not in (2, 16):
12031211
raise ValueError("binary_exp option could be used for base 2 and 16")
12041212

1213+
1214+
if rounding not in (round_nearest, round_floor, round_ceiling, round_up,
1215+
round_down):
1216+
raise ValueError("rounding should be one of " +
1217+
", ".join([round_nearest, round_floor, round_ceiling,
1218+
round_up, round_down]) + ".")
1219+
12051220
if base == 2:
12061221
prefix = "0b"
12071222
elif base == 8:
@@ -1249,21 +1264,8 @@ def to_str(s, dps, strip_zeros=True, min_fixed=None, max_fixed=None,
12491264
n = int(digits, 16) >> shift
12501265
digits = hex(n)[2:]
12511266

1252-
# Rounding up kills some instances of "...99999"
1253-
if len(digits) > dps and digits[dps] in rnd_digs:
1254-
digits = digits[:dps]
1255-
i = dps - 1
1256-
dig = stddigits[base-1]
1257-
while i >= 0 and digits[i] == dig:
1258-
i -= 1
1259-
if i >= 0:
1260-
digits = digits[:i] + stddigits[int(digits[i], base) + 1] + \
1261-
'0' * (dps - i - 1)
1262-
else:
1263-
digits = '1' + '0' * (dps - 1)
1264-
exponent += 1
1265-
else:
1266-
digits = digits[:dps]
1267+
digits, exp_add = round_digits(s[0], digits, dps, base, rounding)
1268+
exponent += exp_add
12671269

12681270
# Prettify numbers close to unit magnitude
12691271
if not binary_exp and min_fixed < exponent < max_fixed:
@@ -1533,7 +1535,7 @@ def format_fixed(s,
15331535
if no_neg_0:
15341536
sign = '' if sign_spec == '-' else sign_spec
15351537
else:
1536-
digits, exp_add = round_digits(digits, dps, base, rounding)
1538+
digits, exp_add = round_digits(s[0], digits, dps, base, rounding)
15371539
exponent += exp_add
15381540

15391541
# Here we prepend the corresponding 0s to the digits string, according
@@ -1602,7 +1604,7 @@ def format_scientific(s,
16021604
if sign != '-' and sign_spec != '-':
16031605
sign = sign_spec
16041606

1605-
digits, exp_add = round_digits(digits, dps, base, rounding)
1607+
digits, exp_add = round_digits(s[0], digits, dps, base, rounding)
16061608
exponent += exp_add
16071609

16081610
if strip_zeros:
@@ -1645,12 +1647,7 @@ def format_mpf(num, format_spec, prec):
16451647
strip_last_zero = False
16461648
strip_zeros = False
16471649

1648-
if format_dict['rounding'] == round_ceiling:
1649-
rounding = round_down if num[0] else round_up
1650-
elif format_dict['rounding'] == round_floor:
1651-
rounding = round_up if num[0] else round_down
1652-
else:
1653-
rounding = format_dict['rounding']
1650+
rounding = format_dict['rounding']
16541651

16551652
if fmt_type == 'g':
16561653
if not format_dict['alternate']:

mpmath/tests/test_convert.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,12 @@ def test_to_str():
7676
assert to_str(from_str('0x1.4ace478p+33'), 7, base=16, binary_exp=True) == '0x1.4ace48p+33'
7777
assert to_str(from_str('0x1.4ace478p+33'), 5, base=16, binary_exp=True) == '0x1.4acep+33'
7878
assert to_str(from_str('1', base=16), 6, base=16, binary_exp=True) == '0x1.0'
79-
pytest.raises(ValueError, lambda: to_str(from_str('1', base=16),
80-
6, binary_exp=True))
79+
x = mpf('1234.567891')._mpf_
80+
pytest.raises(ValueError, lambda: to_str(x, 6, binary_exp=True))
81+
pytest.raises(ValueError, lambda: to_str(x, 6, rounding='Y'))
82+
assert to_str(x, 5, rounding='n') == '1234.6'
83+
assert to_str(x, 5, rounding='d') == '1234.5'
84+
assert to_str(x, 5, rounding='u') == '1234.6'
8185

8286
def test_pretty():
8387
mp.pretty = True

mpmath/tests/test_str.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ def test_nstr():
1313
[-0.0299195971, 0.205663228, 0.64453125e-20]])
1414
assert nstr(m, 4, min_fixed=-inf) == \
1515
'''[ 0.75 0.1909 -0.02992]
16-
[ 0.1909 0.6563 0.2057]
16+
[ 0.1909 0.6562 0.2057]
1717
[-0.02992 0.2057 0.000000000000000000006445]'''
1818
assert nstr(m, 4) == \
1919
'''[ 0.75 0.1909 -0.02992]
20-
[ 0.1909 0.6563 0.2057]
20+
[ 0.1909 0.6562 0.2057]
2121
[-0.02992 0.2057 6.445e-21]'''
2222

2323
def test_matrix_repr():

0 commit comments

Comments
 (0)