diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index a0ddaac4..8be7609d 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -252,8 +252,11 @@ def learn_rounding_digits(data): int or None: Number of digits to round to. """ - # check if data has any decimals name = data.name + if isinstance(data.dtype, pd.ArrowDtype): + data = data.to_numpy() + + # check if data has any decimals roundable_data = data[~(np.isinf(data.astype(float)) | pd.isna(data))] # Doesn't contain numbers diff --git a/tests/unit/transformers/test_numerical.py b/tests/unit/transformers/test_numerical.py index a878075f..314b6206 100644 --- a/tests/unit/transformers/test_numerical.py +++ b/tests/unit/transformers/test_numerical.py @@ -5,7 +5,6 @@ import copulas import numpy as np import pandas as pd -import pyarrow as pa import pytest from copulas import univariate from pandas.api.types import is_float_dtype diff --git a/tests/unit/transformers/test_utils.py b/tests/unit/transformers/test_utils.py index 164b3a20..4a0ba2ea 100644 --- a/tests/unit/transformers/test_utils.py +++ b/tests/unit/transformers/test_utils.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd -import pyarrow as pa import pytest from rdt.transformers.utils import ( @@ -238,6 +237,18 @@ def test_learn_rounding_digits_pyarrow(): assert output == 0 +def test_learn_rounding_digits_pyarrow_float(): + """Test it learns the proper amount of digits with pyarrow.""" + # Setup + data = pd.Series([0.5, 0.19, 3], dtype='float64[pyarrow]') + + # Run + output = learn_rounding_digits(data) + + # Assert + assert output == 2 + + def test_learn_rounding_digits_negative_decimals_float(): """Test the learn_rounding_digits method with floats multiples of powers of 10.