Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix learn_rounding_scheme for more than 14 digits #591

Merged
merged 2 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,19 @@ def _learn_rounding_digits(data):
# check if data has any decimals
data = np.array(data)
roundable_data = data[~(np.isinf(data) | pd.isna(data))]
if ((roundable_data % 1) != 0).any():
if (roundable_data == roundable_data.round(MAX_DECIMALS)).all():
for decimal in range(MAX_DECIMALS + 1):
if (roundable_data == roundable_data.round(decimal)).all():
return decimal

return None
# Doesn't contain decimal digits
if ((roundable_data % 1) == 0).all():
return None

# Try to round to fewer digits
if (roundable_data == roundable_data.round(MAX_DECIMALS)).all():
for decimal in range(MAX_DECIMALS + 1):
if (roundable_data == roundable_data.round(decimal)).all():
return decimal

# Can't round, not equal after MAX_DECIMALS digits of precision
return MAX_DECIMALS

def _raise_out_of_bounds_error(self, value, name, bound_type, min_bound, max_bound):
raise ValueError(
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def test__learn_rounding_digits_more_than_15_decimals(self):
Input:
- An array that contains floats with more than 15 decimals.
Output:
- None
- 14
"""
data = np.random.random(size=10).round(20)

output = FloatFormatter._learn_rounding_digits(data)

assert output is None
assert output == 14

def test__learn_rounding_digits_less_than_15_decimals(self):
"""Test the _learn_rounding_digits method with less than 15 decimals.
Expand Down Expand Up @@ -298,7 +298,7 @@ def test__fit_learn_rounding_scheme_true_max_decimals(self):
Input:
- Series with a value that has 15 decimals
Side Effect:
- ``_rounding_digits`` is set to ``None``
- ``_rounding_digits`` is set to 14
"""
# Setup
data = pd.Series([0.000000000000001])
Expand All @@ -311,7 +311,7 @@ def test__fit_learn_rounding_scheme_true_max_decimals(self):
transformer._fit(data)

# Asserts
assert transformer._rounding_digits is None
assert transformer._rounding_digits == 14

def test__fit_learn_rounding_scheme_true_inf(self):
"""Test ``_fit`` with ``learn_rounding_scheme`` set to ``True``.
Expand Down