Skip to content

Commit

Permalink
unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Aug 23, 2023
1 parent e598afe commit c8d0ef6
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 35 deletions.
68 changes: 34 additions & 34 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,39 @@ def _warn_update_transformers_by_sdtype(self, transformer, transformer_name):
'parameters instead.', FutureWarning
)

def _generate_column_in_tuple(self):
"""Generate a dict mapping columns in a tuple to the tuple itself."""
column_in_tuple = {}
for fields in self.field_transformers:
if isinstance(fields, tuple):
for column in fields:
column_in_tuple[column] = fields

return column_in_tuple

def _update_column_in_tuple(self, column, column_in_tuple):
"""Update the column in tuple dict.
Args:
column (str):
Column name to be updated.
column_in_tuple (dict):
Dict mapping columns in a tuple to the tuple itself.
"""
old_tuple = column_in_tuple.pop(column)
new_tuple = tuple(item for item in old_tuple if item != column)

if len(new_tuple) == 1:
new_tuple, = new_tuple
column_in_tuple.pop(new_tuple, None)
else:
for col in new_tuple:
column_in_tuple[col] = new_tuple

self.field_transformers[new_tuple] = self.field_transformers.pop(old_tuple)

return column_in_tuple

def update_transformers_by_sdtype(
self, sdtype, transformer=None, transformer_name=None, transformer_parameters=None):
"""Update the transformers for the specified ``sdtype``.
Expand Down Expand Up @@ -370,7 +403,7 @@ def update_transformers_by_sdtype(
if field_sdtype == sdtype:
self.field_transformers[field] = deepcopy(transformer_instance)
if field in column_in_tuple:
self._update_column_in_tuple(field, column_in_tuple)
column_in_tuple = self._update_column_in_tuple(field, column_in_tuple)

self._modified_config = True

Expand Down Expand Up @@ -412,39 +445,6 @@ def update_sdtypes(self, column_name_to_sdtype):
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)

def _generate_column_in_tuple(self):
"""Generate a dict mapping columns in a tuple to the tuple itself."""
column_in_tuple = {}
for fields in self.field_transformers:
if isinstance(fields, tuple):
for column in fields:
column_in_tuple[column] = fields

return column_in_tuple

def _update_column_in_tuple(self, column, column_in_tuple):
"""Update the column in tuple dict.
Args:
column (str):
Column name to be updated.
column_in_tuple (dict):
Dict mapping columns in a tuple to the tuple itself.
"""
old_tuple = column_in_tuple[column]
new_tuple = tuple(
item for item in old_tuple if item != column
)
if len(new_tuple) == 1:
new_tuple = new_tuple[0]
del column_in_tuple[new_tuple]
else:
for col in new_tuple:
column_in_tuple[col] = new_tuple

self.field_transformers[new_tuple] = self.field_transformers[old_tuple]
del self.field_transformers[old_tuple]

def update_transformers(self, column_name_to_transformer):
"""Update any of the transformers assigned to each of the column names.
Expand Down
180 changes: 179 additions & 1 deletion tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
TransformerProcessingError)
from rdt.transformers import (
AnonymizedFaker, BinaryEncoder, FloatFormatter, FrequencyEncoder, LabelEncoder, RegexGenerator,
UnixTimestampEncoder)
UniformEncoder, UnixTimestampEncoder)
from rdt.transformers.base import BaseTransformer
from rdt.transformers.numerical import ClusterBasedNormalizer

Expand Down Expand Up @@ -2135,6 +2135,110 @@ def test_update_transformers_by_sdtype_with_transformer_name_transformer_paramet
assert isinstance(ht.field_transformers['categorical_column'], LabelEncoder)
assert ht.field_transformers['categorical_column'].order_by == 'alphabetical'

def test_generate_column_in_tuple(self):
"""Test ``_generate_column_in_tuple``."""
# Setup
ht = HyperTransformer()

ht.field_transformers = {
('column1', 'column2'): 'transformer1',
('column3', 'column4'): 'transformer2',
}

# Run
column_in_tuple = ht._generate_column_in_tuple()

# Assert
expected_mapping = {
'column1': ('column1', 'column2'),
'column2': ('column1', 'column2'),
'column3': ('column3', 'column4'),
'column4': ('column3', 'column4'),
}

assert column_in_tuple == expected_mapping

def test_update_column_in_tuple(self):
"""Test ``_update_column_in_tuple``."""
# Setup
column_in_tuple = {
'column1': ('column1', 'column2', 'column3'),
'column2': ('column1', 'column2', 'column3'),
'column3': ('column1', 'column2', 'column3'),
}
ht = HyperTransformer()
ht.field_transformers = {
('column1', 'column2', 'column3'): 'transformer',
}
# Run
result = ht._update_column_in_tuple('column1', column_in_tuple)

