Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 8, 2024
1 parent 98255ba commit 85bcb01
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ rdt = { main = 'rdt.cli.__main__:main' }

[project.optional-dependencies]
copulas = ['copulas>=0.11.0',]
pyarrow = ['pyarrow>=17.0.0']
test = [
'rdt[pyarrow]',
'rdt[copulas]',

'pytest>=3.4.2',
Expand All @@ -58,7 +60,6 @@ test = [
'rundoc>=0.4.3,<0.5',
'pytest-subtests>=0.5,<1.0',
'pytest-runner >= 2.11.1',
'pyarrow >= 17.0.0',
'tomli>=2.0.0,<3',
]
dev = [
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copulas
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from copulas import univariate
from pandas.api.types import is_float_dtype
Expand Down Expand Up @@ -44,6 +45,16 @@ def test__validate_values_within_bounds(self):
# Run
transformer._validate_values_within_bounds(data)

def test__validate_values_within_bounds_pyarrow(self):
"""Test it works with pyarrow."""
# Setup
data = pd.Series(range(10), dtype='int64[pyarrow]')
transformer = FloatFormatter()
transformer.computer_representation = 'UInt8'

# Run
transformer._validate_values_within_bounds(data)

def test__validate_values_within_bounds_under_minimum(self):
"""Test the ``_validate_values_within_bounds`` method.
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/transformers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

from rdt.transformers.utils import (
Expand Down Expand Up @@ -225,6 +226,18 @@ def test_learn_rounding_digits_less_than_15_decimals():
assert output == 3


def test_learn_rounding_digits_pyarrow():
"""Test it works with pyarrow."""
# Setup
data = pd.Series(range(10), dtype='int64[pyarrow]')

# Run
output = learn_rounding_digits(data)

# Assert
assert output == 0


def test_learn_rounding_digits_negative_decimals_float():
"""Test the learn_rounding_digits method with floats multiples of powers of 10.
Expand Down

0 comments on commit 85bcb01

Please sign in to comment.