From 635fae4121e32e75d3c70a17df64dadc0124ea7a Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 22 Aug 2023 10:37:52 +0200 Subject: [PATCH] unit test --- rdt/transformers/base.py | 6 +-- tests/integration/test_transformers.py | 2 +- tests/unit/transformers/test_base.py | 70 ++++++++++++++++++++++++-- 3 files changed, 69 insertions(+), 9 deletions(-) diff --git a/rdt/transformers/base.py b/rdt/transformers/base.py index b3c9aa2e2..fffe92466 100644 --- a/rdt/transformers/base.py +++ b/rdt/transformers/base.py @@ -273,7 +273,7 @@ def _store_columns(self, columns, data): if missing: raise KeyError(f'Columns {missing} were not present in the data.') - self.columns = columns + self.columns = column_names @staticmethod def _get_columns_data(data, columns): @@ -490,7 +490,7 @@ class BaseMultiColumnTransformer(BaseTransformer): """Base class for all multi column transformers. The ``BaseMultiColumnTransformer`` class contains methods that must be implemented - in order to create a new mulit column transformer. + in order to create a new multi column transformer. """ def get_input_column(self): @@ -499,7 +499,7 @@ def get_input_column(self): Raise an error because for multi column transformers, ``get_input_columns`` must be used instead. """ - raise ValueError( + raise NotImplementedError( 'MultiColumnTransformers does not have a single input column.' 'Please use ``get_input_columns`` instead.' ) diff --git a/tests/integration/test_transformers.py b/tests/integration/test_transformers.py index 4ed5021df..a8317b685 100644 --- a/tests/integration/test_transformers.py +++ b/tests/integration/test_transformers.py @@ -69,7 +69,7 @@ def _is_valid_transformer(transformer_name): """Determine if transformer should be tested or not.""" invalid_names = [ 'IdentityTransformer', 'Dummy', 'OrderedLabelEncoder', 'CustomLabelEncoder', - 'OrderedUniformEncoder', + 'OrderedUniformEncoder', 'BaseMultiColumnTransformer' ] return all(invalid_name not in transformer_name for invalid_name in invalid_names) diff --git a/tests/unit/transformers/test_base.py b/tests/unit/transformers/test_base.py index e431156b7..2754fa2f8 100644 --- a/tests/unit/transformers/test_base.py +++ b/tests/unit/transformers/test_base.py @@ -7,7 +7,7 @@ import pytest from rdt.errors import TransformerInputError -from rdt.transformers import BaseTransformer, NullTransformer +from rdt.transformers import BaseMultiColumnTransformer, BaseTransformer, NullTransformer from rdt.transformers.base import random_state, set_random_states @@ -1278,13 +1278,73 @@ def _reverse_transform(self, data): class TestBaseMultiColumnTransformer: def test_get_input_column(self): - pass + """Test the ``get_input_column`` method. + + When the ``get_input_column`` method is called, it should raise a ``NotImplementedError``. + """ + # Setup + expected_message = ( + 'MultiColumnTransformers does not have a single input column.' + 'Please use ``get_input_columns`` instead.' + ) + + # Run and Assert + with pytest.raises(NotImplementedError, match=expected_message): + BaseMultiColumnTransformer().get_input_column() def test_get_input_columns(self): - pass + """Test the ``get_input_columns`` method.""" + # Setup + transformer = BaseMultiColumnTransformer() + transformer.columns = ['a', 'b', 'c'] + + # Run + output = transformer.get_input_columns() + + # Assert + assert output == ['a', 'b', 'c'] def test__fit(self): - pass + """Test the ``_fit`` method. + + Check that an error is raised when the ``_fit`` method is called. + """ + # Setup + transformer = BaseMultiColumnTransformer() + + # Run and Assert + with pytest.raises(NotImplementedError): + transformer._fit(None, None) def test_fit(self): - pass + """Test the ``fit`` method.""" + # Setup + transformer = BaseMultiColumnTransformer() + data = Mock() + columns_to_sdtypes = { + 'a': 'numerical', + 'b': 'categorical', + 'c': 'boolean' + } + data_transformer = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': ['a', 'b', 'c'], + }) + transformer.columns = ['a', 'b'] + transformer._store_columns = Mock() + transformer._get_columns_data = Mock(return_value=data_transformer) + transformer._set_seed = Mock() + transformer._fit = Mock() + transformer._build_output_columns = Mock() + + # Run + transformer.fit(data, columns_to_sdtypes) + + # Assert + transformer._store_columns.assert_called_once_with( + list(columns_to_sdtypes.keys()), data + ) + transformer._set_seed.assert_called_once_with(data) + transformer._get_columns_data.assert_called_once_with(data, ['a', 'b']) + transformer._fit.assert_called_once_with(data_transformer, columns_to_sdtypes) + transformer._build_output_columns.assert_called_once_with(data)