Skip to content

Commit

Permalink
Add new test
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 8, 2024
1 parent 85bcb01 commit 3e9bd7a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
5 changes: 4 additions & 1 deletion rdt/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,11 @@ def learn_rounding_digits(data):
int or None:
Number of digits to round to.
"""
# check if data has any decimals
name = data.name
if isinstance(data.dtype, pd.ArrowDtype):
data = data.to_numpy()

# check if data has any decimals
roundable_data = data[~(np.isinf(data.astype(float)) | pd.isna(data))]

# Doesn't contain numbers
Expand Down
1 change: 0 additions & 1 deletion tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import copulas
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from copulas import univariate
from pandas.api.types import is_float_dtype
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/transformers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

from rdt.transformers.utils import (
Expand Down Expand Up @@ -238,6 +237,18 @@ def test_learn_rounding_digits_pyarrow():
assert output == 0


def test_learn_rounding_digits_pyarrow_float():
"""Test it learns the proper amount of digits with pyarrow."""
# Setup
data = pd.Series([0.5, 0.19, 3], dtype='float64[pyarrow]')

# Run
output = learn_rounding_digits(data)

# Assert
assert output == 2


def test_learn_rounding_digits_negative_decimals_float():
"""Test the learn_rounding_digits method with floats multiples of powers of 10.
Expand Down

0 comments on commit 3e9bd7a

Please sign in to comment.