diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index e3eee078..724edbe6 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -24,6 +24,11 @@ class Config(dict): def __repr__(self): """Pretty print the dictionary.""" + transformers_repr = {} + for key, value in self['transformers'].items(): + transformed_key = str(key) + transformers_repr[transformed_key] = repr(value) + config = { 'sdtypes': self['sdtypes'], 'transformers': {str(k): repr(v) for k, v in self['transformers'].items()} @@ -244,7 +249,9 @@ def _validate_config(config): ) def _validate_update_columns(self, update_columns): - unknown_columns = self._subset(update_columns, self.field_sdtypes.keys(), not_in=True) + unknown_columns = self._subset( + flatten_column_list(update_columns), self.field_sdtypes.keys(), not_in=True + ) if unknown_columns: raise InvalidConfigError( f'Invalid column names: {unknown_columns}. These columns do not exist in the ' @@ -266,6 +273,7 @@ def set_config(self, config): self._validate_config(config) self.field_sdtypes.update(config['sdtypes']) self.field_transformers.update(config['transformers']) + self._multi_column_fields = self._create_multi_column_fields() self._modified_config = True if self._fitted: warnings.warn(self._REFIT_MESSAGE) @@ -327,6 +335,28 @@ def _warn_update_transformers_by_sdtype(self, transformer, transformer_name): 'parameters instead.', FutureWarning ) + def _remove_column_in_multi_column_fields(self, column): + """Remove a column that is part of a multi-column field. + + Remove the column from the tuple and modify the ``multi_column_fields`` + as well as the ``field_transformers`` dicts accordingly. + + Args: + column (str): + Column name to be updated. + """ + old_tuple = self._multi_column_fields.pop(column) + new_tuple = tuple(item for item in old_tuple if item != column) + + if len(new_tuple) == 1: + new_tuple, = new_tuple + self._multi_column_fields.pop(new_tuple, None) + else: + for col in new_tuple: + self._multi_column_fields[col] = new_tuple + + self.field_transformers[new_tuple] = self.field_transformers.pop(old_tuple) + def update_transformers_by_sdtype( self, sdtype, transformer=None, transformer_name=None, transformer_parameters=None): """Update the transformers for the specified ``sdtype``. @@ -351,6 +381,7 @@ def update_transformers_by_sdtype( self._warn_update_transformers_by_sdtype(transformer, transformer_name) transformer_instance = transformer + if transformer_name is not None: if transformer_parameters is not None: transformer_instance = \ @@ -362,6 +393,8 @@ def update_transformers_by_sdtype( for field, field_sdtype in self.field_sdtypes.items(): if field_sdtype == sdtype: self.field_transformers[field] = deepcopy(transformer_instance) + if field in self._multi_column_fields: + self._remove_column_in_multi_column_fields(field) self._modified_config = True @@ -421,13 +454,20 @@ def update_transformers(self, column_name_to_transformer): self._validate_transformers(column_name_to_transformer) for column_name, transformer in column_name_to_transformer.items(): - if transformer is not None: - current_sdtype = self.field_sdtypes.get(column_name) - if current_sdtype and current_sdtype not in transformer.get_supported_sdtypes(): - raise InvalidConfigError( - f"Column '{column_name}' is a {current_sdtype} column, which is " - f"incompatible with the '{transformer.get_name()}' transformer." - ) + columns = column_name if isinstance(column_name, tuple) else (column_name,) + for column in columns: + if transformer is not None: + col_sdtype = self.field_sdtypes.get(column) + if col_sdtype and col_sdtype not in transformer.get_supported_sdtypes(): + raise InvalidConfigError( + f"Column '{column}' is a {col_sdtype} column, which is " + f"incompatible with the '{transformer.get_name()}' transformer." + ) + + if len(columns) > 1 and column in self.field_transformers: + del self.field_transformers[column] + elif column in self._multi_column_fields: + self._remove_column_in_multi_column_fields(column) self.field_transformers[column_name] = transformer @@ -579,16 +619,13 @@ def _fit_field_transformer(self, data, field, transformer): self._transformers_sequence.append(transformer) data = transformer.transform(data) - output_columns = transformer.get_output_columns() next_transformers = transformer.get_next_transformers() - for output_name in output_columns: - output_field = self._multi_column_fields.get(output_name, output_name) - next_transformer = next_transformers[output_field] + for column_name, next_transformer in next_transformers.items(): # If the column is part of a multi-column field, and at least one column # isn't present in the data, then it should not fit the next transformer - if self._field_in_data(output_field, data): - data = self._fit_field_transformer(data, output_field, next_transformer) + if self._field_in_data(column_name, data): + data = self._fit_field_transformer(data, column_name, next_transformer) return data diff --git a/tests/integration/test_hyper_transformer.py b/tests/integration/test_hyper_transformer.py index 4fb3967c..ca4e1822 100644 --- a/tests/integration/test_hyper_transformer.py +++ b/tests/integration/test_hyper_transformer.py @@ -54,9 +54,6 @@ def _reverse_transform(self, data): return data.astype('datetime64[ns]') -TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0] - - class DummyMultiColumnTransformerNumerical(BaseMultiColumnTransformer): """Multi column transformer that takes categorical data.""" @@ -80,6 +77,9 @@ def _reverse_transform(self, data): return data.astype(str) +TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0] + + def get_input_data(): datetimes = pd.to_datetime([ '2010-02-01', @@ -1480,3 +1480,137 @@ def test_hypertransformer_with_mutli_column_transformer_end_to_end(self): pd.testing.assert_frame_equal(transformed_data, expected_transformed_data) pd.testing.assert_frame_equal(reverse_transformed_data, data_test) + + def test_update_transformers_single_to_multi_column(self): + """Test ``update_transformers`` to go from single to mutli column transformer.""" + # Setup + dict_config = { + 'sdtypes': { + 'A': 'categorical', + 'B': 'categorical', + 'C': 'boolean' + }, + 'transformers': { + 'A': None, + 'B': UniformEncoder(), + 'C': UniformEncoder() + } + } + config = Config(dict_config) + ht = HyperTransformer() + ht.set_config(config) + + # Run + ht.update_transformers({ + ('A', 'B'): DummyMultiColumnTransformerNumerical(), + }) + new_config = ht.get_config() + + # Assert + expected_config = Config({ + 'sdtypes': { + 'A': 'categorical', + 'B': 'categorical', + 'C': 'boolean' + }, + 'transformers': { + 'C': UniformEncoder(), + "('A', 'B')": DummyMultiColumnTransformerNumerical() + } + }) + + assert repr(new_config) == repr(expected_config) + + def test_update_transformers_multi_to_single_column(self): + """Test ``update_transformers`` to go from multi to single column transformer.""" + + # Setup + dict_config = { + 'sdtypes': { + 'A': 'categorical', + 'B': 'categorical', + 'C': 'boolean', + 'D': 'categorical', + 'E': 'categorical' + }, + 'transformers': { + 'A': UniformEncoder(), + ('B', 'C', 'D'): DummyMultiColumnTransformerNumerical(), + 'E': UniformEncoder() + } + } + + config = Config(dict_config) + ht = HyperTransformer() + ht.set_config(config) + + # Run + ht.update_transformers({ + ('A', 'B'): DummyMultiColumnTransformerNumerical(), + 'D': UniformEncoder() + }) + new_config = ht.get_config() + + # Assert + expected_config = Config({ + 'sdtypes': { + 'A': 'categorical', + 'B': 'categorical', + 'C': 'boolean', + 'D': 'categorical', + 'E': 'categorical' + }, + 'transformers': { + 'E': UniformEncoder(), + "('A', 'B')": DummyMultiColumnTransformerNumerical(), + 'C': DummyMultiColumnTransformerNumerical(), + 'D': UniformEncoder() + } + }) + + assert repr(new_config) == repr(expected_config) + + def test_update_transformers_by_sdtype_mutli_column(self): + """Test ``update_transformers_by_sdtype`` with mutli column transformers.""" + # Setup + dict_config = { + 'sdtypes': { + 'A': 'categorical', + 'B': 'categorical', + 'C': 'boolean', + 'D': 'categorical', + 'E': 'categorical' + }, + 'transformers': { + 'A': UniformEncoder(), + ('B', 'C', 'D'): DummyMultiColumnTransformerNumerical(), + 'E': UniformEncoder() + } + } + + config = Config(dict_config) + ht = HyperTransformer() + ht.set_config(config) + + # Run + ht.update_transformers_by_sdtype('boolean', transformer_name='LabelEncoder') + new_config = ht.get_config() + + # Assert + expected_config = Config({ + 'sdtypes': { + 'A': 'categorical', + 'B': 'categorical', + 'C': 'boolean', + 'D': 'categorical', + 'E': 'categorical' + }, + 'transformers': { + 'A': UniformEncoder(), + 'E': UniformEncoder(), + 'C': LabelEncoder(), + "('B', 'D')": DummyMultiColumnTransformerNumerical() + } + }) + + assert repr(new_config) == repr(expected_config) diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index b9e0924f..24d8c93e 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -12,7 +12,7 @@ TransformerProcessingError) from rdt.transformers import ( AnonymizedFaker, BaseMultiColumnTransformer, BinaryEncoder, FloatFormatter, FrequencyEncoder, - LabelEncoder, RegexGenerator, UnixTimestampEncoder) + LabelEncoder, RegexGenerator, UniformEncoder, UnixTimestampEncoder) from rdt.transformers.base import BaseTransformer from rdt.transformers.numerical import ClusterBasedNormalizer @@ -2236,6 +2236,107 @@ def test_update_transformers_by_sdtype_with_transformer_name_transformer_paramet assert isinstance(ht.field_transformers['categorical_column'], LabelEncoder) assert ht.field_transformers['categorical_column'].order_by == 'alphabetical' + def test_create_multi_column_fields(self): + """Test ``_create_multi_column_fields``.""" + # Setup + ht = HyperTransformer() + + ht.field_transformers = { + ('column1', 'column2'): 'transformer1', + ('column3', 'column4'): 'transformer2', + } + + # Run + self._multi_column_fields = ht._create_multi_column_fields() + + # Assert + expected_mapping = { + 'column1': ('column1', 'column2'), + 'column2': ('column1', 'column2'), + 'column3': ('column3', 'column4'), + 'column4': ('column3', 'column4'), + } + + assert self._multi_column_fields == expected_mapping + + def test_remove_column_in_multi_column_fields(self): + """Test ``_remove_column_in_multi_column_fields``.""" + # Setup + ht = HyperTransformer() + ht.field_transformers = { + ('column1', 'column2', 'column3'): 'transformer', + } + ht._multi_column_fields = { + 'column1': ('column1', 'column2', 'column3'), + 'column2': ('column1', 'column2', 'column3'), + 'column3': ('column1', 'column2', 'column3'), + } + # Run + ht._remove_column_in_multi_column_fields('column1') + + # Assert + expected_column_in_tuple = { + 'column2': ('column2', 'column3'), + 'column3': ('column2', 'column3'), + } + expected_field_transformers = {('column2', 'column3'): 'transformer'} + assert ht.field_transformers == expected_field_transformers + assert ht._multi_column_fields == expected_column_in_tuple + + def test_remove_column_in_multi_column_fields_single_column_left(self): + """Test ``_remove_column_in_multi_column_fields`` with one column left in the tuple.""" + # Setup + ht = HyperTransformer() + ht.field_transformers = { + ('column1', 'column2'): 'transformer', + } + ht._multi_column_fields = { + 'column1': ('column1', 'column2'), + 'column2': ('column1', 'column2'), + } + # Run + ht._remove_column_in_multi_column_fields('column1') + + # Assert + expected_column_in_tuple = {} + expected_field_transformers = {'column2': 'transformer'} + assert ht.field_transformers == expected_field_transformers + assert ht._multi_column_fields == expected_column_in_tuple + + def test_update_transformers_by_sdtype_with_multi_column_transformer(self): + """Test ``update_transformers_by_sdtype`` with columns use with a multi-column transformer. + """ + # Setup + ht = HyperTransformer() + ht.field_transformers = { + 'A': LabelEncoder(), + 'B': UniformEncoder(), + "('C', 'D')": None, + } + ht.field_sdtypes = { + 'A': 'categorical', + 'B': 'boolean', + 'C': 'categorical', + 'D': 'numerical' + } + + ht._multi_column_fields = { + 'C': ('C', 'D'), + 'D': ('C', 'D') + } + mock__remove_column_in_multi_column_fields = Mock() + ht._remove_column_in_multi_column_fields = mock__remove_column_in_multi_column_fields + + # Run + ht.update_transformers_by_sdtype( + 'categorical', + transformer_name='LabelEncoder', + ) + + # Assert + assert len(ht.field_transformers) == 4 + assert mock__remove_column_in_multi_column_fields.call_count == 1 + @patch('rdt.hyper_transformer.warnings') def test_update_transformers_fitted(self, mock_warnings): """Test update transformers. @@ -2282,6 +2383,83 @@ def test_update_transformers_fitted(self, mock_warnings): assert instance.field_transformers['my_column'] == transformer instance._validate_transformers.assert_called_once_with(column_name_to_transformer) + def test_update_transformers_multi_column(self): + """Test ``update_transformers`` with a multi-column transformer.""" + # Setup + ht = HyperTransformer() + ht.field_sdtypes = { + 'A': 'categorical', + 'B': 'boolean', + 'C': 'numerical', + } + ht.field_transformers = { + 'A': LabelEncoder(), + 'B': UniformEncoder(), + 'C': FloatFormatter(), + } + + column_name_to_transformer = { + ('A', 'B'): None, + 'C': None, + } + # Run + ht.update_transformers(column_name_to_transformer) + + # Assert + expected_field_transformers = { + ('A', 'B'): None, + 'C': None, + } + assert ht.field_transformers == expected_field_transformers + + def test_update_transformers_changing_multi_column_transformer(self): + """Test ``update_transformers`` when changing a multi column transformer.""" + # Setup + ht = HyperTransformer() + ht._multi_column_fields = { + 'A': ('A', 'B'), + 'B': ('A', 'B'), + } + ht.field_sdtypes = { + 'A': 'categorical', + 'B': 'boolean', + 'C': 'numerical', + } + ht.field_transformers = { + ('A', 'B'): None, + 'C': FloatFormatter(), + } + + column_name_to_transformer = { + 'A': UniformEncoder(), + } + + def side_effect(column): + ht._multi_column_fields = { + 'B': ('B',) + } + ht.field_transformers = { + 'C': FloatFormatter(), + 'B': None, + 'A': UniformEncoder() + } + + mock_remove_column_in_multi_column_fields = Mock() + mock_remove_column_in_multi_column_fields.side_effect = side_effect + ht._remove_column_in_multi_column_fields = mock_remove_column_in_multi_column_fields + + # Run + ht.update_transformers(column_name_to_transformer) + + # Assert + expected_field_transformers = { + 'C': FloatFormatter(), + 'B': None, + 'A': UniformEncoder() + } + mock_remove_column_in_multi_column_fields.assert_called_once_with('A') + assert str(ht.field_transformers) == str(expected_field_transformers) + @patch('rdt.hyper_transformer.warnings') def test_update_transformers_not_fitted(self, mock_warnings): """Test update transformers.