diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index 6a851d98..9544ff1e 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -252,11 +252,10 @@ def learn_rounding_digits(data): int or None: Number of digits to round to. """ - # check it ends with pyarrow - if str(data.dtype).endswith("[pyarrow]"): - data = data.to_numpy() - # check if data has any decimals + name = data.name + if str(data.dtype).endswith('[pyarrow]'): + data = data.to_numpy() roundable_data = data[~(np.isinf(data.astype(float)) | pd.isna(data))] # Doesn't contain numbers @@ -276,7 +275,7 @@ def learn_rounding_digits(data): # Can't round, not equal after MAX_DECIMALS digits of precision LOGGER.info( "No rounding scheme detected for column '%s'. Data will not be rounded.", - data.name, + name, ) return None diff --git a/tests/unit/transformers/test_utils.py b/tests/unit/transformers/test_utils.py index a99843f9..d979b14a 100644 --- a/tests/unit/transformers/test_utils.py +++ b/tests/unit/transformers/test_utils.py @@ -1,7 +1,7 @@ import sre_parse import warnings from sre_constants import MAXREPEAT -from unittest.mock import patch +from unittest.mock import Mock, patch import numpy as np import pandas as pd @@ -329,6 +329,20 @@ def test_learn_rounding_digits_nullable_numerical_pandas_dtypes(): assert output == expected_output[column] +def test_learn_rounding_digits_pyarrow_to_numpy(): + """Test that ``learn_rounding_digits`` works with pyarrow to numpy conversion.""" + # Setup + data = Mock() + data.dtype = 'int64[pyarrow]' + data.to_numpy.return_value = np.array() + + # Run + learn_rounding_digits(data) + + # Assert + assert data.to_numpy.called + + def test_warn_dict(): """Test that ``WarnDict`` will raise a warning when called with `text`.""" # Setup