Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extra anonymized columns are created during reverse_transform_subset #549

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,24 +798,29 @@ def _reverse_transform(self, data, prevent_subset):
raise NotFittedError(self._NOT_FIT_MESSAGE)

unknown_columns = self._subset(data.columns, self._output_columns, not_in=True)
if unknown_columns:
raise Error(
'There are unexpected column names in the data you are trying to transform. '
f'A reverse transform is not defined for {unknown_columns}.'
)

if prevent_subset:
contained = all(column in self._output_columns for column in data.columns)
is_subset = contained and len(data.columns) < len(self._output_columns)
if unknown_columns or is_subset:
if is_subset:
raise Error(
'There are unexpected columns in the data you are trying to transform. '
'You must provide a transformed dataset with all the columns from the '
'original data.'
)

elif unknown_columns:
raise Error(
'There are unexpected column names in the data you are trying to transform. '
f'A reverse transform is not defined for {unknown_columns}.'
)
for transformer in reversed(self._transformers_sequence):
data = transformer.reverse_transform(data, drop=False)

for transformer in reversed(self._transformers_sequence):
data = transformer.reverse_transform(data, drop=False)
else:
for transformer in reversed(self._transformers_sequence):
output_columns = transformer.get_output_columns()
if output_columns and set(output_columns).issubset(data.columns):
pvk-developer marked this conversation as resolved.
Show resolved Hide resolved
data = transformer.reverse_transform(data, drop=False)

reversed_columns = self._subset(self._input_columns, data.columns)

Expand Down
61 changes: 58 additions & 3 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import pandas as pd
import pytest

from rdt import HyperTransformer
from rdt import HyperTransformer, get_demo
from rdt.errors import Error, NotFittedError
from rdt.transformers import (
DEFAULT_TRANSFORMERS, BaseTransformer, BinaryEncoder, FloatFormatter, FrequencyEncoder,
OneHotEncoder, UnixTimestampEncoder, get_default_transformer, get_default_transformers)
DEFAULT_TRANSFORMERS, AnonymizedFaker, BaseTransformer, BinaryEncoder, FloatFormatter,
FrequencyEncoder, OneHotEncoder, RegexGenerator, UnixTimestampEncoder, get_default_transformer,
get_default_transformers)


class DummyTransformerNumerical(BaseTransformer):
Expand Down Expand Up @@ -734,3 +735,57 @@ def test_hyper_transformer_with_supported_sdtypes():

for transformer in ht.get_config()['transformers'].values():
assert not isinstance(transformer, BinaryEncoder)


def test_hyper_transformer_reverse_transform_subset_and_generators():
"""Test the ``HyperTransformer`` with ``reverse_transform_subset``.

Test that when calling ``reverse_transform_subset`` and there are ``generators`` like
``AnonymizedFaker`` or ``RegexGenerator`` those are not being used in the ``subset``, and
also any other transformer which can't transform the given columns.

Setup:
- DataFrame with multiple datatypes.
- Instance of HyperTransformer.
- Add ``pii`` using ``AnonymizedFaker`` and ``RegexGenerator``.

Run:
- Use ``fit_transform`` then ``revese_transform_subsample``.

Assert:
- Assert that the ``reverse_transformed`` data does not contain any additional columns
but the expected one.
"""
# Setup
customers = get_demo()
customers['id'] = ['ID_a', 'ID_b', 'ID_c', 'ID_d', 'ID_e']

# Create a config
ht = HyperTransformer()
ht.detect_initial_config(customers)

# credit_card and id are pii and text columns
ht.update_sdtypes({
'credit_card': 'pii',
'id': 'text'
})

ht.update_transformers({
'credit_card': AnonymizedFaker(),
'id': RegexGenerator(regex_format='id_[a-z]')
})

# Run
ht.fit(customers)
transformed = ht.transform(customers)
reverse_transformed = ht.reverse_transform_subset(transformed[['last_login.value']])

# Assert
expected_transformed_columns = [
'last_login.value',
'email_optin.value',
'age.value',
'dollars_spent.value'
]
assert all(expected_transformed_columns == transformed.columns)
assert reverse_transformed.columns == ['last_login']
67 changes: 62 additions & 5 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,64 @@ def test_reverse_transform(self):
bool_transformer.reverse_transform.assert_called_once()
datetime_transformer.reverse_transform.assert_called_once()

def test_reverse_transform_subset_with_generators(self):
"""Test the ``reverse_transform`` method.

Tests that ``reverse_transform`` loops through the ``_transformers_sequence``
in reverse order and calls ``transformer.reverse_transform`` if they have
``output_columns``.

Setup:
- The ``_transformers_sequence`` will be hardcoded with a list
of transformer mocks and one of them is a ``generator``
(does not have ``output_columns``).
- The ``_output_columns`` will be hardcoded.
- The ``_input_columns`` will be hardcoded.

Input:
- A DataFrame of multiple sdtypes.

Output:
- The reverse transformed DataFrame with the correct columns dropped.

Side Effects:
- Only the transformers with ``get_output_columns`` will be called.
"""
# Setup
int_transformer = Mock()
float_transformer = Mock()
generator_transformer = Mock()
int_transformer.get_output_columns.return_value = ['integer.out.value']
float_transformer.get_output_columns.return_value = ['float.value']
generator_transformer.get_output_columns.return_value = []

reverse_transformed_data = self.get_transformed_data()
float_transformer.reverse_transform = lambda x, drop: x
int_transformer.reverse_transform.return_value = reverse_transformed_data

data = self.get_transformed_data(True)

ht = HyperTransformer()
ht._validate_config_exists = Mock()
ht._validate_config_exists.return_value = True
ht._fitted = True
ht._transformers_sequence = [
int_transformer,
float_transformer,
generator_transformer
]
ht._output_columns = list(data.columns)
expected = self.get_data()
ht._input_columns = list(expected.columns)

# Run
reverse_transformed = ht.reverse_transform_subset(data)

# Assert
pd.testing.assert_frame_equal(reverse_transformed, expected)
int_transformer.reverse_transform.assert_called_once()
generator_transformer.reverse_transform.assert_not_called()

def test_reverse_transform_raises_error_no_config(self):
"""Test that ``reverse_transform`` raises an error.

Expand Down Expand Up @@ -1660,8 +1718,7 @@ def test_reverse_transform_with_subset(self):

# Run / Assert
expected_msg = (
'There are unexpected columns in the data you are trying to transform. You must '
'provide a transformed dataset with all the columns from the original data.'
'You must provide a transformed dataset with all the columns from the original data.'
)
with pytest.raises(Error, match=expected_msg):
ht.reverse_transform(data)
Expand Down Expand Up @@ -1691,9 +1748,9 @@ def test_reverse_transform_with_unknown_columns(self):
data = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})

# Run / Assert
expected_msg = (
'There are unexpected columns in the data you are trying to transform. You must '
'provide a transformed dataset with all the columns from the original data.'
expected_msg = re.escape(
'There are unexpected column names in the data you are trying to transform. '
"A reverse transform is not defined for ['col2']."
)
with pytest.raises(Error, match=expected_msg):
ht.reverse_transform(data)
Expand Down