Skip to content

Commit

Permalink
Improve user warnings and logic for update_sdtype (#705)
Browse files Browse the repository at this point in the history
* update sdtype

* improve for boucle

* coverage
  • Loading branch information
R-Palazzo committed Oct 27, 2023
1 parent 0d0810f commit 3c90f6c
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 8 deletions.
27 changes: 19 additions & 8 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,25 @@ def update_sdtypes(self, column_name_to_sdtype):

transformers_to_update = {}
for column, sdtype in column_name_to_sdtype.items():
if self.field_sdtypes.get(column) != sdtype:
current_transformer = self.field_transformers.get(column)
supported_sdtypes = []
if current_transformer:
supported_sdtypes = current_transformer.get_supported_sdtypes()

if sdtype not in supported_sdtypes:
transformers_to_update[column] = deepcopy(get_default_transformer(sdtype))
if self.field_sdtypes.get(column) == sdtype:
continue

column_key = self._multi_column_fields.get(column, column)
current_transformer = self.field_transformers.get(column_key)

if current_transformer:
supported_sdtypes = current_transformer.get_supported_sdtypes()
if sdtype in supported_sdtypes:
continue

warnings.warn(
f"Sdtype '{sdtype}' is incompatible with transformer "
f"'{current_transformer.get_name()}'. Assigning a new transformer to it."
)
if column in self._multi_column_fields:
self._remove_column_in_multi_column_fields(column)

transformers_to_update[column] = deepcopy(get_default_transformer(sdtype))

self.field_sdtypes.update(column_name_to_sdtype)
self.field_transformers.update(transformers_to_update)
Expand Down
48 changes: 48 additions & 0 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,3 +1703,51 @@ def test_remove_transformer_by_sdtype(self):
})

assert repr(new_config) == repr(expected_config)

def test_update_sdtype(self):
"""Test ``update_sdtypes`` with multi column transformer."""
# Setup
dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'A': UniformEncoder(),
('B', 'C', 'D'): DummyMultiColumnTransformerNumerical(),
'E': UniformEncoder()
}
}

config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
ht.update_sdtypes({
'C': 'numerical',
'A': 'numerical'
})
new_config = ht.get_config()

# Assert
expected_config = Config({
'sdtypes': {
'A': 'numerical',
'B': 'categorical',
'C': 'numerical',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'A': FloatFormatter(),
'E': UniformEncoder(),
"('B', 'D')": DummyMultiColumnTransformerNumerical(),
'C': FloatFormatter()
}
})

assert repr(new_config) == repr(expected_config)
109 changes: 109 additions & 0 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2947,6 +2947,115 @@ def test_update_sdtypes_different_sdtype_than_transformer(self, mock_warnings, m
assert instance.field_transformers == {'a': transformer}
mock_logger.info.assert_called_once_with(user_message)

def test_update_sdtypes_multi_column_with_supported_sdtypes(self):
"""Test the ``update_sdtypes`` method.
Test that the method works for column that are in a multi-column transformer.
In this case the multi column transformer supports the new sdtype so the transformer
should not be changed.
"""
# Setup
class DummyMultiColumnTransformer(BaseMultiColumnTransformer):
"""Dummy multi column transformer."""

SUPPORTED_SDTYPES = ['categorical', 'boolean']

ht = HyperTransformer()
ht.field_sdtypes = {
'column1': 'categorical',
'column2': 'categorical',
'column3': 'categorical',
'column4': 'categorical'
}
ht.field_transformers = {
'column1': UniformEncoder(),
('column2', 'column3'): DummyMultiColumnTransformer(),
'column4': None
}
ht._multi_column_fields = {
'column2': ('column2', 'column3'),
'column3': ('column2', 'column3')
}

# Run
ht.update_sdtypes(column_name_to_sdtype={
'column2': 'boolean',
'column1': 'boolean',
'column4': 'categorical'
})

# Assert
expected_field_sdtypes = {
'column1': 'boolean',
'column2': 'boolean',
'column3': 'categorical',
'column4': 'categorical'
}
expected_field_transformers = {
'column1': UniformEncoder(),
('column2', 'column3'): DummyMultiColumnTransformer(),
'column4': None
}
assert ht.field_sdtypes == expected_field_sdtypes
assert str(ht.field_transformers) == str(expected_field_transformers)

def test_update_sdtypes_multi_column_with_unsupported_sdtypes(self):
"""Test the ``update_sdtypes`` method.
Test that the method works for column that are in a multi-column transformer.
In this case the multi column transformer does not support the new sdtype so the
transformer should be changed to the default one.
"""
# Setup
class DummyMultiColumnTransformer(BaseMultiColumnTransformer):
"""Dummy multi column transformer."""

SUPPORTED_SDTYPES = ['categorical', 'boolean']

ht = HyperTransformer()
ht.field_sdtypes = {
'column1': 'categorical',
'column2': 'categorical',
'column3': 'categorical',
'column4': 'categorical'
}
ht.field_transformers = {
'column1': UniformEncoder(),
('column2', 'column3'): DummyMultiColumnTransformer(),
'column4': None
}
ht._multi_column_fields = {
'column2': ('column2', 'column3'),
'column3': ('column2', 'column3')
}

# Run
expected_warning = (
"Sdtype 'numerical' is incompatible with transformer 'DummyMultiColumnTransformer'."
' Assigning a new transformer to it.'
)
with pytest.warns(UserWarning, match=expected_warning):
ht.update_sdtypes(column_name_to_sdtype={
'column2': 'numerical',
'column1': 'boolean'
})

# Assert
expected_field_sdtypes = {
'column1': 'boolean',
'column2': 'numerical',
'column3': 'categorical',
'column4': 'categorical'
}
expected_field_transformers = {
'column1': UniformEncoder(),
'column4': None,
'column3': DummyMultiColumnTransformer(),
'column2': FloatFormatter(),
}
assert ht.field_sdtypes == expected_field_sdtypes
assert str(ht.field_transformers) == str(expected_field_transformers)

def test__validate_update_columns(self):
"""Test ``_validate_update_columns``.
Expand Down

0 comments on commit 3c90f6c

Please sign in to comment.