Skip to content

Commit

Permalink
add _get_columns_to_sdtype
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Sep 8, 2023
1 parent 569b2dd commit daf2cbe
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 9 deletions.
19 changes: 18 additions & 1 deletion rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,18 @@ def detect_initial_config(self, data):
LOGGER.info('Config:')
LOGGER.info(str(config))

def _get_columns_to_sdtypes(self, field):
"""Generate the ``columns_to_sdtypes`` dict for the given field.
Args:
field (tuple):
Names of the column for the multi column trnasformer.
"""
columns_to_sdtypes = {}
for column in field:
columns_to_sdtypes[column] = self.field_sdtypes[column]
return columns_to_sdtypes

def _fit_field_transformer(self, data, field, transformer):
"""Fit a transformer to its corresponding field.
Expand All @@ -597,7 +609,12 @@ def _fit_field_transformer(self, data, field, transformer):
self._output_columns.append(field)

else:
transformer.fit(data, field)
if isinstance(field, tuple):
columns_to_sdtypes = self._get_columns_to_sdtypes(field)
transformer.fit(data, columns_to_sdtypes)
else:
transformer.fit(data, field)

self._transformers_sequence.append(transformer)
data = transformer.transform(data)

Expand Down
11 changes: 5 additions & 6 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,16 @@ class DummyMultiColumnTransformerNumerical(BaseMultiColumnTransformer):

SUPPORTED_SDTYPES = ['categorical', 'boolean']

def _fit(self, data, ordered_columns):
def _fit(self, data):
self.output_properties = {
column: {
'sdtype': 'numerical',
'next_transformer': None,
} for column in self.columns
}

def _generate_prefixes(self, ordered_columns):
prefixes = {column: column for column in self.output_properties}
return prefixes
def _get_prefix(self):
return None

def _transform(self, data):
return data.astype(float)
Expand Down Expand Up @@ -1577,8 +1576,8 @@ def test_hypertransformer_with_mutli_column_transformer_end_to_end(self):

expected_transformed_data = pd.DataFrame({
'C': [0.5225768219566304, 0.7797813625043645, 0.31881544039752413],
'A.A': [1.0, 2.0, 3.0],
'B.B': [4.0, 5.0, 6.0]
'A': [1.0, 2.0, 3.0],
'B': [4.0, 5.0, 6.0]
})

assert repr(new_config) == repr(expected_config)
Expand Down
78 changes: 76 additions & 2 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError, TransformerInputError,
TransformerProcessingError)
from rdt.transformers import (
AnonymizedFaker, BinaryEncoder, FloatFormatter, FrequencyEncoder, LabelEncoder, RegexGenerator,
UniformEncoder, UnixTimestampEncoder)
AnonymizedFaker, BaseMultiColumnTransformer, BinaryEncoder, FloatFormatter, FrequencyEncoder,
LabelEncoder, RegexGenerator, UniformEncoder, UnixTimestampEncoder)
from rdt.transformers.base import BaseTransformer
from rdt.transformers.numerical import ClusterBasedNormalizer

Expand Down Expand Up @@ -326,6 +326,30 @@ def test_detect_initial_config(self, logger_mock):
call(expected_config)
])

def test__get_columns_to_sdtypes(self):
"""Test the ``_get_columns_to_sdtypes`` method."""
# Setup
ht = HyperTransformer()
ht.field_sdtypes = {
'col1': 'numerical',
'col2': 'categorical',
'col3': 'boolean',
'col4': 'datetime',
}

column_tuple = ('col1', 'col2', 'col3')

# Run
columns_to_sdtypes = ht._get_columns_to_sdtypes(column_tuple)

# Assert
expected_columns_to_sdtypes = {
'col1': 'numerical',
'col2': 'categorical',
'col3': 'boolean',
}
assert columns_to_sdtypes == expected_columns_to_sdtypes

def test__fit_field_transformer(self):
"""Test the ``_fit_field_transformer`` method.
Expand Down Expand Up @@ -998,6 +1022,56 @@ def test_fit(self):
ht._validate_detect_config_called.assert_called_once()
ht._unfit.assert_called_once()

def test_fit_with_multi_column_transformer(self):
"""Test the ``fit`` method with a multi-column transformer."""
# Setup
class MultiColumnTransformer(BaseMultiColumnTransformer):
def _fit(self, data):
self.output_properties = {
column: {'next_transformer': None, 'sdtype': 'numerical'}
for column in self.columns
}

def _get_prefix(self):
return None

def _transform(self, data):
return data

def _reverse_transform(self, data):
return data

field_transformers = {
('col1', 'col2'): MultiColumnTransformer(),
'col3': FloatFormatter()
}
field_sdtypes = {
'col1': 'numerical',
'col2': 'categorical',
'col3': 'numerical'
}

columns_to_sdtype = {
'col1': 'numerical',
'col2': 'categorical',
}
ht = HyperTransformer()
ht.field_transformers = field_transformers
ht.field_sdtypes = field_sdtypes
ht._get_columns_to_sdtypes = Mock(return_value=columns_to_sdtype)

data = pd.DataFrame({
'col1': [1, 2, 3],
'col2': ['a', 'b', 'c'],
'col3': [1, 2, 3]
})

# Run
ht.fit(data)

# Assert
ht._get_columns_to_sdtypes.assert_called_once()

def test_fit_warns(self):
"""Test it warns when different transformer instances produce the same column name.
Expand Down

0 comments on commit daf2cbe

Please sign in to comment.