Skip to content

Commit

Permalink
Support multi-column transformers (#692)
Browse files Browse the repository at this point in the history
* def + store_columns

* hypertransformer modif 1

* unit test

* docstring

* drop _store_columns change

* add integration_test + _get_output_to_property

* ordered_columns + prefix

* 'int' in integration test

* change prefix + columns_to_sdtype

* docstring

* columns_to_sdtypes

* validate column_to_sdtype ValueError + docstring

* fix _get_prefix

* move PR, hypertransformer with multi column

* break line
  • Loading branch information
R-Palazzo committed Oct 31, 2023
1 parent aebb0a9 commit 35ce859
Show file tree
Hide file tree
Showing 10 changed files with 760 additions and 26 deletions.
52 changes: 43 additions & 9 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rdt.transformers import (
BaseTransformer, get_class_by_transformer_name, get_default_transformer,
get_transformers_by_type)
from rdt.transformers.utils import flatten_column_list

LOGGER = logging.getLogger(__name__)

Expand All @@ -25,7 +26,7 @@ def __repr__(self):
"""Pretty print the dictionary."""
config = {
'sdtypes': self['sdtypes'],
'transformers': {k: repr(v) for k, v in self['transformers'].items()}
'transformers': {str(k): repr(v) for k, v in self['transformers'].items()}
}

printed = json.dumps(config, indent=4)
Expand Down Expand Up @@ -115,7 +116,7 @@ def __init__(self):
self._specified_fields = set()
self._validate_field_transformers()
self._valid_output_sdtypes = self._DEFAULT_OUTPUT_SDTYPES
self._multi_column_fields = self._create_multi_column_fields()
self._multi_column_fields = {}
self._transformers_sequence = []
self._output_columns = []
self._input_columns = []
Expand Down Expand Up @@ -205,7 +206,18 @@ def _validate_config(config):

sdtypes = config['sdtypes']
transformers = config['transformers']
if set(sdtypes.keys()) != set(transformers.keys()):

sdtype_keys = sdtypes.keys()
transformer_keys = flatten_column_list(transformers.keys())

is_transformer_keys_unique = len(transformer_keys) == len(set(transformer_keys))
if not is_transformer_keys_unique:
raise InvalidConfigError(
'Error: Invalid config. Please provide unique keys for the sdtypes '
'and transformers.'
)

if set(sdtype_keys) != set(transformer_keys):
raise InvalidConfigError(
"The column names in the 'sdtypes' dictionary must match the "
"column names in the 'transformers' dictionary."
Expand All @@ -216,10 +228,14 @@ def _validate_config(config):

mismatched_columns = []
for column_name, transformer in transformers.items():
if transformer is not None:
sdtype = sdtypes.get(column_name)
if transformer is None:
continue

columns = column_name if isinstance(column_name, tuple) else [column_name]
for column in columns:
sdtype = sdtypes.get(column)
if sdtype not in transformer.get_supported_sdtypes():
mismatched_columns.append(column_name)
mismatched_columns.append(column)

if mismatched_columns:
raise InvalidConfigError(
Expand Down Expand Up @@ -519,6 +535,19 @@ def detect_initial_config(self, data):
LOGGER.info('Config:')
LOGGER.info(str(config))

def _get_columns_to_sdtypes(self, field):
"""Generate the ``columns_to_sdtypes`` dict for the given field.
Args:
field (tuple):
Names of the column for the multi column trnasformer.
"""
columns_to_sdtypes = {}
for column in field:
columns_to_sdtypes[column] = self.field_sdtypes[column]

return columns_to_sdtypes

def _fit_field_transformer(self, data, field, transformer):
"""Fit a transformer to its corresponding field.
Expand All @@ -541,7 +570,12 @@ def _fit_field_transformer(self, data, field, transformer):
self._output_columns.append(field)

else:
transformer.fit(data, field)
if isinstance(field, tuple):
columns_to_sdtypes = self._get_columns_to_sdtypes(field)
transformer.fit(data, columns_to_sdtypes)
else:
transformer.fit(data, field)

self._transformers_sequence.append(transformer)
data = transformer.transform(data)

Expand Down Expand Up @@ -609,8 +643,8 @@ def fit(self, data):
self._validate_detect_config_called(data)
self._unfit()
self._input_columns = list(data.columns)
for field in self._input_columns:
data = self._fit_field_transformer(data, field, self.field_transformers[field])
for field_column, field_transformer in self.field_transformers.items():
data = self._fit_field_transformer(data, field_column, field_transformer)

self._validate_all_fields_fitted()
self._fitted = True
Expand Down
3 changes: 2 additions & 1 deletion rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from functools import lru_cache
from pathlib import Path

from rdt.transformers.base import BaseTransformer
from rdt.transformers.base import BaseMultiColumnTransformer, BaseTransformer
from rdt.transformers.boolean import BinaryEncoder
from rdt.transformers.categorical import (
CustomLabelEncoder, FrequencyEncoder, LabelEncoder, OneHotEncoder, OrderedLabelEncoder,
Expand All @@ -22,6 +22,7 @@

__all__ = [
'BaseTransformer',
'BaseMultiColumnTransformer',
'BinaryEncoder',
'ClusterBasedNormalizer',
'CustomLabelEncoder',
Expand Down
112 changes: 111 additions & 1 deletion rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _fit(self, columns_data):
raise NotImplementedError()

def _set_seed(self, data):
hash_value = self.get_input_column()
hash_value = self.columns[0]
for value in data.head(5):
hash_value += str(value)

Expand Down Expand Up @@ -479,3 +479,113 @@ def reverse_transform(self, data):
data = self._add_columns_to_data(data, reversed_data, self.columns)

return data


class BaseMultiColumnTransformer(BaseTransformer):
"""Base class for all multi column transformers.
The ``BaseMultiColumnTransformer`` class contains methods that must be implemented
in order to create a new multi column transformer.
Attributes:
columns_to_sdtypes (dict):
Dictionary mapping each column to its sdtype.
"""

def __init__(self):
super().__init__()
self.columns_to_sdtypes = {}

def get_input_column(self):
"""Override ``get_input_column`` method from ``BaseTransformer``.
Raise an error because for multi column transformers, ``get_input_columns``
must be used instead.
"""
raise NotImplementedError(
'MultiColumnTransformers does not have a single input column.'
'Please use ``get_input_columns`` instead.'
)

def get_input_columns(self):
"""Return input column name for transformer.
Returns:
list:
Input column names.
"""
return self.columns

def _get_prefix(self):
"""Return the prefix of the output columns.
Returns:
str:
Prefix of the output columns.
"""
raise NotImplementedError()

def _get_output_to_property(self, property_):
self.column_prefix = self._get_prefix()
output = {}
for output_column, properties in self.output_properties.items():
# if 'sdtype' is not in the dict, ignore the column
if property_ not in properties:
continue

if self.column_prefix is None:
output[f'{output_column}'] = properties[property_]
else:
output[f'{self.column_prefix}.{output_column}'] = properties[property_]

return output

def _validate_columns_to_sdtypes(self, data, columns_to_sdtypes):
"""Check that all the columns in ``columns_to_sdtypes`` are present in the data."""
missing = set(columns_to_sdtypes.keys()) - set(data.columns)
if missing:
missing_to_print = ', '.join(missing)
raise ValueError(f'Columns ({missing_to_print}) are not present in the data.')

def _fit(self, data):
"""Fit the transformer to the data.
Args:
data (pandas.DataFrame):
Data to transform.
"""
raise NotImplementedError()

@random_state
def fit(self, data, columns_to_sdtypes):
"""Fit the transformer to a ``column`` of the ``data``.
Args:
data (pandas.DataFrame):
The entire table.
columns_to_sdtypes (dict):
Dictionary mapping each column to its sdtype.
"""
self._validate_columns_to_sdtypes(data, columns_to_sdtypes)
self.columns_to_sdtypes = columns_to_sdtypes
self._store_columns(list(self.columns_to_sdtypes.keys()), data)
self._set_seed(data)
columns_data = self._get_columns_data(data, self.columns)
self._fit(columns_data)
self._build_output_columns(data)

def fit_transform(self, data, columns_to_sdtypes):
"""Fit the transformer to a `column` of the `data` and then transform it.
Args:
data (pandas.DataFrame):
The entire table.
columns_to_sdtypes (dict):
Dictionary mapping each column to its sdtype.
Returns:
pd.DataFrame:
The entire table, containing the transformed data.
"""
self.fit(data, columns_to_sdtypes)
return self.transform(data)
21 changes: 21 additions & 0 deletions rdt/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,24 @@ def fill_nan_with_none(data):
Original data with nan values replaced by None.
"""
return data.fillna(np.nan).replace([np.nan], [None])


def flatten_column_list(column_list):
"""Flatten a list of columns.
Args:
column_list (list):
List of columns to flatten.
Returns:
list:
Flattened list of columns.
"""
flattened = []
for column in column_list:
if isinstance(column, tuple):
flattened.extend(column)
else:
flattened.append(column)

return flattened
70 changes: 66 additions & 4 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import pandas as pd
import pytest

from rdt import HyperTransformer, get_demo
from rdt import get_demo
from rdt.errors import ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError
from rdt.hyper_transformer import Config, HyperTransformer
from rdt.transformers import (
AnonymizedFaker, BaseTransformer, BinaryEncoder, ClusterBasedNormalizer, FloatFormatter,
FrequencyEncoder, LabelEncoder, OneHotEncoder, RegexGenerator, UniformEncoder,
UnixTimestampEncoder, get_default_transformer, get_default_transformers)
AnonymizedFaker, BaseMultiColumnTransformer, BaseTransformer, BinaryEncoder,
ClusterBasedNormalizer, FloatFormatter, FrequencyEncoder, LabelEncoder, OneHotEncoder,
RegexGenerator, UniformEncoder, UnixTimestampEncoder, get_default_transformer,
get_default_transformers)
from rdt.transformers.datetime import OptimizedTimestampEncoder
from rdt.transformers.numerical import GaussianNormalizer
from rdt.transformers.pii.anonymizer import PseudoAnonymizedFaker
Expand Down Expand Up @@ -55,6 +57,29 @@ def _reverse_transform(self, data):
TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0]


class DummyMultiColumnTransformerNumerical(BaseMultiColumnTransformer):
"""Multi column transformer that takes categorical data."""

SUPPORTED_SDTYPES = ['categorical', 'boolean']

def _fit(self, data):
self.output_properties = {
column: {
'sdtype': 'numerical',
'next_transformer': None,
} for column in self.columns
}

def _get_prefix(self):
return None

def _transform(self, data):
return data.astype(float)

def _reverse_transform(self, data):
return data.astype(str)


def get_input_data():
datetimes = pd.to_datetime([
'2010-02-01',
Expand Down Expand Up @@ -1418,3 +1443,40 @@ def test_random_seed(self):

# Assert
pd.testing.assert_frame_equal(reversed1, reversed2)

def test_hypertransformer_with_mutli_column_transformer_end_to_end(self):
"""Test ``HyperTransformer`` with mutli column transformers end to end."""
# Setup
data_test = pd.DataFrame({
'A': ['1.0', '2.0', '3.0'],
'B': ['4.0', '5.0', '6.0'],
'C': [True, False, True]
})
dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean'
},
'transformers': {
('A', 'B'): DummyMultiColumnTransformerNumerical(),
'C': UniformEncoder()
}
}
config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
transformed_data = ht.fit_transform(data_test)
reverse_transformed_data = ht.reverse_transform(transformed_data)

# Assert
expected_transformed_data = pd.DataFrame({
'A': [1.0, 2.0, 3.0],
'B': [4.0, 5.0, 6.0],
'C': [0.5892351646057272, 0.8615278122985615, 0.36493646501970534]
})

pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
pd.testing.assert_frame_equal(reverse_transformed_data, data_test)
2 changes: 1 addition & 1 deletion tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _is_valid_transformer(transformer_name):
"""Determine if transformer should be tested or not."""
invalid_names = [
'IdentityTransformer', 'Dummy', 'OrderedLabelEncoder', 'CustomLabelEncoder',
'OrderedUniformEncoder',
'OrderedUniformEncoder', 'BaseMultiColumnTransformer'
]
return all(invalid_name not in transformer_name for invalid_name in invalid_names)

Expand Down
Loading

0 comments on commit 35ce859

Please sign in to comment.