Skip to content

Commit 89d871d

Browse files
authored
Support multi-column transformers (#692)
* def + store_columns * hypertransformer modif 1 * unit test * docstring * drop _store_columns change * add integration_test + _get_output_to_property * ordered_columns + prefix * 'int' in integration test * change prefix + columns_to_sdtype * docstring * columns_to_sdtypes * validate column_to_sdtype ValueError + docstring * fix _get_prefix * move PR, hypertransformer with multi column * break line
1 parent 636f4e2 commit 89d871d

File tree

10 files changed

+760
-26
lines changed

10 files changed

+760
-26
lines changed

rdt/hyper_transformer.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from rdt.transformers import (
1515
BaseTransformer, get_class_by_transformer_name, get_default_transformer,
1616
get_transformers_by_type)
17+
from rdt.transformers.utils import flatten_column_list
1718

1819
LOGGER = logging.getLogger(__name__)
1920

@@ -25,7 +26,7 @@ def __repr__(self):
2526
"""Pretty print the dictionary."""
2627
config = {
2728
'sdtypes': self['sdtypes'],
28-
'transformers': {k: repr(v) for k, v in self['transformers'].items()}
29+
'transformers': {str(k): repr(v) for k, v in self['transformers'].items()}
2930
}
3031

3132
printed = json.dumps(config, indent=4)
@@ -115,7 +116,7 @@ def __init__(self):
115116
self._specified_fields = set()
116117
self._validate_field_transformers()
117118
self._valid_output_sdtypes = self._DEFAULT_OUTPUT_SDTYPES
118-
self._multi_column_fields = self._create_multi_column_fields()
119+
self._multi_column_fields = {}
119120
self._transformers_sequence = []
120121
self._output_columns = []
121122
self._input_columns = []
@@ -205,7 +206,18 @@ def _validate_config(config):
205206

206207
sdtypes = config['sdtypes']
207208
transformers = config['transformers']
208-
if set(sdtypes.keys()) != set(transformers.keys()):
209+
210+
sdtype_keys = sdtypes.keys()
211+
transformer_keys = flatten_column_list(transformers.keys())
212+
213+
is_transformer_keys_unique = len(transformer_keys) == len(set(transformer_keys))
214+
if not is_transformer_keys_unique:
215+
raise InvalidConfigError(
216+
'Error: Invalid config. Please provide unique keys for the sdtypes '
217+
'and transformers.'
218+
)
219+
220+
if set(sdtype_keys) != set(transformer_keys):
209221
raise InvalidConfigError(
210222
"The column names in the 'sdtypes' dictionary must match the "
211223
"column names in the 'transformers' dictionary."
@@ -216,10 +228,14 @@ def _validate_config(config):
216228

217229
mismatched_columns = []
218230
for column_name, transformer in transformers.items():
219-
if transformer is not None:
220-
sdtype = sdtypes.get(column_name)
231+
if transformer is None:
232+
continue
233+
234+
columns = column_name if isinstance(column_name, tuple) else [column_name]
235+
for column in columns:
236+
sdtype = sdtypes.get(column)
221237
if sdtype not in transformer.get_supported_sdtypes():
222-
mismatched_columns.append(column_name)
238+
mismatched_columns.append(column)
223239

224240
if mismatched_columns:
225241
raise InvalidConfigError(
@@ -519,6 +535,19 @@ def detect_initial_config(self, data):
519535
LOGGER.info('Config:')
520536
LOGGER.info(str(config))
521537

538+
def _get_columns_to_sdtypes(self, field):
539+
"""Generate the ``columns_to_sdtypes`` dict for the given field.
540+
541+
Args:
542+
field (tuple):
543+
Names of the column for the multi column trnasformer.
544+
"""
545+
columns_to_sdtypes = {}
546+
for column in field:
547+
columns_to_sdtypes[column] = self.field_sdtypes[column]
548+
549+
return columns_to_sdtypes
550+
522551
def _fit_field_transformer(self, data, field, transformer):
523552
"""Fit a transformer to its corresponding field.
524553
@@ -541,7 +570,12 @@ def _fit_field_transformer(self, data, field, transformer):
541570
self._output_columns.append(field)
542571

543572
else:
544-
transformer.fit(data, field)
573+
if isinstance(field, tuple):
574+
columns_to_sdtypes = self._get_columns_to_sdtypes(field)
575+
transformer.fit(data, columns_to_sdtypes)
576+
else:
577+
transformer.fit(data, field)
578+
545579
self._transformers_sequence.append(transformer)
546580
data = transformer.transform(data)
547581

@@ -603,8 +637,8 @@ def fit(self, data):
603637
self._validate_detect_config_called(data)
604638
self._unfit()
605639
self._input_columns = list(data.columns)
606-
for field in self._input_columns:
607-
data = self._fit_field_transformer(data, field, self.field_transformers[field])
640+
for field_column, field_transformer in self.field_transformers.items():
641+
data = self._fit_field_transformer(data, field_column, field_transformer)
608642

609643
self._validate_all_fields_fitted()
610644
self._fitted = True

rdt/transformers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from functools import lru_cache
1010
from pathlib import Path
1111

12-
from rdt.transformers.base import BaseTransformer
12+
from rdt.transformers.base import BaseMultiColumnTransformer, BaseTransformer
1313
from rdt.transformers.boolean import BinaryEncoder
1414
from rdt.transformers.categorical import (
1515
CustomLabelEncoder, FrequencyEncoder, LabelEncoder, OneHotEncoder, OrderedLabelEncoder,
@@ -22,6 +22,7 @@
2222

2323
__all__ = [
2424
'BaseTransformer',
25+
'BaseMultiColumnTransformer',
2526
'BinaryEncoder',
2627
'ClusterBasedNormalizer',
2728
'CustomLabelEncoder',

rdt/transformers/base.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def _fit(self, columns_data):
362362
raise NotImplementedError()
363363

364364
def _set_seed(self, data):
365-
hash_value = self.get_input_column()
365+
hash_value = self.columns[0]
366366
for value in data.head(5):
367367
hash_value += str(value)
368368

@@ -479,3 +479,113 @@ def reverse_transform(self, data):
479479
data = self._add_columns_to_data(data, reversed_data, self.columns)
480480

481481
return data
482+
483+
484+
class BaseMultiColumnTransformer(BaseTransformer):
485+
"""Base class for all multi column transformers.
486+
487+
The ``BaseMultiColumnTransformer`` class contains methods that must be implemented
488+
in order to create a new multi column transformer.
489+
490+
Attributes:
491+
columns_to_sdtypes (dict):
492+
Dictionary mapping each column to its sdtype.
493+
"""
494+
495+
def __init__(self):
496+
super().__init__()
497+
self.columns_to_sdtypes = {}
498+
499+
def get_input_column(self):
500+
"""Override ``get_input_column`` method from ``BaseTransformer``.
501+
502+
Raise an error because for multi column transformers, ``get_input_columns``
503+
must be used instead.
504+
"""
505+
raise NotImplementedError(
506+
'MultiColumnTransformers does not have a single input column.'
507+
'Please use ``get_input_columns`` instead.'
508+
)
509+
510+
def get_input_columns(self):
511+
"""Return input column name for transformer.
512+
513+
Returns:
514+
list:
515+
Input column names.
516+
"""
517+
return self.columns
518+
519+
def _get_prefix(self):
520+
"""Return the prefix of the output columns.
521+
522+
Returns:
523+
str:
524+
Prefix of the output columns.
525+
"""
526+
raise NotImplementedError()
527+
528+
def _get_output_to_property(self, property_):
529+
self.column_prefix = self._get_prefix()
530+
output = {}
531+
for output_column, properties in self.output_properties.items():
532+
# if 'sdtype' is not in the dict, ignore the column
533+
if property_ not in properties:
534+
continue
535+
536+
if self.column_prefix is None:
537+
output[f'{output_column}'] = properties[property_]
538+
else:
539+
output[f'{self.column_prefix}.{output_column}'] = properties[property_]
540+
541+
return output
542+
543+
def _validate_columns_to_sdtypes(self, data, columns_to_sdtypes):
544+
"""Check that all the columns in ``columns_to_sdtypes`` are present in the data."""
545+
missing = set(columns_to_sdtypes.keys()) - set(data.columns)
546+
if missing:
547+
missing_to_print = ', '.join(missing)
548+
raise ValueError(f'Columns ({missing_to_print}) are not present in the data.')
549+
550+
def _fit(self, data):
551+
"""Fit the transformer to the data.
552+
553+
Args:
554+
data (pandas.DataFrame):
555+
Data to transform.
556+
"""
557+
raise NotImplementedError()
558+
559+
@random_state
560+
def fit(self, data, columns_to_sdtypes):
561+
"""Fit the transformer to a ``column`` of the ``data``.
562+
563+
Args:
564+
data (pandas.DataFrame):
565+
The entire table.
566+
columns_to_sdtypes (dict):
567+
Dictionary mapping each column to its sdtype.
568+
"""
569+
self._validate_columns_to_sdtypes(data, columns_to_sdtypes)
570+
self.columns_to_sdtypes = columns_to_sdtypes
571+
self._store_columns(list(self.columns_to_sdtypes.keys()), data)
572+
self._set_seed(data)
573+
columns_data = self._get_columns_data(data, self.columns)
574+
self._fit(columns_data)
575+
self._build_output_columns(data)
576+
577+
def fit_transform(self, data, columns_to_sdtypes):
578+
"""Fit the transformer to a `column` of the `data` and then transform it.
579+
580+
Args:
581+
data (pandas.DataFrame):
582+
The entire table.
583+
columns_to_sdtypes (dict):
584+
Dictionary mapping each column to its sdtype.
585+
586+
Returns:
587+
pd.DataFrame:
588+
The entire table, containing the transformed data.
589+
"""
590+
self.fit(data, columns_to_sdtypes)
591+
return self.transform(data)

rdt/transformers/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,24 @@ def fill_nan_with_none(data):
163163
Original data with nan values replaced by None.
164164
"""
165165
return data.fillna(np.nan).replace([np.nan], [None])
166+
167+
168+
def flatten_column_list(column_list):
169+
"""Flatten a list of columns.
170+
171+
Args:
172+
column_list (list):
173+
List of columns to flatten.
174+
175+
Returns:
176+
list:
177+
Flattened list of columns.
178+
"""
179+
flattened = []
180+
for column in column_list:
181+
if isinstance(column, tuple):
182+
flattened.extend(column)
183+
else:
184+
flattened.append(column)
185+
186+
return flattened

tests/integration/test_hyper_transformer.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
import pandas as pd
77
import pytest
88

9-
from rdt import HyperTransformer, get_demo
9+
from rdt import get_demo
1010
from rdt.errors import ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError
11+
from rdt.hyper_transformer import Config, HyperTransformer
1112
from rdt.transformers import (
12-
AnonymizedFaker, BaseTransformer, BinaryEncoder, ClusterBasedNormalizer, FloatFormatter,
13-
FrequencyEncoder, LabelEncoder, OneHotEncoder, RegexGenerator, UniformEncoder,
14-
UnixTimestampEncoder, get_default_transformer, get_default_transformers)
13+
AnonymizedFaker, BaseMultiColumnTransformer, BaseTransformer, BinaryEncoder,
14+
ClusterBasedNormalizer, FloatFormatter, FrequencyEncoder, LabelEncoder, OneHotEncoder,
15+
RegexGenerator, UniformEncoder, UnixTimestampEncoder, get_default_transformer,
16+
get_default_transformers)
1517
from rdt.transformers.datetime import OptimizedTimestampEncoder
1618
from rdt.transformers.numerical import GaussianNormalizer
1719
from rdt.transformers.pii.anonymizer import PseudoAnonymizedFaker
@@ -55,6 +57,29 @@ def _reverse_transform(self, data):
5557
TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0]
5658

5759

60+
class DummyMultiColumnTransformerNumerical(BaseMultiColumnTransformer):
61+
"""Multi column transformer that takes categorical data."""
62+
63+
SUPPORTED_SDTYPES = ['categorical', 'boolean']
64+
65+
def _fit(self, data):
66+
self.output_properties = {
67+
column: {
68+
'sdtype': 'numerical',
69+
'next_transformer': None,
70+
} for column in self.columns
71+
}
72+
73+
def _get_prefix(self):
74+
return None
75+
76+
def _transform(self, data):
77+
return data.astype(float)
78+
79+
def _reverse_transform(self, data):
80+
return data.astype(str)
81+
82+
5883
def get_input_data():
5984
datetimes = pd.to_datetime([
6085
'2010-02-01',
@@ -1370,3 +1395,40 @@ def test_random_seed(self):
13701395

13711396
# Assert
13721397
pd.testing.assert_frame_equal(reversed1, reversed2)
1398+
1399+
def test_hypertransformer_with_mutli_column_transformer_end_to_end(self):
1400+
"""Test ``HyperTransformer`` with mutli column transformers end to end."""
1401+
# Setup
1402+
data_test = pd.DataFrame({
1403+
'A': ['1.0', '2.0', '3.0'],
1404+
'B': ['4.0', '5.0', '6.0'],
1405+
'C': [True, False, True]
1406+
})
1407+
dict_config = {
1408+
'sdtypes': {
1409+
'A': 'categorical',
1410+
'B': 'categorical',
1411+
'C': 'boolean'
1412+
},
1413+
'transformers': {
1414+
('A', 'B'): DummyMultiColumnTransformerNumerical(),
1415+
'C': UniformEncoder()
1416+
}
1417+
}
1418+
config = Config(dict_config)
1419+
ht = HyperTransformer()
1420+
ht.set_config(config)
1421+
1422+
# Run
1423+
transformed_data = ht.fit_transform(data_test)
1424+
reverse_transformed_data = ht.reverse_transform(transformed_data)
1425+
1426+
# Assert
1427+
expected_transformed_data = pd.DataFrame({
1428+
'A': [1.0, 2.0, 3.0],
1429+
'B': [4.0, 5.0, 6.0],
1430+
'C': [0.5892351646057272, 0.8615278122985615, 0.36493646501970534]
1431+
})
1432+
1433+
pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
1434+
pd.testing.assert_frame_equal(reverse_transformed_data, data_test)

tests/integration/test_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _is_valid_transformer(transformer_name):
6969
"""Determine if transformer should be tested or not."""
7070
invalid_names = [
7171
'IdentityTransformer', 'Dummy', 'OrderedLabelEncoder', 'CustomLabelEncoder',
72-
'OrderedUniformEncoder',
72+
'OrderedUniformEncoder', 'BaseMultiColumnTransformer'
7373
]
7474
return all(invalid_name not in transformer_name for invalid_name in invalid_names)
7575

0 commit comments

Comments
 (0)