Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-column transformers #692

Merged
52 changes: 43 additions & 9 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rdt.transformers import (
BaseTransformer, get_class_by_transformer_name, get_default_transformer,
get_transformers_by_type)
from rdt.transformers.utils import flatten_column_list

LOGGER = logging.getLogger(__name__)

Expand All @@ -25,7 +26,7 @@ def __repr__(self):
"""Pretty print the dictionary."""
config = {
'sdtypes': self['sdtypes'],
'transformers': {k: repr(v) for k, v in self['transformers'].items()}
'transformers': {str(k): repr(v) for k, v in self['transformers'].items()}
}

printed = json.dumps(config, indent=4)
Expand Down Expand Up @@ -115,7 +116,7 @@ def __init__(self):
self._specified_fields = set()
self._validate_field_transformers()
self._valid_output_sdtypes = self._DEFAULT_OUTPUT_SDTYPES
self._multi_column_fields = self._create_multi_column_fields()
self._multi_column_fields = {}
self._transformers_sequence = []
self._output_columns = []
self._input_columns = []
Expand Down Expand Up @@ -205,7 +206,18 @@ def _validate_config(config):

sdtypes = config['sdtypes']
transformers = config['transformers']
if set(sdtypes.keys()) != set(transformers.keys()):

sdtype_keys = sdtypes.keys()
transformer_keys = flatten_column_list(transformers.keys())

is_transformer_keys_unique = len(transformer_keys) == len(set(transformer_keys))
if not is_transformer_keys_unique:
raise InvalidConfigError(
'Error: Invalid config. Please provide unique keys for the sdtypes '
'and transformers.'
)

