Skip to content

Commit

Permalink
Improve user warnings and logic for update_transformers and update_tr…
Browse files Browse the repository at this point in the history
…ansformers_by_sdtype (#695)

* modify update_transformer

* modify validate_config

* modify config printing

* add _generate_column_in_tuple + update transformer

* add flatten_column_list

* add _update_column_in_tuple method

* update_transformers_by_sdtype

* integration test

* unit tests

* rename

* docstring

* remove _create_multi_column_fields in the __init__

* fix rebasing

* use a multi column transformer in integration test

* assert + extra line

* rebase + test end to end with hypertransformer

* add _get_columns_to_sdtype

* rebase

* remove notebook

* undo remove unit test

* 'Remove the column from the tuple and' docstring
  • Loading branch information
R-Palazzo committed Oct 31, 2023
1 parent 248ec0c commit 05ee1cb
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 18 deletions.
65 changes: 51 additions & 14 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class Config(dict):

def __repr__(self):
"""Pretty print the dictionary."""
transformers_repr = {}
for key, value in self['transformers'].items():
transformed_key = str(key)
transformers_repr[transformed_key] = repr(value)

config = {
'sdtypes': self['sdtypes'],
'transformers': {str(k): repr(v) for k, v in self['transformers'].items()}
Expand Down Expand Up @@ -244,7 +249,9 @@ def _validate_config(config):
)

def _validate_update_columns(self, update_columns):
unknown_columns = self._subset(update_columns, self.field_sdtypes.keys(), not_in=True)
unknown_columns = self._subset(
flatten_column_list(update_columns), self.field_sdtypes.keys(), not_in=True
)
if unknown_columns:
raise InvalidConfigError(
f'Invalid column names: {unknown_columns}. These columns do not exist in the '
Expand All @@ -266,6 +273,7 @@ def set_config(self, config):
self._validate_config(config)
self.field_sdtypes.update(config['sdtypes'])
self.field_transformers.update(config['transformers'])
self._multi_column_fields = self._create_multi_column_fields()
self._modified_config = True
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)
Expand Down Expand Up @@ -327,6 +335,28 @@ def _warn_update_transformers_by_sdtype(self, transformer, transformer_name):
'parameters instead.', FutureWarning
)

def _remove_column_in_multi_column_fields(self, column):
"""Remove a column that is part of a multi-column field.
Remove the column from the tuple and modify the ``multi_column_fields``
as well as the ``field_transformers`` dicts accordingly.
Args:
column (str):
Column name to be updated.
"""
old_tuple = self._multi_column_fields.pop(column)
new_tuple = tuple(item for item in old_tuple if item != column)

if len(new_tuple) == 1:
new_tuple, = new_tuple
self._multi_column_fields.pop(new_tuple, None)
else:
for col in new_tuple:
self._multi_column_fields[col] = new_tuple

self.field_transformers[new_tuple] = self.field_transformers.pop(old_tuple)

