Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 9, 2024
1 parent cbe1abf commit a28f04a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions rdt/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def learn_rounding_digits(data):
int or None:
Number of digits to round to.
"""
# check it ends with pyarrow
if str(data.dtype).endswith("[pyarrow]"):
data = data.to_numpy()

Expand Down
5 changes: 4 additions & 1 deletion tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def test__validate_values_within_bounds(self):
def test__validate_values_within_bounds_pyarrow(self):
"""Test it works with pyarrow."""
# Setup
data = pd.Series(range(10), dtype='int64[pyarrow]')
try:
data = pd.Series(range(10), dtype='int64[pyarrow]')
except TypeError:
pytest.skip("Skipping as old numpy/pandas versions don't support arrow")
transformer = FloatFormatter()
transformer.computer_representation = 'UInt8'

Expand Down
10 changes: 8 additions & 2 deletions tests/unit/transformers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ def test_learn_rounding_digits_less_than_15_decimals():
def test_learn_rounding_digits_pyarrow():
"""Test it works with pyarrow."""
# Setup
data = pd.Series(range(10), dtype='int64[pyarrow]')
try:
data = pd.Series(range(10), dtype='int64[pyarrow]')
except TypeError:
pytest.skip("Skipping as old numpy/pandas versions don't support arrow")

# Run
output = learn_rounding_digits(data)
Expand All @@ -240,7 +243,10 @@ def test_learn_rounding_digits_pyarrow():
def test_learn_rounding_digits_pyarrow_float():
"""Test it learns the proper amount of digits with pyarrow."""
# Setup
data = pd.Series([0.5, 0.19, 3], dtype='float64[pyarrow]')
try:
data = pd.Series([0.5, 0.19, 3], dtype='float64[pyarrow]')
except TypeError:
pytest.skip("Skipping as old numpy/pandas versions don't support arrow")

# Run
output = learn_rounding_digits(data)
Expand Down

0 comments on commit a28f04a

Please sign in to comment.