if set(sdtype_keys) != set(transformer_keys):
raise InvalidConfigError(
"The column names in the 'sdtypes' dictionary must match the "
"column names in the 'transformers' dictionary."
Expand All @@ -216,10 +228,14 @@ def _validate_config(config):

mismatched_columns = []
for column_name, transformer in transformers.items():
if transformer is not None:
sdtype = sdtypes.get(column_name)
if transformer is None:
continue

columns = column_name if isinstance(column_name, tuple) else [column_name]
for column in columns:
sdtype = sdtypes.get(column)
if sdtype not in transformer.get_supported_sdtypes():
mismatched_columns.append(column_name)
mismatched_columns.append(column)

if mismatched_columns:
raise InvalidConfigError(
Expand Down Expand Up @@ -519,6 +535,19 @@ def detect_initial_config(self, data):
LOGGER.info('Config:')
LOGGER.info(str(config))

def _get_columns_to_sdtypes(self, field):
"""Generate the ``columns_to_sdtypes`` dict for the given field.

Args:
field (tuple):
Names of the column for the multi column trnasformer.
"""
columns_to_sdtypes = {}
for column in field:
columns_to_sdtypes[column] = self.field_sdtypes[column]

return columns_to_sdtypes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Break line above would be appreciated.


def _fit_field_transformer(self, data, field, transformer):
"""Fit a transformer to its corresponding field.

Expand All @@ -541,7 +570,12 @@ def _fit_field_transformer(self, data, field, transformer):
self._output_columns.append(field)

else:
transformer.fit(data, field)
if isinstance(field, tuple):
columns_to_sdtypes = self._get_columns_to_sdtypes(field)
transformer.fit(data, columns_to_sdtypes)
else:
transformer.fit(data, field)

self._transformers_sequence.append(transformer)
data = transformer.transform(data)

Expand Down Expand Up @@ -603,8 +637,8 @@ def fit(self, data):
self._validate_detect_config_called(data)
self._unfit()
self._input_columns = list(data.columns)
for field in self._input_columns:
data = self._fit_field_transformer(data, field, self.field_transformers[field])
for field_column, field_transformer in self.field_transformers.items():
data = self._fit_field_transformer(data, field_column, field_transformer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is going to work alone. The _fit_field_transformer method doesn't know to pass the columns as a dict in the multi column case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the multi column case we pass the columns as a tuple or as a dict? I supposed it was a tuple and so it was good for me based on the docstring of _fit_field_transformer(self, data, field, transformer):

field (str or tuple):
                Name of column or tuple of columns in data that will be transformed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the base class you made fit takes in data and a dictionary. It needs to be a dict otherwise the values will always have to be passed in a specific order

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@R-Palazzo I think this is still the case. The issue states that the HyperTransformer should run with a multi-column transformer too.


self._validate_all_fields_fitted()
self._fitted = True
Expand Down
3 changes: 2 additions & 1 deletion rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from functools import lru_cache
from pathlib import Path

from rdt.transformers.base import BaseTransformer
from rdt.transformers.base import BaseMultiColumnTransformer, BaseTransformer
from rdt.transformers.boolean import BinaryEncoder
from rdt.transformers.categorical import (
CustomLabelEncoder, FrequencyEncoder, LabelEncoder, OneHotEncoder, OrderedLabelEncoder,
Expand All @@ -22,6 +22,7 @@

__all__ = [
'BaseTransformer',
'BaseMultiColumnTransformer',
'BinaryEncoder',
'ClusterBasedNormalizer',
'CustomLabelEncoder',
Expand Down
112 changes: 111 additions & 1 deletion rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _fit(self, columns_data):
raise NotImplementedError()

def _set_seed(self, data):
hash_value = self.get_input_column()
hash_value = self.columns[0]
for value in data.head(5):
hash_value += str(value)

Expand Down Expand Up @@ -479,3 +479,113 @@ def reverse_transform(self, data):
data = self._add_columns_to_data(data, reversed_data, self.columns)

return data


class BaseMultiColumnTransformer(BaseTransformer):
"""Base class for all multi column transformers.

The ``BaseMultiColumnTransformer`` class contains methods that must be implemented
in order to create a new multi column transformer.

Attributes:
columns_to_sdtypes (dict):
Dictionary mapping each column to its sdtype.
"""

def __init__(self):
super().__init__()
self.columns_to_sdtypes = {}

def get_input_column(self):
"""Override ``get_input_column`` method from ``BaseTransformer``.

Raise an error because for multi column transformers, ``get_input_columns``
must be used instead.
"""
raise NotImplementedError(
'MultiColumnTransformers does not have a single input column.'
'Please use ``get_input_columns`` instead.'
)

def get_input_columns(self):
"""Return input column name for transformer.

Returns:
list:
Input column names.
"""
return self.columns

def _get_prefix(self):
"""Return the prefix of the output columns.

Returns:
str:
Prefix of the output columns.
"""
raise NotImplementedError()

def _get_output_to_property(self, property_):
self.column_prefix = self._get_prefix()
output = {}
for output_column, properties in self.output_properties.items():
# if 'sdtype' is not in the dict, ignore the column
if property_ not in properties:
continue

if self.column_prefix is None:
output[f'{output_column}'] = properties[property_]
else:
output[f'{self.column_prefix}.{output_column}'] = properties[property_]

return output

def _validate_columns_to_sdtypes(self, data, columns_to_sdtypes):
"""Check that all the columns in ``columns_to_sdtypes`` are present in the data."""
missing = set(columns_to_sdtypes.keys()) - set(data.columns)
if missing:
missing_to_print = ', '.join(missing)
raise ValueError(f'Columns ({missing_to_print}) are not present in the data.')

def _fit(self, data):
"""Fit the transformer to the data.

Args:
data (pandas.DataFrame):
Data to transform.
"""
raise NotImplementedError()

@random_state
def fit(self, data, columns_to_sdtypes):
"""Fit the transformer to a ``column`` of the ``data``.

Args:
data (pandas.DataFrame):
The entire table.
columns_to_sdtypes (dict):
Dictionary mapping each column to its sdtype.
"""
self._validate_columns_to_sdtypes(data, columns_to_sdtypes)
self.columns_to_sdtypes = columns_to_sdtypes
self._store_columns(list(self.columns_to_sdtypes.keys()), data)
self._set_seed(data)
columns_data = self._get_columns_data(data, self.columns)
self._fit(columns_data)
self._build_output_columns(data)

def fit_transform(self, data, columns_to_sdtypes):
"""Fit the transformer to a `column` of the `data` and then transform it.

Args:
data (pandas.DataFrame):
The entire table.
columns_to_sdtypes (dict):
Dictionary mapping each column to its sdtype.

Returns:
pd.DataFrame:
The entire table, containing the transformed data.
"""
self.fit(data, columns_to_sdtypes)
return self.transform(data)
21 changes: 21 additions & 0 deletions rdt/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,24 @@ def fill_nan_with_none(data):
Original data with nan values replaced by None.
"""
return data.fillna(np.nan).replace([np.nan], [None])


def flatten_column_list(column_list):
"""Flatten a list of columns.

Args:
column_list (list):
List of columns to flatten.

Returns:
list:
Flattened list of columns.
"""
flattened = []
for column in column_list:
if isinstance(column, tuple):
flattened.extend(column)
else:
flattened.append(column)

return flattened
70 changes: 66 additions & 4 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import pandas as pd
import pytest

from rdt import HyperTransformer, get_demo
from rdt import get_demo
from rdt.errors import ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError
from rdt.hyper_transformer import Config, HyperTransformer
from rdt.transformers import (
AnonymizedFaker, BaseTransformer, BinaryEncoder, ClusterBasedNormalizer, FloatFormatter,
FrequencyEncoder, LabelEncoder, OneHotEncoder, RegexGenerator, UniformEncoder,
UnixTimestampEncoder, get_default_transformer, get_default_transformers)
AnonymizedFaker, BaseMultiColumnTransformer, BaseTransformer, BinaryEncoder,
ClusterBasedNormalizer, FloatFormatter, FrequencyEncoder, LabelEncoder, OneHotEncoder,
RegexGenerator, UniformEncoder, UnixTimestampEncoder, get_default_transformer,
get_default_transformers)
from rdt.transformers.datetime import OptimizedTimestampEncoder
from rdt.transformers.numerical import GaussianNormalizer
from rdt.transformers.pii.anonymizer import PseudoAnonymizedFaker
Expand Down Expand Up @@ -55,6 +57,29 @@ def _reverse_transform(self, data):
TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0]


