Skip to content

Commit

Permalink
Extra anonymized columns are created during reverse_transform_subset (
Browse files Browse the repository at this point in the history
#549)

* Fix extra columns being created

* Address comments
  • Loading branch information
pvk-developer authored Sep 13, 2022
1 parent c3fcb99 commit e7e4132
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 17 deletions.
23 changes: 14 additions & 9 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,24 +798,29 @@ def _reverse_transform(self, data, prevent_subset):
raise NotFittedError(self._NOT_FIT_MESSAGE)

unknown_columns = self._subset(data.columns, self._output_columns, not_in=True)
if unknown_columns:
raise Error(
'There are unexpected column names in the data you are trying to transform. '
f'A reverse transform is not defined for {unknown_columns}.'
)

if prevent_subset:
contained = all(column in self._output_columns for column in data.columns)
is_subset = contained and len(data.columns) < len(self._output_columns)
if unknown_columns or is_subset:
if is_subset:
raise Error(
'There are unexpected columns in the data you are trying to transform. '
'You must provide a transformed dataset with all the columns from the '
'original data.'
)

elif unknown_columns:
raise Error(
'There are unexpected column names in the data you are trying to transform. '
f'A reverse transform is not defined for {unknown_columns}.'
)
for transformer in reversed(self._transformers_sequence):
data = transformer.reverse_transform(data, drop=False)

for transformer in reversed(self._transformers_sequence):
data = transformer.reverse_transform(data, drop=False)
else:
for transformer in reversed(self._transformers_sequence):
output_columns = transformer.get_output_columns()
if output_columns and set(output_columns).issubset(data.columns):
data = transformer.reverse_transform(data, drop=False)

reversed_columns = self._subset(self._input_columns, data.columns)

Expand Down
61 changes: 58 additions & 3 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import pandas as pd
import pytest

from rdt import HyperTransformer
from rdt import HyperTransformer, get_demo
from rdt.errors import Error, NotFittedError
from rdt.transformers import (
DEFAULT_TRANSFORMERS, BaseTransformer, BinaryEncoder, FloatFormatter, FrequencyEncoder,
OneHotEncoder, UnixTimestampEncoder, get_default_transformer, get_default_transformers)
DEFAULT_TRANSFORMERS, AnonymizedFaker, BaseTransformer, BinaryEncoder, FloatFormatter,
FrequencyEncoder, OneHotEncoder, RegexGenerator, UnixTimestampEncoder, get_default_transformer,
get_default_transformers)


class DummyTransformerNumerical(BaseTransformer):
Expand Down Expand Up @@ -734,3 +735,57 @@ def test_hyper_transformer_with_supported_sdtypes():

for transformer in ht.get_config()['transformers'].values():
assert not isinstance(transformer, BinaryEncoder)


def test_hyper_transformer_reverse_transform_subset_and_generators():
"""Test the ``HyperTransformer`` with ``reverse_transform_subset``.
Test that when calling ``reverse_transform_subset`` and there are ``generators`` like
``AnonymizedFaker`` or ``RegexGenerator`` those are not being used in the ``subset``, and
also any other transformer which can't transform the given columns.
Setup:
- DataFrame with multiple datatypes.
- Instance of HyperTransformer.
- Add ``pii`` using ``AnonymizedFaker`` and ``RegexGenerator``.
Run:
- Use ``fit_transform`` then ``revese_transform_subsample``.
Assert:
- Assert that the ``reverse_transformed`` data does not contain any additional columns
but the expected one.
"""
# Setup
customers = get_demo()
customers['id'] = ['ID_a', 'ID_b', 'ID_c', 'ID_d', 'ID_e']

# Create a config
ht = HyperTransformer()
ht.detect_initial_config(customers)

# credit_card and id are pii and text columns
ht.update_sdtypes({
'credit_card': 'pii',
'id': 'text'
})

ht.update_transformers({
'credit_card': AnonymizedFaker(),
'id': RegexGenerator(regex_format='id_[a-z]')
})

# Run
ht.fit(customers)
transformed = ht.transform(customers)
reverse_transformed = ht.reverse_transform_subset(transformed[['last_login.value']])

# Assert
expected_transformed_columns = [
'last_login.value',
'email_optin.value',
'age.value',
'dollars_spent.value'
]
assert all(expected_transformed_columns == transformed.columns)
assert reverse_transformed.columns == ['last_login']
67 changes: 62 additions & 5 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,64 @@ def test_reverse_transform(self):
bool_transformer.reverse_transform.assert_called_once()
datetime_transformer.reverse_transform.assert_called_once()

def test_reverse_transform_subset_with_generators(self):
"""Test the ``reverse_transform`` method.
Tests that ``reverse_transform`` loops through the ``_transformers_sequence``
in reverse order and calls ``transformer.reverse_transform`` if they have
``output_columns``.
Setup:
- The ``_transformers_sequence`` will be hardcoded with a list
of transformer mocks and one of them is a ``generator``
(does not have ``output_columns``).
- The ``_output_columns`` will be hardcoded.
- The ``_input_columns`` will be hardcoded.
Input:
- A DataFrame of multiple sdtypes.
Output:
- The reverse transformed DataFrame with the correct columns dropped.
Side Effects:
- Only the transformers with ``get_output_columns`` will be called.
"""
# Setup
int_transformer = Mock()
float_transformer = Mock()
generator_transformer = Mock()
int_transformer.get_output_columns.return_value = ['integer.out.value']
float_transformer.get_output_columns.return_value = ['float.value']
generator_transformer.get_output_columns.return_value = []

reverse_transformed_data = self.get_transformed_data()
float_transformer.reverse_transform = lambda x, drop: x
int_transformer.reverse_transform.return_value = reverse_transformed_data

data = self.get_transformed_data(True)

ht = HyperTransformer()
ht._validate_config_exists = Mock()
ht._validate_config_exists.return_value = True
ht._fitted = True
ht._transformers_sequence = [
int_transformer,
float_transformer,
generator_transformer
]
ht._output_columns = list(data.columns)
expected = self.get_data()
ht._input_columns = list(expected.columns)

# Run
reverse_transformed = ht.reverse_transform_subset(data)

# Assert
pd.testing.assert_frame_equal(reverse_transformed, expected)
int_transformer.reverse_transform.assert_called_once()
generator_transformer.reverse_transform.assert_not_called()

def test_reverse_transform_raises_error_no_config(self):
"""Test that ``reverse_transform`` raises an error.
Expand Down Expand Up @@ -1660,8 +1718,7 @@ def test_reverse_transform_with_subset(self):

# Run / Assert
expected_msg = (
'There are unexpected columns in the data you are trying to transform. You must '
'provide a transformed dataset with all the columns from the original data.'
'You must provide a transformed dataset with all the columns from the original data.'
)
with pytest.raises(Error, match=expected_msg):
ht.reverse_transform(data)
Expand Down Expand Up @@ -1691,9 +1748,9 @@ def test_reverse_transform_with_unknown_columns(self):
data = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})

# Run / Assert
expected_msg = (
'There are unexpected columns in the data you are trying to transform. You must '
'provide a transformed dataset with all the columns from the original data.'
expected_msg = re.escape(
'There are unexpected column names in the data you are trying to transform. '
"A reverse transform is not defined for ['col2']."
)
with pytest.raises(Error, match=expected_msg):
ht.reverse_transform(data)
Expand Down

0 comments on commit e7e4132

Please sign in to comment.