diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index 214714e5..e3eee078 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -14,6 +14,7 @@ from rdt.transformers import ( BaseTransformer, get_class_by_transformer_name, get_default_transformer, get_transformers_by_type) +from rdt.transformers.utils import flatten_column_list LOGGER = logging.getLogger(__name__) @@ -25,7 +26,7 @@ def __repr__(self): """Pretty print the dictionary.""" config = { 'sdtypes': self['sdtypes'], - 'transformers': {k: repr(v) for k, v in self['transformers'].items()} + 'transformers': {str(k): repr(v) for k, v in self['transformers'].items()} } printed = json.dumps(config, indent=4) @@ -115,7 +116,7 @@ def __init__(self): self._specified_fields = set() self._validate_field_transformers() self._valid_output_sdtypes = self._DEFAULT_OUTPUT_SDTYPES - self._multi_column_fields = self._create_multi_column_fields() + self._multi_column_fields = {} self._transformers_sequence = [] self._output_columns = [] self._input_columns = [] @@ -205,7 +206,18 @@ def _validate_config(config): sdtypes = config['sdtypes'] transformers = config['transformers'] - if set(sdtypes.keys()) != set(transformers.keys()): + + sdtype_keys = sdtypes.keys() + transformer_keys = flatten_column_list(transformers.keys()) + + is_transformer_keys_unique = len(transformer_keys) == len(set(transformer_keys)) + if not is_transformer_keys_unique: + raise InvalidConfigError( + 'Error: Invalid config. Please provide unique keys for the sdtypes ' + 'and transformers.' + ) + + if set(sdtype_keys) != set(transformer_keys): raise InvalidConfigError( "The column names in the 'sdtypes' dictionary must match the " "column names in the 'transformers' dictionary." @@ -216,10 +228,14 @@ def _validate_config(config): mismatched_columns = [] for column_name, transformer in transformers.items(): - if transformer is not None: - sdtype = sdtypes.get(column_name) + if transformer is None: + continue + + columns = column_name if isinstance(column_name, tuple) else [column_name] + for column in columns: + sdtype = sdtypes.get(column) if sdtype not in transformer.get_supported_sdtypes(): - mismatched_columns.append(column_name) + mismatched_columns.append(column) if mismatched_columns: raise InvalidConfigError( @@ -519,6 +535,19 @@ def detect_initial_config(self, data): LOGGER.info('Config:') LOGGER.info(str(config)) + def _get_columns_to_sdtypes(self, field): + """Generate the ``columns_to_sdtypes`` dict for the given field. + + Args: + field (tuple): + Names of the column for the multi column trnasformer. + """ + columns_to_sdtypes = {} + for column in field: + columns_to_sdtypes[column] = self.field_sdtypes[column] + + return columns_to_sdtypes + def _fit_field_transformer(self, data, field, transformer): """Fit a transformer to its corresponding field. @@ -541,7 +570,12 @@ def _fit_field_transformer(self, data, field, transformer): self._output_columns.append(field) else: - transformer.fit(data, field) + if isinstance(field, tuple): + columns_to_sdtypes = self._get_columns_to_sdtypes(field) + transformer.fit(data, columns_to_sdtypes) + else: + transformer.fit(data, field) + self._transformers_sequence.append(transformer) data = transformer.transform(data) @@ -609,8 +643,8 @@ def fit(self, data): self._validate_detect_config_called(data) self._unfit() self._input_columns = list(data.columns) - for field in self._input_columns: - data = self._fit_field_transformer(data, field, self.field_transformers[field]) + for field_column, field_transformer in self.field_transformers.items(): + data = self._fit_field_transformer(data, field_column, field_transformer) self._validate_all_fields_fitted() self._fitted = True diff --git a/rdt/transformers/__init__.py b/rdt/transformers/__init__.py index 9ac5793a..09f28ae9 100644 --- a/rdt/transformers/__init__.py +++ b/rdt/transformers/__init__.py @@ -9,7 +9,7 @@ from functools import lru_cache from pathlib import Path -from rdt.transformers.base import BaseTransformer +from rdt.transformers.base import BaseMultiColumnTransformer, BaseTransformer from rdt.transformers.boolean import BinaryEncoder from rdt.transformers.categorical import ( CustomLabelEncoder, FrequencyEncoder, LabelEncoder, OneHotEncoder, OrderedLabelEncoder, @@ -22,6 +22,7 @@ __all__ = [ 'BaseTransformer', + 'BaseMultiColumnTransformer', 'BinaryEncoder', 'ClusterBasedNormalizer', 'CustomLabelEncoder', diff --git a/rdt/transformers/base.py b/rdt/transformers/base.py index d43f0c3b..4cf06f32 100644 --- a/rdt/transformers/base.py +++ b/rdt/transformers/base.py @@ -362,7 +362,7 @@ def _fit(self, columns_data): raise NotImplementedError() def _set_seed(self, data): - hash_value = self.get_input_column() + hash_value = self.columns[0] for value in data.head(5): hash_value += str(value) @@ -479,3 +479,113 @@ def reverse_transform(self, data): data = self._add_columns_to_data(data, reversed_data, self.columns) return data + + +class BaseMultiColumnTransformer(BaseTransformer): + """Base class for all multi column transformers. + + The ``BaseMultiColumnTransformer`` class contains methods that must be implemented + in order to create a new multi column transformer. + + Attributes: + columns_to_sdtypes (dict): + Dictionary mapping each column to its sdtype. + """ + + def __init__(self): + super().__init__() + self.columns_to_sdtypes = {} + + def get_input_column(self): + """Override ``get_input_column`` method from ``BaseTransformer``. + + Raise an error because for multi column transformers, ``get_input_columns`` + must be used instead. + """ + raise NotImplementedError( + 'MultiColumnTransformers does not have a single input column.' + 'Please use ``get_input_columns`` instead.' + ) + + def get_input_columns(self): + """Return input column name for transformer. + + Returns: + list: + Input column names. + """ + return self.columns + + def _get_prefix(self): + """Return the prefix of the output columns. + + Returns: + str: + Prefix of the output columns. + """ + raise NotImplementedError() + + def _get_output_to_property(self, property_): + self.column_prefix = self._get_prefix() + output = {} + for output_column, properties in self.output_properties.items(): + # if 'sdtype' is not in the dict, ignore the column + if property_ not in properties: + continue + + if self.column_prefix is None: + output[f'{output_column}'] = properties[property_] + else: + output[f'{self.column_prefix}.{output_column}'] = properties[property_] + + return output + + def _validate_columns_to_sdtypes(self, data, columns_to_sdtypes): + """Check that all the columns in ``columns_to_sdtypes`` are present in the data.""" + missing = set(columns_to_sdtypes.keys()) - set(data.columns) + if missing: + missing_to_print = ', '.join(missing) + raise ValueError(f'Columns ({missing_to_print}) are not present in the data.') + + def _fit(self, data): + """Fit the transformer to the data. + + Args: + data (pandas.DataFrame): + Data to transform. + """ + raise NotImplementedError() + + @random_state + def fit(self, data, columns_to_sdtypes): + """Fit the transformer to a ``column`` of the ``data``. + + Args: + data (pandas.DataFrame): + The entire table. + columns_to_sdtypes (dict): + Dictionary mapping each column to its sdtype. + """ + self._validate_columns_to_sdtypes(data, columns_to_sdtypes) + self.columns_to_sdtypes = columns_to_sdtypes + self._store_columns(list(self.columns_to_sdtypes.keys()), data) + self._set_seed(data) + columns_data = self._get_columns_data(data, self.columns) + self._fit(columns_data) + self._build_output_columns(data) + + def fit_transform(self, data, columns_to_sdtypes): + """Fit the transformer to a `column` of the `data` and then transform it. + + Args: + data (pandas.DataFrame): + The entire table. + columns_to_sdtypes (dict): + Dictionary mapping each column to its sdtype. + + Returns: + pd.DataFrame: + The entire table, containing the transformed data. + """ + self.fit(data, columns_to_sdtypes) + return self.transform(data) diff --git a/rdt/transformers/utils.py b/rdt/transformers/utils.py index 84a219a8..fe264764 100644 --- a/rdt/transformers/utils.py +++ b/rdt/transformers/utils.py @@ -163,3 +163,24 @@ def fill_nan_with_none(data): Original data with nan values replaced by None. """ return data.fillna(np.nan).replace([np.nan], [None]) + + +def flatten_column_list(column_list): + """Flatten a list of columns. + + Args: + column_list (list): + List of columns to flatten. + + Returns: + list: + Flattened list of columns. + """ + flattened = [] + for column in column_list: + if isinstance(column, tuple): + flattened.extend(column) + else: + flattened.append(column) + + return flattened diff --git a/tests/integration/test_hyper_transformer.py b/tests/integration/test_hyper_transformer.py index bb708109..4fb3967c 100644 --- a/tests/integration/test_hyper_transformer.py +++ b/tests/integration/test_hyper_transformer.py @@ -6,12 +6,14 @@ import pandas as pd import pytest -from rdt import HyperTransformer, get_demo +from rdt import get_demo from rdt.errors import ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError +from rdt.hyper_transformer import Config, HyperTransformer from rdt.transformers import ( - AnonymizedFaker, BaseTransformer, BinaryEncoder, ClusterBasedNormalizer, FloatFormatter, - FrequencyEncoder, LabelEncoder, OneHotEncoder, RegexGenerator, UniformEncoder, - UnixTimestampEncoder, get_default_transformer, get_default_transformers) + AnonymizedFaker, BaseMultiColumnTransformer, BaseTransformer, BinaryEncoder, + ClusterBasedNormalizer, FloatFormatter, FrequencyEncoder, LabelEncoder, OneHotEncoder, + RegexGenerator, UniformEncoder, UnixTimestampEncoder, get_default_transformer, + get_default_transformers) from rdt.transformers.datetime import OptimizedTimestampEncoder from rdt.transformers.numerical import GaussianNormalizer from rdt.transformers.pii.anonymizer import PseudoAnonymizedFaker @@ -55,6 +57,29 @@ def _reverse_transform(self, data): TEST_DATA_INDEX = [4, 6, 3, 8, 'a', 1.0, 2.0, 3.0] +class DummyMultiColumnTransformerNumerical(BaseMultiColumnTransformer): + """Multi column transformer that takes categorical data.""" + + SUPPORTED_SDTYPES = ['categorical', 'boolean'] + + def _fit(self, data): + self.output_properties = { + column: { + 'sdtype': 'numerical', + 'next_transformer': None, + } for column in self.columns + } + + def _get_prefix(self): + return None + + def _transform(self, data): + return data.astype(float) + + def _reverse_transform(self, data): + return data.astype(str) + + def get_input_data(): datetimes = pd.to_datetime([ '2010-02-01', @@ -1418,3 +1443,40 @@ def test_random_seed(self): # Assert pd.testing.assert_frame_equal(reversed1, reversed2) + + def test_hypertransformer_with_mutli_column_transformer_end_to_end(self): + """Test ``HyperTransformer`` with mutli column transformers end to end.""" + # Setup + data_test = pd.DataFrame({ + 'A': ['1.0', '2.0', '3.0'], + 'B': ['4.0', '5.0', '6.0'], + 'C': [True, False, True] + }) + dict_config = { + 'sdtypes': { + 'A': 'categorical', + 'B': 'categorical', + 'C': 'boolean' + }, + 'transformers': { + ('A', 'B'): DummyMultiColumnTransformerNumerical(), + 'C': UniformEncoder() + } + } + config = Config(dict_config) + ht = HyperTransformer() + ht.set_config(config) + + # Run + transformed_data = ht.fit_transform(data_test) + reverse_transformed_data = ht.reverse_transform(transformed_data) + + # Assert + expected_transformed_data = pd.DataFrame({ + 'A': [1.0, 2.0, 3.0], + 'B': [4.0, 5.0, 6.0], + 'C': [0.5892351646057272, 0.8615278122985615, 0.36493646501970534] + }) + + pd.testing.assert_frame_equal(transformed_data, expected_transformed_data) + pd.testing.assert_frame_equal(reverse_transformed_data, data_test) diff --git a/tests/integration/test_transformers.py b/tests/integration/test_transformers.py index 8a1cf130..5fc4d72a 100644 --- a/tests/integration/test_transformers.py +++ b/tests/integration/test_transformers.py @@ -69,7 +69,7 @@ def _is_valid_transformer(transformer_name): """Determine if transformer should be tested or not.""" invalid_names = [ 'IdentityTransformer', 'Dummy', 'OrderedLabelEncoder', 'CustomLabelEncoder', - 'OrderedUniformEncoder', + 'OrderedUniformEncoder', 'BaseMultiColumnTransformer' ] return all(invalid_name not in transformer_name for invalid_name in invalid_names) diff --git a/tests/integration/transformers/test_base.py b/tests/integration/transformers/test_base.py index de96dfb9..cff48696 100644 --- a/tests/integration/transformers/test_base.py +++ b/tests/integration/transformers/test_base.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from rdt.transformers.base import BaseTransformer +from rdt.transformers.base import BaseMultiColumnTransformer, BaseTransformer def test_dummy_transformer_series_output(): @@ -129,3 +129,177 @@ def _reverse_transform(self, data): }) pd.testing.assert_frame_equal(expected_transform, transformed) pd.testing.assert_frame_equal(reverse, data) + + +def test_multi_column_transformer_same_number_of_columns_input_output(): + """Test a multi-column transformer when the same of input and output columns.""" + # Setup + class AdditionTransformer(BaseMultiColumnTransformer): + """This transformer takes 3 columns and return the cumulative sum of each row.""" + def _fit(self, columns_data): + self.output_properties = { + f'{self.columns[0]}': {'sdtype': 'numerical'}, + f'{self.columns[0]}+{self.columns[1]}': {'sdtype': 'numerical'}, + f'{self.columns[0]}+{self.columns[1]}+{self.columns[2]}': {'sdtype': 'numerical'} + } + + def _get_prefix(self): + return None + + def _transform(self, data): + return data.cumsum(axis=1) + + def _reverse_transform(self, data): + result = data.diff(axis=1) + result.iloc[:, 0] = data.iloc[:, 0] + + return result.astype('int64') + + data_test = pd.DataFrame({ + 'col_1': [1, 2, 3], + 'col_2': [10, 20, 30], + 'col_3': [100, 200, 300] + }) + + columns_to_sdtypes = { + 'col_1': 'numerical', + 'col_2': 'numerical', + 'col_3': 'numerical' + } + transformer = AdditionTransformer() + + # Run + transformed = transformer.fit_transform(data_test, columns_to_sdtypes) + reverse = transformer.reverse_transform(transformed) + + # Assert + expected_transform = pd.DataFrame({ + 'col_1': [1, 2, 3], + 'col_1+col_2': [11, 22, 33], + 'col_1+col_2+col_3': [111, 222, 333] + }) + pd.testing.assert_frame_equal(expected_transform, transformed) + pd.testing.assert_frame_equal(reverse, data_test) + + +def test_multi_column_transformer_less_output_than_input_columns(): + """Test a multi-column transformer when the output has less columns than the input.""" + class ConcatenateTransformer(BaseMultiColumnTransformer): + """This transformer takes 4 columns and concatenate them into 2 columns. + The two first and last columns are concatenated together. + """ + def _fit(self, columns_data): + self.name_1 = self.columns[0] + '#' + self.columns[1] + self.name_2 = self.columns[2] + '#' + self.columns[3] + self.output_properties = { + f'{self.name_1}.concatenate_1': {'sdtype': 'categorical'}, + f'{self.name_2}.concatenate_2': {'sdtype': 'categorical'} + } + + def _get_prefix(self): + return None + + def _transform(self, data): + data[self.name_1] = data.iloc[:, 0] + '#' + data.iloc[:, 1] + data[self.name_2] = data.iloc[:, 2] + '#' + data.iloc[:, 3] + + return data.drop(columns=self.columns) + + def _reverse_transform(self, data): + result = data.copy() + column_names = list(data.columns) + + col1, col2 = column_names[0].split('#') + result[[col1, col2]] = result[column_names[0]].str.split('#', expand=True) + + col3, col4 = column_names[1].split('#') + result[[col3, col4]] = result[column_names[1]].str.split('#', expand=True) + + return result.drop(columns=column_names) + + data_test = pd.DataFrame({ + 'col_1': ['A', 'B', 'C'], + 'col_2': ['D', 'E', 'F'], + 'col_3': ['G', 'H', 'I'], + 'col_4': ['J', 'K', 'L'] + }) + + columns_to_sdtypes = { + 'col_1': 'categorical', + 'col_2': 'categorical', + 'col_3': 'categorical', + 'col_4': 'categorical' + } + transformer = ConcatenateTransformer() + + # Run + transformer.fit(data_test, columns_to_sdtypes) + transformed = transformer.transform(data_test) + reverse = transformer.reverse_transform(transformed) + + # Assert + expected_transform = pd.DataFrame({ + 'col_1#col_2.concatenate_1': ['A#D', 'B#E', 'C#F'], + 'col_3#col_4.concatenate_2': ['G#J', 'H#K', 'I#L'] + }) + pd.testing.assert_frame_equal(expected_transform, transformed) + pd.testing.assert_frame_equal(reverse, data_test) + + +def test_multi_column_transformer_more_output_than_input_columns(): + """Test a multi-column transformer when the output has more columns than the input.""" + class ExpandTransformer(BaseMultiColumnTransformer): + + def _fit(self, columns_data): + self.output_properties = { + f'{self.columns[0]}.first_part_1': {'sdtype': 'categorical'}, + f'{self.columns[0]}.second_part_1': {'sdtype': 'categorical'}, + f'{self.columns[1]}.first_part_2': {'sdtype': 'categorical'}, + f'{self.columns[1]}.second_part_2': {'sdtype': 'categorical'} + } + + def _get_prefix(self): + return None + + def _transform(self, data): + data[self.output_columns[0]] = data[self.columns[0]].str[0] + data[self.output_columns[1]] = data[self.columns[0]].str[1] + data[self.output_columns[2]] = data[self.columns[1]].str[0] + data[self.output_columns[3]] = data[self.columns[1]].str[1] + + return data.drop(columns=self.columns) + + def _reverse_transform(self, data): + result = data.copy() + reverse_1 = result[self.output_columns[0]] + result[self.output_columns[1]] + reverse_2 = result[self.output_columns[2]] + result[self.output_columns[3]] + result[self.columns[0]] = reverse_1 + result[self.columns[1]] = reverse_2 + + return result.drop(columns=self.output_columns) + + data_test = pd.DataFrame({ + 'col_1': ['AB', 'CD', 'EF'], + 'col_2': ['GH', 'IJ', 'KL'], + }) + + columns_to_sdtypes = { + 'col_1': 'categorical', + 'col_2': 'categorical' + } + transformer = ExpandTransformer() + + # Run + transformer.fit(data_test, columns_to_sdtypes) + transformed = transformer.transform(data_test) + reverse = transformer.reverse_transform(transformed) + + # Assert + expected_transform = pd.DataFrame({ + 'col_1.first_part_1': ['A', 'C', 'E'], + 'col_1.second_part_1': ['B', 'D', 'F'], + 'col_2.first_part_2': ['G', 'I', 'K'], + 'col_2.second_part_2': ['H', 'J', 'L'] + }) + pd.testing.assert_frame_equal(expected_transform, transformed) + pd.testing.assert_frame_equal(reverse, data_test) diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index dc31bc8a..b9e0924f 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -11,8 +11,8 @@ ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError, TransformerInputError, TransformerProcessingError) from rdt.transformers import ( - AnonymizedFaker, BinaryEncoder, FloatFormatter, FrequencyEncoder, LabelEncoder, RegexGenerator, - UnixTimestampEncoder) + AnonymizedFaker, BaseMultiColumnTransformer, BinaryEncoder, FloatFormatter, FrequencyEncoder, + LabelEncoder, RegexGenerator, UnixTimestampEncoder) from rdt.transformers.base import BaseTransformer from rdt.transformers.numerical import ClusterBasedNormalizer @@ -99,9 +99,8 @@ def test__validate_field_transformers(self): with pytest.raises(ValueError, match=error_msg): ht._validate_field_transformers() - @patch('rdt.hyper_transformer.HyperTransformer._create_multi_column_fields') @patch('rdt.hyper_transformer.HyperTransformer._validate_field_transformers') - def test___init__(self, validation_mock, multi_column_mock): + def test___init__(self, validation_mock): """Test create new instance of HyperTransformer""" # Run ht = HyperTransformer() @@ -109,6 +108,7 @@ def test___init__(self, validation_mock, multi_column_mock): # Asserts assert ht.field_sdtypes == {} assert ht.field_transformers == {} + assert ht._multi_column_fields == {} assert ht._specified_fields == set() assert ht._valid_output_sdtypes == ht._DEFAULT_OUTPUT_SDTYPES assert ht._transformers_sequence == [] @@ -117,7 +117,6 @@ def test___init__(self, validation_mock, multi_column_mock): assert ht._fitted_fields == set() assert ht._fitted is False assert ht._modified_config is False - multi_column_mock.assert_called_once() validation_mock.assert_called_once() def test__unfit(self): @@ -327,6 +326,30 @@ def test_detect_initial_config(self, logger_mock): call(expected_config) ]) + def test__get_columns_to_sdtypes(self): + """Test the ``_get_columns_to_sdtypes`` method.""" + # Setup + ht = HyperTransformer() + ht.field_sdtypes = { + 'col1': 'numerical', + 'col2': 'categorical', + 'col3': 'boolean', + 'col4': 'datetime', + } + + column_tuple = ('col1', 'col2', 'col3') + + # Run + columns_to_sdtypes = ht._get_columns_to_sdtypes(column_tuple) + + # Assert + expected_columns_to_sdtypes = { + 'col1': 'numerical', + 'col2': 'categorical', + 'col3': 'boolean', + } + assert columns_to_sdtypes == expected_columns_to_sdtypes + def test__fit_field_transformer(self): """Test the ``_fit_field_transformer`` method. @@ -480,6 +503,32 @@ def test__validate_config(self): with pytest.raises(InvalidConfigError, match=error_msg): HyperTransformer._validate_config(config) + def test_validate_config_not_unique_field(self): + """Test the ``_validate_config`` method when a column is repeated in a field.""" + # Setup + transformers = { + 'column1': FloatFormatter(), + 'column2': FrequencyEncoder(), + ('column2', 'column3'): None + } + sdtypes = { + 'column1': 'numerical', + 'column2': 'numerical', + 'column3': 'numerical' + } + config = { + 'sdtypes': sdtypes, + 'transformers': transformers + } + + # Run + error_msg = re.escape( + 'Error: Invalid config. Please provide unique keys for the sdtypes ' + 'and transformers.' + ) + with pytest.raises(InvalidConfigError, match=error_msg): + HyperTransformer._validate_config(config) + @patch('rdt.hyper_transformer.warnings') def test__validate_config_no_warning(self, warnings_mock): """Test the ``_validate_config`` method with no warning. @@ -499,11 +548,13 @@ def test__validate_config_no_warning(self, warnings_mock): # Setup transformers = { 'column1': FloatFormatter(), - 'column2': FrequencyEncoder() + 'column2': FrequencyEncoder(), + 'column3': None } sdtypes = { 'column1': 'numerical', - 'column2': 'categorical' + 'column2': 'categorical', + 'column3': 'numerical' } config = { 'sdtypes': sdtypes, @@ -971,6 +1022,56 @@ def test_fit(self): ht._validate_detect_config_called.assert_called_once() ht._unfit.assert_called_once() + def test_fit_with_multi_column_transformer(self): + """Test the ``fit`` method with a multi-column transformer.""" + # Setup + class MultiColumnTransformer(BaseMultiColumnTransformer): + def _fit(self, data): + self.output_properties = { + column: {'next_transformer': None, 'sdtype': 'numerical'} + for column in self.columns + } + + def _get_prefix(self): + return None + + def _transform(self, data): + return data + + def _reverse_transform(self, data): + return data + + field_transformers = { + ('col1', 'col2'): MultiColumnTransformer(), + 'col3': FloatFormatter() + } + field_sdtypes = { + 'col1': 'numerical', + 'col2': 'categorical', + 'col3': 'numerical' + } + + columns_to_sdtype = { + 'col1': 'numerical', + 'col2': 'categorical', + } + ht = HyperTransformer() + ht.field_transformers = field_transformers + ht.field_sdtypes = field_sdtypes + ht._get_columns_to_sdtypes = Mock(return_value=columns_to_sdtype) + + data = pd.DataFrame({ + 'col1': [1, 2, 3], + 'col2': ['a', 'b', 'c'], + 'col3': [1, 2, 3] + }) + + # Run + ht.fit(data) + + # Assert + ht._get_columns_to_sdtypes.assert_called_once() + def test_fit_warns(self): """Test it warns when different transformer instances produce the same column name. diff --git a/tests/unit/transformers/test_base.py b/tests/unit/transformers/test_base.py index edcfe799..f9a8fc11 100644 --- a/tests/unit/transformers/test_base.py +++ b/tests/unit/transformers/test_base.py @@ -7,7 +7,7 @@ import pytest from rdt.errors import TransformerInputError -from rdt.transformers import BaseTransformer, NullTransformer +from rdt.transformers import BaseMultiColumnTransformer, BaseTransformer, NullTransformer from rdt.transformers.base import random_state, set_random_states @@ -1272,3 +1272,221 @@ def _reverse_transform(self, data): 'b': [0.0, 0.0, 0.0], }) pd.testing.assert_frame_equal(transformed_data, expected_transformed) + + +class TestBaseMultiColumnTransformer: + + def test___init__(self): + """Test the ``__init__`` method.""" + # Setup + transformer = BaseMultiColumnTransformer() + + # Assert + assert transformer.columns_to_sdtypes == {} + + def test_get_input_column(self): + """Test the ``get_input_column`` method. + + When the ``get_input_column`` method is called, it should raise a ``NotImplementedError``. + """ + # Setup + expected_message = ( + 'MultiColumnTransformers does not have a single input column.' + 'Please use ``get_input_columns`` instead.' + ) + + # Run and Assert + with pytest.raises(NotImplementedError, match=expected_message): + BaseMultiColumnTransformer().get_input_column() + + def test_get_input_columns(self): + """Test the ``get_input_columns`` method.""" + # Setup + transformer = BaseMultiColumnTransformer() + transformer.columns = ['a', 'b', 'c'] + + # Run + output = transformer.get_input_columns() + + # Assert + assert output == ['a', 'b', 'c'] + + def test__get_prefixes(self): + """Test the ``_get_prefix`` method.""" + # Setup + transformer = BaseMultiColumnTransformer() + + # Run and Assert + with pytest.raises(NotImplementedError): + transformer._get_prefix() + + def test__get_output_to_property(self): + """Test the ``_get_output_to_property`` method.""" + # Setup + transformer = BaseMultiColumnTransformer() + transformer.output_properties = { + 'col_1': {'sdtype': 'numerical'}, + 'col_2': {'sdtype': 'categorical'}, + 'col_3': {'next_transformer': None}, + } + transformer._get_prefix = Mock(return_value='prefix') + + # Run + output = transformer._get_output_to_property('sdtype') + + # Assert + expected_output = { + 'prefix.col_1': 'numerical', + 'prefix.col_2': 'categorical', + } + assert output == expected_output + transformer._get_prefix.assert_called_once_with() + + def test__get_output_to_property_with_single_prefix(self): + """Test the ``_get_output_to_property`` method when the prefix is a single string.""" + # Setup + transformer = BaseMultiColumnTransformer() + transformer.output_properties = { + 'col_1': {'sdtype': 'numerical'}, + 'col_2': {'sdtype': 'categorical'}, + 'col_3': {'sdtype': 'boolean'}, + } + prefixes = 'prefix' + transformer._get_prefix = Mock(return_value=prefixes) + + # Run + output = transformer._get_output_to_property('sdtype') + + # Assert + expected_output = { + 'prefix.col_1': 'numerical', + 'prefix.col_2': 'categorical', + 'prefix.col_3': 'boolean', + } + assert output == expected_output + transformer._get_prefix.assert_called_once_with() + + def test__get_output_to_property_with_prefix_none(self): + """Test the ``_get_output_to_property`` method when the prefix is None.""" + # Setup + transformer = BaseMultiColumnTransformer() + transformer.output_properties = { + 'col_1': {'sdtype': 'numerical'}, + 'col_2': {'sdtype': 'categorical'}, + 'col_3': {'sdtype': 'boolean'}, + } + prefixes = None + transformer._get_prefix = Mock(return_value=prefixes) + + # Run + output = transformer._get_output_to_property('sdtype') + + # Assert + expected_output = { + 'col_1': 'numerical', + 'col_2': 'categorical', + 'col_3': 'boolean', + } + assert output == expected_output + transformer._get_prefix.assert_called_once_with() + + def test__validate_columns_to_sdtypes(self): + """Test the ``_validate_columns_to_sdtypes`` method. + + Test that an error with the missing column names is raised. + """ + # Setup + transformer = BaseMultiColumnTransformer() + data = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': ['a', 'b', 'c'], + 'c': [True, False, True], + }) + columns_to_sdtypes = { + 'a': 'numerical', + 'b': 'categorical', + 'c': 'boolean', + } + + # Run and Assert + transformer._validate_columns_to_sdtypes(data, columns_to_sdtypes) + + wrong_columns_to_sdtypes = { + 'a': 'numerical', + 'b': 'categorical', + 'd': 'boolean', + } + expected_error_msg = re.escape( + 'Columns (d) are not present in the data.' + ) + with pytest.raises(ValueError, match=expected_error_msg): + transformer._validate_columns_to_sdtypes(data, wrong_columns_to_sdtypes) + + def test__fit(self): + """Test the ``_fit`` method. + + Check that an error is raised when the ``_fit`` method is called. + """ + # Setup + transformer = BaseMultiColumnTransformer() + + # Run and Assert + with pytest.raises(NotImplementedError): + transformer._fit(None) + + def test_fit(self): + """Test the ``fit`` method.""" + # Setup + transformer = BaseMultiColumnTransformer() + data = Mock() + data_transformer = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': ['a', 'b', 'c'], + }) + columns_to_sdtypes = { + 'a': 'numerical', + 'b': 'categorical', + } + transformer.columns = ['a', 'b'] + + transformer._validate_columns_to_sdtypes = Mock() + transformer._store_columns = Mock() + transformer._get_columns_data = Mock(return_value=data_transformer) + transformer._set_seed = Mock() + transformer._fit = Mock() + transformer._build_output_columns = Mock() + + # Run + transformer.fit(data, columns_to_sdtypes) + + # Assert + transformer._validate_columns_to_sdtypes.assert_called_once_with(data, columns_to_sdtypes) + transformer._store_columns.assert_called_once_with( + ['a', 'b'], data + ) + transformer._set_seed.assert_called_once_with(data) + transformer._get_columns_data.assert_called_once_with(data, ['a', 'b']) + transformer._fit.assert_called_once_with(data_transformer) + transformer._build_output_columns.assert_called_once_with(data) + + def test_fit_transform(self): + """Test the ``fit_transform`` method.""" + # Setup + transformer = BaseMultiColumnTransformer() + columns_to_sdtypes = ('a', 'b', 'c') + data = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': ['a', 'b', 'c'], + }) + transformer.columns = ['a', 'b'] + mock_fit = Mock() + mock_transform = Mock(return_value=data) + transformer.fit = mock_fit + transformer.transform = mock_transform + + # Run + transformer.fit_transform(data, columns_to_sdtypes) + + # Assert + mock_fit.assert_called_once_with(data, columns_to_sdtypes) + mock_transform.assert_called_once_with(data) diff --git a/tests/unit/transformers/test_utils.py b/tests/unit/transformers/test_utils.py index 2d6eceb0..475a32da 100644 --- a/tests/unit/transformers/test_utils.py +++ b/tests/unit/transformers/test_utils.py @@ -1,7 +1,7 @@ import sre_parse from sre_constants import MAXREPEAT -from rdt.transformers.utils import _any, _max_repeat, strings_from_regex +from rdt.transformers.utils import _any, _max_repeat, flatten_column_list, strings_from_regex def test_strings_from_regex_literal(): @@ -53,3 +53,16 @@ def test_strings_from_regex_very_large_regex(): assert size == 173689027553046619421110743915454114823342474255318764491341273608665169920 [next(generator) for _ in range(100_000)] + + +def test_flatten_column_list(): + """Test `flatten_column_list` function.""" + # Setup + column_list = ['column1', ('column2', 'column3'), 'column4', ('column5',), 'column6'] + + # Run + flattened_list = flatten_column_list(column_list) + + # Assert + expected_flattened_list = ['column1', 'column2', 'column3', 'column4', 'column5', 'column6'] + assert flattened_list == expected_flattened_list