class DummyMultiColumnTransformerNumerical(BaseMultiColumnTransformer):
"""Multi column transformer that takes categorical data."""

SUPPORTED_SDTYPES = ['categorical', 'boolean']

def _fit(self, data):
self.output_properties = {
column: {
'sdtype': 'numerical',
'next_transformer': None,
} for column in self.columns
}

def _get_prefix(self):
return None

def _transform(self, data):
return data.astype(float)

def _reverse_transform(self, data):
return data.astype(str)


def get_input_data():
datetimes = pd.to_datetime([
'2010-02-01',
Expand Down Expand Up @@ -1370,3 +1395,40 @@ def test_random_seed(self):

# Assert
pd.testing.assert_frame_equal(reversed1, reversed2)

def test_hypertransformer_with_mutli_column_transformer_end_to_end(self):
"""Test ``HyperTransformer`` with mutli column transformers end to end."""
# Setup
data_test = pd.DataFrame({
'A': ['1.0', '2.0', '3.0'],
'B': ['4.0', '5.0', '6.0'],
'C': [True, False, True]
})
dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'C': 'boolean'
},
'transformers': {
('A', 'B'): DummyMultiColumnTransformerNumerical(),
'C': UniformEncoder()
}
}
config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
transformed_data = ht.fit_transform(data_test)
reverse_transformed_data = ht.reverse_transform(transformed_data)

# Assert
expected_transformed_data = pd.DataFrame({
'A': [1.0, 2.0, 3.0],
'B': [4.0, 5.0, 6.0],
'C': [0.5892351646057272, 0.8615278122985615, 0.36493646501970534]
})

pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
pd.testing.assert_frame_equal(reverse_transformed_data, data_test)
2 changes: 1 addition & 1 deletion tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _is_valid_transformer(transformer_name):
"""Determine if transformer should be tested or not."""
invalid_names = [
'IdentityTransformer', 'Dummy', 'OrderedLabelEncoder', 'CustomLabelEncoder',
'OrderedUniformEncoder',
'OrderedUniformEncoder', 'BaseMultiColumnTransformer'
]
return all(invalid_name not in transformer_name for invalid_name in invalid_names)

Expand Down
Loading
Loading