diff --git a/pyproject.toml b/pyproject.toml index 1c9c26b2..04006a52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,9 @@ rdt = { main = 'rdt.cli.__main__:main' } [project.optional-dependencies] copulas = ['copulas>=0.11.0',] +pyarrow = ['pyarrow>=17.0.0'] test = [ + 'rdt[pyarrow]', 'rdt[copulas]', 'pytest>=3.4.2', @@ -58,7 +60,6 @@ test = [ 'rundoc>=0.4.3,<0.5', 'pytest-subtests>=0.5,<1.0', 'pytest-runner >= 2.11.1', - 'pyarrow >= 17.0.0', 'tomli>=2.0.0,<3', ] dev = [ diff --git a/tests/unit/transformers/test_numerical.py b/tests/unit/transformers/test_numerical.py index 0af6b449..47c34787 100644 --- a/tests/unit/transformers/test_numerical.py +++ b/tests/unit/transformers/test_numerical.py @@ -5,6 +5,7 @@ import copulas import numpy as np import pandas as pd +import pyarrow as pa import pytest from copulas import univariate from pandas.api.types import is_float_dtype @@ -44,6 +45,19 @@ def test__validate_values_within_bounds(self): # Run transformer._validate_values_within_bounds(data) + def test__validate_values_within_bounds_pyarrow(self): + """Test it works with pyarrow.""" + # Setup + try: + data = pd.Series(range(10), dtype=pd.ArrowDtype(pa.int64())) + except AttributeError: + data = pd.Series(range(10), dtype='int64[pyarrow]') + transformer = FloatFormatter() + transformer.computer_representation = 'UInt8' + + # Run + transformer._validate_values_within_bounds(data) + def test__validate_values_within_bounds_under_minimum(self): """Test the ``_validate_values_within_bounds`` method. diff --git a/tests/unit/transformers/test_utils.py b/tests/unit/transformers/test_utils.py index c2bccead..c29ed0ee 100644 --- a/tests/unit/transformers/test_utils.py +++ b/tests/unit/transformers/test_utils.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import pytest from rdt.transformers.utils import ( @@ -225,6 +226,21 @@ def test_learn_rounding_digits_less_than_15_decimals(): assert output == 3 +def test_learn_rounding_digits_pyarrow(): + """Test it works with pyarrow.""" + # Setup + try: + data = pd.Series(range(10), dtype=pd.ArrowDtype(pa.int64())) + except AttributeError: + data = pd.Series(range(10), dtype='int64[pyarrow]') + + # Run + output = learn_rounding_digits(data) + + # Assert + assert output == 0 + + def test_learn_rounding_digits_negative_decimals_float(): """Test the learn_rounding_digits method with floats multiples of powers of 10.