Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve user warnings and logic for update_transformers and update_transformers_by_sdtype #695

Merged
merged 21 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class Config(dict):

def __repr__(self):
"""Pretty print the dictionary."""
transformers_repr = {}
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
for key, value in self['transformers'].items():
transformed_key = str(key)
transformers_repr[transformed_key] = repr(value)

config = {
'sdtypes': self['sdtypes'],
'transformers': {str(k): repr(v) for k, v in self['transformers'].items()}
Expand Down Expand Up @@ -244,7 +249,9 @@ def _validate_config(config):
)

def _validate_update_columns(self, update_columns):
unknown_columns = self._subset(update_columns, self.field_sdtypes.keys(), not_in=True)
unknown_columns = self._subset(
flatten_column_list(update_columns), self.field_sdtypes.keys(), not_in=True
)
if unknown_columns:
raise InvalidConfigError(
f'Invalid column names: {unknown_columns}. These columns do not exist in the '
Expand All @@ -266,6 +273,7 @@ def set_config(self, config):
self._validate_config(config)
self.field_sdtypes.update(config['sdtypes'])
self.field_transformers.update(config['transformers'])
self._multi_column_fields = self._create_multi_column_fields()
self._modified_config = True
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)
Expand Down Expand Up @@ -327,6 +335,28 @@ def _warn_update_transformers_by_sdtype(self, transformer, transformer_name):
'parameters instead.', FutureWarning
)

def _remove_column_in_multi_column_fields(self, column):
"""Remove a column that is part of a multi-column field.

Remove the column from the tuple and modify the ``multi_column_fields``
as well as the ``field_transformers`` dicts accordingly.

Args:
column (str):
Column name to be updated.
"""
old_tuple = self._multi_column_fields.pop(column)
new_tuple = tuple(item for item in old_tuple if item != column)

if len(new_tuple) == 1:
new_tuple, = new_tuple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the point of this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to not have a single element tuple in the field_transformer dict. Without this line I have a test failing with the error:
AssertionError: assert {('column2',): 'transformer'} == {'column2': 'transformer'}
The right is the expected dict

self._multi_column_fields.pop(new_tuple, None)
else:
for col in new_tuple:
self._multi_column_fields[col] = new_tuple

self.field_transformers[new_tuple] = self.field_transformers.pop(old_tuple)

