From c8d0ef69fd614ee27acea2607ac0a8ae134d1860 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Wed, 23 Aug 2023 18:38:34 +0200 Subject: [PATCH] unit tests --- rdt/hyper_transformer.py | 68 +++++----- tests/unit/test_hyper_transformer.py | 180 ++++++++++++++++++++++++++- 2 files changed, 213 insertions(+), 35 deletions(-) diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index a2a31efb3..3585d42e7 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -332,6 +332,39 @@ def _warn_update_transformers_by_sdtype(self, transformer, transformer_name): 'parameters instead.', FutureWarning ) + def _generate_column_in_tuple(self): + """Generate a dict mapping columns in a tuple to the tuple itself.""" + column_in_tuple = {} + for fields in self.field_transformers: + if isinstance(fields, tuple): + for column in fields: + column_in_tuple[column] = fields + + return column_in_tuple + + def _update_column_in_tuple(self, column, column_in_tuple): + """Update the column in tuple dict. + + Args: + column (str): + Column name to be updated. + column_in_tuple (dict): + Dict mapping columns in a tuple to the tuple itself. + """ + old_tuple = column_in_tuple.pop(column) + new_tuple = tuple(item for item in old_tuple if item != column) + + if len(new_tuple) == 1: + new_tuple, = new_tuple + column_in_tuple.pop(new_tuple, None) + else: + for col in new_tuple: + column_in_tuple[col] = new_tuple + + self.field_transformers[new_tuple] = self.field_transformers.pop(old_tuple) + + return column_in_tuple + def update_transformers_by_sdtype( self, sdtype, transformer=None, transformer_name=None, transformer_parameters=None): """Update the transformers for the specified ``sdtype``. @@ -370,7 +403,7 @@ def update_transformers_by_sdtype( if field_sdtype == sdtype: self.field_transformers[field] = deepcopy(transformer_instance) if field in column_in_tuple: - self._update_column_in_tuple(field, column_in_tuple) + column_in_tuple = self._update_column_in_tuple(field, column_in_tuple) self._modified_config = True @@ -412,39 +445,6 @@ def update_sdtypes(self, column_name_to_sdtype): if self._fitted: warnings.warn(self._REFIT_MESSAGE) - def _generate_column_in_tuple(self): - """Generate a dict mapping columns in a tuple to the tuple itself.""" - column_in_tuple = {} - for fields in self.field_transformers: - if isinstance(fields, tuple): - for column in fields: - column_in_tuple[column] = fields - - return column_in_tuple - - def _update_column_in_tuple(self, column, column_in_tuple): - """Update the column in tuple dict. - - Args: - column (str): - Column name to be updated. - column_in_tuple (dict): - Dict mapping columns in a tuple to the tuple itself. - """ - old_tuple = column_in_tuple[column] - new_tuple = tuple( - item for item in old_tuple if item != column - ) - if len(new_tuple) == 1: - new_tuple = new_tuple[0] - del column_in_tuple[new_tuple] - else: - for col in new_tuple: - column_in_tuple[col] = new_tuple - - self.field_transformers[new_tuple] = self.field_transformers[old_tuple] - del self.field_transformers[old_tuple] - def update_transformers(self, column_name_to_transformer): """Update any of the transformers assigned to each of the column names. diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index 3add05603..0a94c52f6 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -12,7 +12,7 @@ TransformerProcessingError) from rdt.transformers import ( AnonymizedFaker, BinaryEncoder, FloatFormatter, FrequencyEncoder, LabelEncoder, RegexGenerator, - UnixTimestampEncoder) + UniformEncoder, UnixTimestampEncoder) from rdt.transformers.base import BaseTransformer from rdt.transformers.numerical import ClusterBasedNormalizer @@ -2135,6 +2135,110 @@ def test_update_transformers_by_sdtype_with_transformer_name_transformer_paramet assert isinstance(ht.field_transformers['categorical_column'], LabelEncoder) assert ht.field_transformers['categorical_column'].order_by == 'alphabetical' + def test_generate_column_in_tuple(self): + """Test ``_generate_column_in_tuple``.""" + # Setup + ht = HyperTransformer() + + ht.field_transformers = { + ('column1', 'column2'): 'transformer1', + ('column3', 'column4'): 'transformer2', + } + + # Run + column_in_tuple = ht._generate_column_in_tuple() + + # Assert + expected_mapping = { + 'column1': ('column1', 'column2'), + 'column2': ('column1', 'column2'), + 'column3': ('column3', 'column4'), + 'column4': ('column3', 'column4'), + } + + assert column_in_tuple == expected_mapping + + def test_update_column_in_tuple(self): + """Test ``_update_column_in_tuple``.""" + # Setup + column_in_tuple = { + 'column1': ('column1', 'column2', 'column3'), + 'column2': ('column1', 'column2', 'column3'), + 'column3': ('column1', 'column2', 'column3'), + } + ht = HyperTransformer() + ht.field_transformers = { + ('column1', 'column2', 'column3'): 'transformer', + } + # Run + result = ht._update_column_in_tuple('column1', column_in_tuple) + + # Assert + expected_column_in_tuple = { + 'column2': ('column2', 'column3'), + 'column3': ('column2', 'column3'), + } + expected_field_transformers = {('column2', 'column3'): 'transformer'} + assert ht.field_transformers == expected_field_transformers + assert result == expected_column_in_tuple + + def test_update_column_in_tuple_single_column_left(self): + """Test ``_update_column_in_tuple`` with a single column left in the tuple.""" + # Setup + column_in_tuple = { + 'column1': ('column1', 'column2'), + 'column2': ('column1', 'column2'), + } + ht = HyperTransformer() + ht.field_transformers = { + ('column1', 'column2'): 'transformer', + } + # Run + result = ht._update_column_in_tuple('column1', column_in_tuple) + + # Assert + expected_column_in_tuple = {} + expected_field_transformers = {'column2': 'transformer'} + assert ht.field_transformers == expected_field_transformers + assert result == expected_column_in_tuple + + def test_update_transformers_by_sdtype_with_multi_column_transformer(self): + """Test ``update_transformers_by_sdtype`` with columns use with a multi-column transformer. + """ + # Setup + ht = HyperTransformer() + ht.field_transformers = { + 'A': LabelEncoder(), + 'B': UniformEncoder(), + "('C', 'D')": None, + } + ht.field_sdtypes = { + 'A': 'categorical', + 'B': 'boolean', + 'C': 'categorical', + 'D': 'numerical' + } + + column_in_tuple = { + 'C': ('C', 'D'), + 'D': ('C', 'D') + } + mock__update_column_in_tuple = Mock() + mock__generate_column_in_tuple = Mock(return_value=column_in_tuple) + ht._update_column_in_tuple = mock__update_column_in_tuple + ht._generate_column_in_tuple = mock__generate_column_in_tuple + + # Run + ht.update_transformers_by_sdtype( + 'categorical', + transformer_name='LabelEncoder', + ) + + # Assert + assert len(ht.field_transformers) == 4 + mock__generate_column_in_tuple.assert_called_once() + assert mock__update_column_in_tuple.call_count == 1 + @patch('rdt.hyper_transformer.warnings') def test_update_transformers_fitted(self, mock_warnings): """Test update transformers. @@ -2168,6 +2272,9 @@ def test_update_transformers_fitted(self, mock_warnings): 'my_column': transformer } + mock__generate_column_in_tuple = Mock(return_value={}) + instance._generate_column_in_tuple = mock__generate_column_in_tuple + # Run instance.update_transformers(column_name_to_transformer) @@ -2180,6 +2287,77 @@ def test_update_transformers_fitted(self, mock_warnings): mock_warnings.warn.assert_called_once_with(expected_message) assert instance.field_transformers['my_column'] == transformer instance._validate_transformers.assert_called_once_with(column_name_to_transformer) + mock__generate_column_in_tuple.assert_called_once() + + def test_update_transformers_multi_column(self): + """Test ``update_transformers`` with a multi-column transformer.""" + # Setup + ht = HyperTransformer() + mock__generate_column_in_tuple = Mock(return_value={}) + mock__update_column_in_tuple = Mock() + ht._generate_column_in_tuple = mock__generate_column_in_tuple + ht._update_column_in_tuple = mock__update_column_in_tuple + ht.field_sdtypes = { + 'A': 'categorical', + 'B': 'boolean', + 'C': 'numerical', + } + ht.field_transformers = { + 'A': LabelEncoder(), + 'B': UniformEncoder(), + 'C': FloatFormatter(), + } + + column_name_to_transformer = { + ('A', 'B'): None, + 'C': None, + } + # Run + ht.update_transformers(column_name_to_transformer) + + # Assert + expected_field_transformers = { + ('A', 'B'): None, + 'C': None, + } + ht.field_transformers == expected_field_transformers + mock__generate_column_in_tuple.assert_called_once() + + def test_update_transformers_changing_multi_column_transformer(self): + """Test ``update_transformers`` when changing a mulit column transformer.""" + # Setup + ht = HyperTransformer() + mock__generate_column_in_tuple = Mock(return_value={ + 'A': ('A', 'B'), + 'B': ('A', 'B'), + }) + mock__update_column_in_tuple = Mock() + ht._generate_column_in_tuple = mock__generate_column_in_tuple + ht._update_column_in_tuple = mock__update_column_in_tuple + ht.field_sdtypes = { + 'A': 'categorical', + 'B': 'boolean', + 'C': 'numerical', + } + ht.field_transformers = { + ('A', 'B'): None, + 'C': FloatFormatter(), + } + + column_name_to_transformer = { + 'A': UniformEncoder(), + } + # Run + ht.update_transformers(column_name_to_transformer) + + # Assert + expected_field_transformers = { + 'A': UniformEncoder(), + 'B': None, + 'C': FloatFormatter(), + } + ht.field_transformers == expected_field_transformers + mock__generate_column_in_tuple.assert_called_once() @patch('rdt.hyper_transformer.warnings') def test_update_transformers_not_fitted(self, mock_warnings):