# Assert
expected_column_in_tuple = {
'column2': ('column2', 'column3'),
'column3': ('column2', 'column3'),
}
expected_field_transformers = {('column2', 'column3'): 'transformer'}
assert ht.field_transformers == expected_field_transformers
assert result == expected_column_in_tuple

def test_update_column_in_tuple_single_column_left(self):
"""Test ``_update_column_in_tuple`` with a single column left in the tuple."""
# Setup
column_in_tuple = {
'column1': ('column1', 'column2'),
'column2': ('column1', 'column2'),
}
ht = HyperTransformer()
ht.field_transformers = {
('column1', 'column2'): 'transformer',
}
# Run
result = ht._update_column_in_tuple('column1', column_in_tuple)

# Assert
expected_column_in_tuple = {}
expected_field_transformers = {'column2': 'transformer'}
assert ht.field_transformers == expected_field_transformers
assert result == expected_column_in_tuple

def test_update_transformers_by_sdtype_with_multi_column_transformer(self):
"""Test ``update_transformers_by_sdtype`` with columns use with a multi-column transformer.
"""
# Setup
ht = HyperTransformer()
ht.field_transformers = {
'A': LabelEncoder(),
'B': UniformEncoder(),
"('C', 'D')": None,
}
ht.field_sdtypes = {
'A': 'categorical',
'B': 'boolean',
'C': 'categorical',
'D': 'numerical'
}

column_in_tuple = {
'C': ('C', 'D'),
'D': ('C', 'D')
}
mock__update_column_in_tuple = Mock()
mock__generate_column_in_tuple = Mock(return_value=column_in_tuple)
ht._update_column_in_tuple = mock__update_column_in_tuple
ht._generate_column_in_tuple = mock__generate_column_in_tuple

# Run
ht.update_transformers_by_sdtype(
'categorical',
transformer_name='LabelEncoder',
)

# Assert
assert len(ht.field_transformers) == 4
mock__generate_column_in_tuple.assert_called_once()
assert mock__update_column_in_tuple.call_count == 1

@patch('rdt.hyper_transformer.warnings')
def test_update_transformers_fitted(self, mock_warnings):
"""Test update transformers.
Expand Down Expand Up @@ -2168,6 +2272,9 @@ def test_update_transformers_fitted(self, mock_warnings):
'my_column': transformer
}

mock__generate_column_in_tuple = Mock(return_value={})
instance._generate_column_in_tuple = mock__generate_column_in_tuple

# Run
instance.update_transformers(column_name_to_transformer)

Expand All @@ -2180,6 +2287,77 @@ def test_update_transformers_fitted(self, mock_warnings):
mock_warnings.warn.assert_called_once_with(expected_message)
assert instance.field_transformers['my_column'] == transformer
instance._validate_transformers.assert_called_once_with(column_name_to_transformer)
mock__generate_column_in_tuple.assert_called_once()

def test_update_transformers_multi_column(self):
"""Test ``update_transformers`` with a multi-column transformer."""
# Setup
ht = HyperTransformer()
mock__generate_column_in_tuple = Mock(return_value={})
mock__update_column_in_tuple = Mock()
ht._generate_column_in_tuple = mock__generate_column_in_tuple
ht._update_column_in_tuple = mock__update_column_in_tuple
ht.field_sdtypes = {
'A': 'categorical',
'B': 'boolean',
'C': 'numerical',
}
ht.field_transformers = {
'A': LabelEncoder(),
'B': UniformEncoder(),
'C': FloatFormatter(),
}

column_name_to_transformer = {
('A', 'B'): None,
'C': None,
}
# Run
ht.update_transformers(column_name_to_transformer)

# Assert
expected_field_transformers = {
('A', 'B'): None,
'C': None,
}
ht.field_transformers == expected_field_transformers
mock__generate_column_in_tuple.assert_called_once()

def test_update_transformers_changing_multi_column_transformer(self):
"""Test ``update_transformers`` when changing a mulit column transformer."""
# Setup
ht = HyperTransformer()
mock__generate_column_in_tuple = Mock(return_value={
'A': ('A', 'B'),
'B': ('A', 'B'),
})
mock__update_column_in_tuple = Mock()
ht._generate_column_in_tuple = mock__generate_column_in_tuple
ht._update_column_in_tuple = mock__update_column_in_tuple
ht.field_sdtypes = {
'A': 'categorical',
'B': 'boolean',
'C': 'numerical',
}
ht.field_transformers = {
('A', 'B'): None,
'C': FloatFormatter(),
}

column_name_to_transformer = {
'A': UniformEncoder(),
}
# Run
ht.update_transformers(column_name_to_transformer)

# Assert
expected_field_transformers = {
'A': UniformEncoder(),
'B': None,
'C': FloatFormatter(),
}
ht.field_transformers == expected_field_transformers
mock__generate_column_in_tuple.assert_called_once()

@patch('rdt.hyper_transformer.warnings')
def test_update_transformers_not_fitted(self, mock_warnings):
Expand Down

0 comments on commit c8d0ef6

Please sign in to comment.