def update_transformers_by_sdtype(
self, sdtype, transformer=None, transformer_name=None, transformer_parameters=None):
"""Update the transformers for the specified ``sdtype``.
Expand All @@ -351,6 +381,7 @@ def update_transformers_by_sdtype(
self._warn_update_transformers_by_sdtype(transformer, transformer_name)

transformer_instance = transformer

if transformer_name is not None:
if transformer_parameters is not None:
transformer_instance = \
Expand All @@ -362,6 +393,8 @@ def update_transformers_by_sdtype(
for field, field_sdtype in self.field_sdtypes.items():
if field_sdtype == sdtype:
self.field_transformers[field] = deepcopy(transformer_instance)
if field in self._multi_column_fields:
self._remove_column_in_multi_column_fields(field)

self._modified_config = True

Expand Down Expand Up @@ -421,13 +454,20 @@ def update_transformers(self, column_name_to_transformer):
self._validate_transformers(column_name_to_transformer)

for column_name, transformer in column_name_to_transformer.items():
if transformer is not None:
current_sdtype = self.field_sdtypes.get(column_name)
if current_sdtype and current_sdtype not in transformer.get_supported_sdtypes():
raise InvalidConfigError(
f"Column '{column_name}' is a {current_sdtype} column, which is "
f"incompatible with the '{transformer.get_name()}' transformer."
)
columns = column_name if isinstance(column_name, tuple) else (column_name,)
for column in columns:
if transformer is not None:
col_sdtype = self.field_sdtypes.get(column)
if col_sdtype and col_sdtype not in transformer.get_supported_sdtypes():
raise InvalidConfigError(
f"Column '{column}' is a {col_sdtype} column, which is "
f"incompatible with the '{transformer.get_name()}' transformer."
)

if len(columns) > 1 and column in self.field_transformers:
del self.field_transformers[column]
elif column in self._multi_column_fields:
self._remove_column_in_multi_column_fields(column)

self.field_transformers[column_name] = transformer

Expand Down Expand Up @@ -579,16 +619,13 @@ def _fit_field_transformer(self, data, field, transformer):
self._transformers_sequence.append(transformer)
data = transformer.transform(data)

output_columns = transformer.get_output_columns()
next_transformers = transformer.get_next_transformers()
for output_name in output_columns:
output_field = self._multi_column_fields.get(output_name, output_name)
next_transformer = next_transformers[output_field]
for column_name, next_transformer in next_transformers.items():

# If the column is part of a multi-column field, and at least one column
# isn't present in the data, then it should not fit the next transformer
if self._field_in_data(output_field, data):
data = self._fit_field_transformer(data, output_field, next_transformer)
if self._field_in_data(column_name, data):
data = self._fit_field_transformer(data, column_name, next_transformer)

return data

Expand Down
140 changes: 137 additions & 3 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def _reverse_transform(self, data):
return data.astype('datetime64[ns]')


TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0]


class DummyMultiColumnTransformerNumerical(BaseMultiColumnTransformer):
"""Multi column transformer that takes categorical data."""

Expand All @@ -80,6 +77,9 @@ def _reverse_transform(self, data):
return data.astype(str)


TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0]


def get_input_data():
datetimes = pd.to_datetime([
'2010-02-01',
Expand Down Expand Up @@ -1480,3 +1480,137 @@ def test_hypertransformer_with_mutli_column_transformer_end_to_end(self):

pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
pd.testing.assert_frame_equal(reverse_transformed_data, data_test)

def test_update_transformers_single_to_multi_column(self):
"""Test ``update_transformers`` to go from single to mutli column transformer."""
# Setup
dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean'
},
'transformers': {
'A': None,
'B': UniformEncoder(),
'C': UniformEncoder()
}
}
config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
ht.update_transformers({
('A', 'B'): DummyMultiColumnTransformerNumerical(),
})
new_config = ht.get_config()

# Assert
expected_config = Config({
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean'
},
'transformers': {
'C': UniformEncoder(),
"('A', 'B')": DummyMultiColumnTransformerNumerical()
}
})

assert repr(new_config) == repr(expected_config)

def test_update_transformers_multi_to_single_column(self):
"""Test ``update_transformers`` to go from multi to single column transformer."""

# Setup
dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'A': UniformEncoder(),
('B', 'C', 'D'): DummyMultiColumnTransformerNumerical(),
'E': UniformEncoder()
}
}

config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
ht.update_transformers({
('A', 'B'): DummyMultiColumnTransformerNumerical(),
'D': UniformEncoder()
})
new_config = ht.get_config()

# Assert
expected_config = Config({
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'E': UniformEncoder(),
"('A', 'B')": DummyMultiColumnTransformerNumerical(),
'C': DummyMultiColumnTransformerNumerical(),
'D': UniformEncoder()
}
})

assert repr(new_config) == repr(expected_config)

def test_update_transformers_by_sdtype_mutli_column(self):
"""Test ``update_transformers_by_sdtype`` with mutli column transformers."""
# Setup
dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'A': UniformEncoder(),
('B', 'C', 'D'): DummyMultiColumnTransformerNumerical(),
'E': UniformEncoder()
}
}

config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
ht.update_transformers_by_sdtype('boolean', transformer_name='LabelEncoder')
new_config = ht.get_config()

# Assert
expected_config = Config({
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'A': UniformEncoder(),
'E': UniformEncoder(),
'C': LabelEncoder(),
"('B', 'D')": DummyMultiColumnTransformerNumerical()
}
})

assert repr(new_config) == repr(expected_config)
Loading

0 comments on commit 05ee1cb

Please sign in to comment.