diff --git a/tests/performance/test_performance.py b/tests/performance/test_performance.py index c3e27dca8..a20274c41 100644 --- a/tests/performance/test_performance.py +++ b/tests/performance/test_performance.py @@ -8,7 +8,8 @@ from rdt.performance.performance import evaluate_transformer_performance from rdt.performance.profiling import profile_transformer from rdt.transformers import get_transformers_by_type -from rdt.transformers.categorical import CustomLabelEncoder, OrderedLabelEncoder, OrderedUniformEncoder +from rdt.transformers.categorical import ( + CustomLabelEncoder, OrderedLabelEncoder, OrderedUniformEncoder) from rdt.transformers.numerical import ClusterBasedNormalizer SANDBOX_TRANSFORMERS = [ diff --git a/tests/unit/transformers/test_categorical.py b/tests/unit/transformers/test_categorical.py index c27ca6b6d..61a5983d4 100644 --- a/tests/unit/transformers/test_categorical.py +++ b/tests/unit/transformers/test_categorical.py @@ -299,19 +299,19 @@ def test__reverse_transform(self): def test__reverse_transform_nans(self): """Test ``_reverse_transform`` for data with NaNs.""" # Setup - data = pd.Series(['a', 'b', 'c', np.nan, 'c', 'b', 'b', 'a', 'b', np.nan]) + data = pd.Series(['a', 'b', 'NaN', np.nan, 'NaN', 'b', 'b', 'a', 'b', np.nan]) transformer = UniformEncoder() transformer.dtype = object transformer.frequencies = { 'a': 0.2, 'b': 0.4, - 'c': 0.2, + 'NaN': 0.2, np.nan: 0.2 } transformer.intervals = { 'a': [0, 0.2], 'b': [0.2, 0.6], - 'c': [0.6, 0.8], + 'NaN': [0.6, 0.8], np.nan: [0.8, 1] } @@ -458,16 +458,20 @@ def test__transform_error(self): If the data being transformed is not in ``self.order`` an error should be raised. """ # Setup - data = pd.Series([1, 2, 3, 2, 1, 4]) + data_error = pd.Series([1, 2, 3, 2, 1, 4]) + data = pd.Series([1, 2, 1, 2, 1, 1]) transformer = OrderedUniformEncoder(order=[2, 1]) # Run / Assert + transformer._fit(data) + transformer._transform(data) + message = re.escape( "Unknown categories '[3, 4]'. All possible categories must be defined in the " "'order' parameter." ) with pytest.raises(TransformerInputError, match=message): - transformer._transform(data) + transformer._transform(data_error) class TestFrequencyEncoder: