From 1a86850dfe2c674297b6f3a59b7df1e18e62bfbc Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Thu, 14 Sep 2023 13:45:13 +0200 Subject: [PATCH] Improve user warnings and logic for update_sdtype (#705) * update sdtype * improve for boucle * coverage --- rdt/hyper_transformer.py | 27 +++-- tests/integration/test_hyper_transformer.py | 48 +++++++++ tests/unit/test_hyper_transformer.py | 109 ++++++++++++++++++++ 3 files changed, 176 insertions(+), 8 deletions(-) diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index ae8ffeccc..f4f34e29c 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -416,14 +416,25 @@ def update_sdtypes(self, column_name_to_sdtype): transformers_to_update = {} for column, sdtype in column_name_to_sdtype.items(): - if self.field_sdtypes.get(column) != sdtype: - current_transformer = self.field_transformers.get(column) - supported_sdtypes = [] - if current_transformer: - supported_sdtypes = current_transformer.get_supported_sdtypes() - - if sdtype not in supported_sdtypes: - transformers_to_update[column] = deepcopy(get_default_transformer(sdtype)) + if self.field_sdtypes.get(column) == sdtype: + continue + + column_key = self._multi_column_fields.get(column, column) + current_transformer = self.field_transformers.get(column_key) + + if current_transformer: + supported_sdtypes = current_transformer.get_supported_sdtypes() + if sdtype in supported_sdtypes: + continue + + warnings.warn( + f"Sdtype '{sdtype}' is incompatible with transformer " + f"'{current_transformer.get_name()}'. Assigning a new transformer to it." + ) + if column in self._multi_column_fields: + self._remove_column_in_multi_column_fields(column) + + transformers_to_update[column] = deepcopy(get_default_transformer(sdtype)) self.field_sdtypes.update(column_name_to_sdtype) self.field_transformers.update(transformers_to_update) diff --git a/tests/integration/test_hyper_transformer.py b/tests/integration/test_hyper_transformer.py index 5dcab338c..158746047 100644 --- a/tests/integration/test_hyper_transformer.py +++ b/tests/integration/test_hyper_transformer.py @@ -1655,3 +1655,51 @@ def test_remove_transformer_by_sdtype(self): }) assert repr(new_config) == repr(expected_config) + + def test_update_sdtype(self): + """Test ``update_sdtypes`` with multi column transformer.""" + # Setup + dict_config = { + 'sdtypes': { + 'A': 'categorical', + 'B': 'categorical', + 'C': 'boolean', + 'D': 'categorical', + 'E': 'categorical' + }, + 'transformers': { + 'A': UniformEncoder(), + ('B', 'C', 'D'): DummyMultiColumnTransformerNumerical(), + 'E': UniformEncoder() + } + } + + config = Config(dict_config) + ht = HyperTransformer() + ht.set_config(config) + + # Run + ht.update_sdtypes({ + 'C': 'numerical', + 'A': 'numerical' + }) + new_config = ht.get_config() + + # Assert + expected_config = Config({ + 'sdtypes': { + 'A': 'numerical', + 'B': 'categorical', + 'C': 'numerical', + 'D': 'categorical', + 'E': 'categorical' + }, + 'transformers': { + 'A': FloatFormatter(), + 'E': UniformEncoder(), + "('B', 'D')": DummyMultiColumnTransformerNumerical(), + 'C': FloatFormatter() + } + }) + + assert repr(new_config) == repr(expected_config) diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index 379024244..4dd24ab05 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -2947,6 +2947,115 @@ def test_update_sdtypes_different_sdtype_than_transformer(self, mock_warnings, m assert instance.field_transformers == {'a': transformer} mock_logger.info.assert_called_once_with(user_message) + def test_update_sdtypes_multi_column_with_supported_sdtypes(self): + """Test the ``update_sdtypes`` method. + + Test that the method works for column that are in a multi-column transformer. + In this case the multi column transformer supports the new sdtype so the transformer + should not be changed. + """ + # Setup + class DummyMultiColumnTransformer(BaseMultiColumnTransformer): + """Dummy multi column transformer.""" + + SUPPORTED_SDTYPES = ['categorical', 'boolean'] + + ht = HyperTransformer() + ht.field_sdtypes = { + 'column1': 'categorical', + 'column2': 'categorical', + 'column3': 'categorical', + 'column4': 'categorical' + } + ht.field_transformers = { + 'column1': UniformEncoder(), + ('column2', 'column3'): DummyMultiColumnTransformer(), + 'column4': None + } + ht._multi_column_fields = { + 'column2': ('column2', 'column3'), + 'column3': ('column2', 'column3') + } + + # Run + ht.update_sdtypes(column_name_to_sdtype={ + 'column2': 'boolean', + 'column1': 'boolean', + 'column4': 'categorical' + }) + + # Assert + expected_field_sdtypes = { + 'column1': 'boolean', + 'column2': 'boolean', + 'column3': 'categorical', + 'column4': 'categorical' + } + expected_field_transformers = { + 'column1': UniformEncoder(), + ('column2', 'column3'): DummyMultiColumnTransformer(), + 'column4': None + } + assert ht.field_sdtypes == expected_field_sdtypes + assert str(ht.field_transformers) == str(expected_field_transformers) + + def test_update_sdtypes_multi_column_with_unsupported_sdtypes(self): + """Test the ``update_sdtypes`` method. + + Test that the method works for column that are in a multi-column transformer. + In this case the multi column transformer does not support the new sdtype so the + transformer should be changed to the default one. + """ + # Setup + class DummyMultiColumnTransformer(BaseMultiColumnTransformer): + """Dummy multi column transformer.""" + + SUPPORTED_SDTYPES = ['categorical', 'boolean'] + + ht = HyperTransformer() + ht.field_sdtypes = { + 'column1': 'categorical', + 'column2': 'categorical', + 'column3': 'categorical', + 'column4': 'categorical' + } + ht.field_transformers = { + 'column1': UniformEncoder(), + ('column2', 'column3'): DummyMultiColumnTransformer(), + 'column4': None + } + ht._multi_column_fields = { + 'column2': ('column2', 'column3'), + 'column3': ('column2', 'column3') + } + + # Run + expected_warning = ( + "Sdtype 'numerical' is incompatible with transformer 'DummyMultiColumnTransformer'." + ' Assigning a new transformer to it.' + ) + with pytest.warns(UserWarning, match=expected_warning): + ht.update_sdtypes(column_name_to_sdtype={ + 'column2': 'numerical', + 'column1': 'boolean' + }) + + # Assert + expected_field_sdtypes = { + 'column1': 'boolean', + 'column2': 'numerical', + 'column3': 'categorical', + 'column4': 'categorical' + } + expected_field_transformers = { + 'column1': UniformEncoder(), + 'column4': None, + 'column3': DummyMultiColumnTransformer(), + 'column2': FloatFormatter(), + } + assert ht.field_sdtypes == expected_field_sdtypes + assert str(ht.field_transformers) == str(expected_field_transformers) + def test__validate_update_columns(self): """Test ``_validate_update_columns``.