def update_transformers_by_sdtype(
self, sdtype, transformer=None, transformer_name=None, transformer_parameters=None):
"""Update the transformers for the specified ``sdtype``.
Expand All @@ -351,6 +381,7 @@ def update_transformers_by_sdtype(
self._warn_update_transformers_by_sdtype(transformer, transformer_name)

transformer_instance = transformer

if transformer_name is not None:
if transformer_parameters is not None:
transformer_instance = \
Expand All @@ -362,6 +393,8 @@ def update_transformers_by_sdtype(
for field, field_sdtype in self.field_sdtypes.items():
if field_sdtype == sdtype:
self.field_transformers[field] = deepcopy(transformer_instance)
if field in self._multi_column_fields:
self._remove_column_in_multi_column_fields(field)

self._modified_config = True

Expand Down Expand Up @@ -421,13 +454,20 @@ def update_transformers(self, column_name_to_transformer):
self._validate_transformers(column_name_to_transformer)

for column_name, transformer in column_name_to_transformer.items():
if transformer is not None:
current_sdtype = self.field_sdtypes.get(column_name)
if current_sdtype and current_sdtype not in transformer.get_supported_sdtypes():
raise InvalidConfigError(
f"Column '{column_name}' is a {current_sdtype} column, which is "
f"incompatible with the '{transformer.get_name()}' transformer."
)
columns = column_name if isinstance(column_name, tuple) else (column_name,)
for column in columns:
if transformer is not None:
col_sdtype = self.field_sdtypes.get(column)
if col_sdtype and col_sdtype not in transformer.get_supported_sdtypes():
raise InvalidConfigError(
f"Column '{column}' is a {col_sdtype} column, which is "
f"incompatible with the '{transformer.get_name()}' transformer."
)

if len(columns) > 1 and column in self.field_transformers:
del self.field_transformers[column]
elif column in self._multi_column_fields:
self._remove_column_in_multi_column_fields(column)

self.field_transformers[column_name] = transformer

Expand Down Expand Up @@ -579,16 +619,13 @@ def _fit_field_transformer(self, data, field, transformer):
self._transformers_sequence.append(transformer)
data = transformer.transform(data)

output_columns = transformer.get_output_columns()
next_transformers = transformer.get_next_transformers()
for output_name in output_columns:
output_field = self._multi_column_fields.get(output_name, output_name)
next_transformer = next_transformers[output_field]
for column_name, next_transformer in next_transformers.items():

# If the column is part of a multi-column field, and at least one column
# isn't present in the data, then it should not fit the next transformer
if self._field_in_data(output_field, data):
data = self._fit_field_transformer(data, output_field, next_transformer)
if self._field_in_data(column_name, data):
data = self._fit_field_transformer(data, column_name, next_transformer)

return data

Expand Down
140 changes: 137 additions & 3 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def _reverse_transform(self, data):
return data.astype('datetime64[ns]')


TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0]


class DummyMultiColumnTransformerNumerical(BaseMultiColumnTransformer):
"""Multi column transformer that takes categorical data."""

Expand All @@ -80,6 +77,9 @@ def _reverse_transform(self, data):
return data.astype(str)


TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0]


def get_input_data():
datetimes = pd.to_datetime([
'2010-02-01',
Expand Down Expand Up @@ -1432,3 +1432,137 @@ def test_hypertransformer_with_mutli_column_transformer_end_to_end(self):

pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
pd.testing.assert_frame_equal(reverse_transformed_data, data_test)

def test_update_transformers_single_to_multi_column(self):
"""Test ``update_transformers`` to go from single to mutli column transformer."""
# Setup
dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean'
},
'transformers': {
'A': None,
'B': UniformEncoder(),
'C': UniformEncoder()
}
}
config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
ht.update_transformers({
('A', 'B'): DummyMultiColumnTransformerNumerical(),
})
new_config = ht.get_config()

# Assert
expected_config = Config({
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean'
},
'transformers': {
'C': UniformEncoder(),
"('A', 'B')": DummyMultiColumnTransformerNumerical()
}
})

assert repr(new_config) == repr(expected_config)

def test_update_transformers_multi_to_single_column(self):
"""Test ``update_transformers`` to go from multi to single column transformer."""

# Setup
dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'A': UniformEncoder(),
('B', 'C', 'D'): DummyMultiColumnTransformerNumerical(),
'E': UniformEncoder()
}
}

config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
ht.update_transformers({
('A', 'B'): DummyMultiColumnTransformerNumerical(),
'D': UniformEncoder()
})
new_config = ht.get_config()

# Assert
expected_config = Config({
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'E': UniformEncoder(),
"('A', 'B')": DummyMultiColumnTransformerNumerical(),
'C': DummyMultiColumnTransformerNumerical(),
'D': UniformEncoder()
}
})

assert repr(new_config) == repr(expected_config)

def test_update_transformers_by_sdtype_mutli_column(self):
"""Test ``update_transformers_by_sdtype`` with mutli column transformers."""
# Setup
dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'A': UniformEncoder(),
('B', 'C', 'D'): DummyMultiColumnTransformerNumerical(),
'E': UniformEncoder()
}
}

config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
ht.update_transformers_by_sdtype('boolean', transformer_name='LabelEncoder')
new_config = ht.get_config()

# Assert
expected_config = Config({
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean',
'D': 'categorical',
'E': 'categorical'
},
'transformers': {
'A': UniformEncoder(),
'E': UniformEncoder(),
'C': LabelEncoder(),
"('B', 'D')": DummyMultiColumnTransformerNumerical()
}
})

assert repr(new_config) == repr(expected_config)
Loading
Loading