Skip to content

Commit

Permalink
make UniformEncoder the default for cat and boolea
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Aug 10, 2023
1 parent 1a0e24b commit 83f4112
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def get_transformer_name(transformer):

DEFAULT_TRANSFORMERS = {
'numerical': FloatFormatter(),
'categorical': LabelEncoder(add_noise=True),
'boolean': LabelEncoder(add_noise=True),
'categorical': UniformEncoder(),
'boolean': UniformEncoder(),
'datetime': UnixTimestampEncoder(),
'text': RegexGenerator(),
'pii': AnonymizedFaker(),
Expand Down
2 changes: 1 addition & 1 deletion rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _transform(self, data):
def map_labels(label):
return np.random.uniform(self.intervals[label][0], self.intervals[label][1])

return data_with_none.map(map_labels)
return data_with_none.map(map_labels).astype(float)

def _reverse_transform(self, data):
"""Convert float values back to the original categorical values.
Expand Down
40 changes: 20 additions & 20 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from rdt.errors import ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError
from rdt.transformers import (
AnonymizedFaker, BaseTransformer, BinaryEncoder, ClusterBasedNormalizer, FloatFormatter,
FrequencyEncoder, LabelEncoder, OneHotEncoder, RegexGenerator, UnixTimestampEncoder,
get_default_transformer, get_default_transformers)
FrequencyEncoder, LabelEncoder, OneHotEncoder, RegexGenerator, UniformEncoder,
UnixTimestampEncoder, get_default_transformer, get_default_transformers)
from rdt.transformers.datetime import OptimizedTimestampEncoder
from rdt.transformers.numerical import GaussianNormalizer
from rdt.transformers.pii.anonymizer import PseudoAnonymizedFaker
Expand Down Expand Up @@ -93,17 +93,17 @@ def get_transformed_data():
'integer': [1., 2., 1., 3., 1., 4., 2., 3.],
'float': [0.1, 0.2, 0.1, 0.2, 0.1, 0.4, 0.2, 0.3],
'categorical': [
0.9690758764963199, 0.8816575994729887, 1.1326495454234662, 1.7988488918189502,
0.9265972159030215, 1.885454600378942, 0.9280858691537548, 0.5093227924068265
0.6056724228102, 0.551035999670618, 0.6747435795337998, 0.9245683344321064,
0.5791232599393884, 0.9570454751421033, 0.5800536682210967, 0.3183267452542666
],
'bool': [
0.26161253184788935, 0.5735484647493089, 0.026673806296574787, 1.197229599974477,
0.8860641570557322, 0.33432787358513416, 1.1089412122841389, 0.6182653878449814
0.196209398885917, 0.4301613485619816, 0.02000535472243109, 0.7993073999936193,
0.6645481177917991, 0.25074590518885065, 0.7772353030710347, 0.46369904088373604
],
'datetime': datetimes,
'names': [
0.24180193241041126, 1.9297787196579723, 1.5617500744772101, 0.6811042561384157,
0.48017218468846856, 2.2867787591284823, 0.25476586891248476, 0.620052082101593
0.15112620775650704, 0.857444679914493, 0.7654375186193025, 0.42569016008650984,
0.30010761543029285, 0.9108473448910603, 0.15922866807030298, 0.3875325513134956
]
}, index=TEST_DATA_INDEX)

Expand Down Expand Up @@ -171,17 +171,17 @@ def test_hypertransformer_default_inputs():
'integer': [1., 2., 1., 3., 1., 4., 2., 3.],
'float': [0.1, 0.2, 0.1, 0.2, 0.1, 0.4, 0.2, 0.3],
'categorical': [
0.9690758764963199, 0.8816575994729887, 1.1326495454234662, 2.7988488918189502,
0.9265972159030215, 2.8854546003789423, 0.9280858691537548, 0.5093227924068265
0.6056724228102, 0.551035999670618, 0.6415811931779333, 0.9497122229547376,
0.5791232599393884, 0.9713636500947356, 0.5800536682210967, 0.3183267452542666
],
'bool': [
0.26161253184788935, 1.573548464749309, 0.026673806296574787, 2.1972295999744773,
0.8860641570557322, 1.334327873585134, 2.108941212284139, 0.6182653878449814
0.13080626592394468, 0.6433871161873272, 0.013336903148287393, 0.7993073999936193,
0.4430320785278661, 0.5835819683962835, 0.7772353030710347, 0.3091326939224907
],
'datetime': expected_datetimes,
'names': [
0.24180193241041126, 1.9297787196579723, 1.5617500744772101, 0.6811042561384157,
0.48017218468846856, 2.2867787591284823, 0.25476586891248476, 0.620052082101593
0.15112620775650704, 0.857444679914493, 0.7654375186193025, 0.42569016008650984,
0.30010761543029285, 0.9108473448910603, 0.15922866807030298, 0.3875325513134956
]
}, index=TEST_DATA_INDEX)
pd.testing.assert_frame_equal(transformed, expected_transformed)
Expand Down Expand Up @@ -212,10 +212,10 @@ def test_hypertransformer_default_inputs():

assert isinstance(ht.field_transformers['integer'], FloatFormatter)
assert isinstance(ht.field_transformers['float'], FloatFormatter)
assert isinstance(ht.field_transformers['categorical'], LabelEncoder)
assert isinstance(ht.field_transformers['bool'], LabelEncoder)
assert isinstance(ht.field_transformers['categorical'], UniformEncoder)
assert isinstance(ht.field_transformers['bool'], UniformEncoder)
assert isinstance(ht.field_transformers['datetime'], UnixTimestampEncoder)
assert isinstance(ht.field_transformers['names'], LabelEncoder)
assert isinstance(ht.field_transformers['names'], UniformEncoder)

get_default_transformers.cache_clear()
get_default_transformer.cache_clear()
Expand Down Expand Up @@ -254,10 +254,10 @@ def test_hypertransformer_field_transformers():
'transformers': {
'integer': FloatFormatter(missing_value_replacement='mean'),
'float': FloatFormatter(missing_value_replacement='mean'),
'categorical': LabelEncoder(add_noise=True),
'bool': LabelEncoder(add_noise=True),
'categorical': UniformEncoder(),
'bool': UniformEncoder(),
'datetime': DummyTransformerNotMLReady(),
'names': LabelEncoder(add_noise=True)
'names': UniformEncoder()
}
}

Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ def test_detect_initial_config(self, logger_mock):
field_transformers = {k: repr(v) for (k, v) in ht.field_transformers.items()}
assert field_transformers == {
'col1': 'FloatFormatter()',
'col2': 'LabelEncoder(add_noise=True)',
'col3': 'LabelEncoder(add_noise=True)',
'col2': 'UniformEncoder()',
'col3': 'UniformEncoder()',
'col4': 'UnixTimestampEncoder()',
'col5': 'FloatFormatter()'
}
Expand All @@ -313,8 +313,8 @@ def test_detect_initial_config(self, logger_mock):
' },',
' "transformers": {',
' "col1": FloatFormatter(),',
' "col2": LabelEncoder(add_noise=True),',
' "col3": LabelEncoder(add_noise=True),',
' "col2": UniformEncoder(),',
' "col3": UniformEncoder(),',
' "col4": UnixTimestampEncoder(),',
' "col5": FloatFormatter()',
' }',
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/transformers/test___init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from rdt.transformers import (
AnonymizedFaker, BinaryEncoder, FloatFormatter, LabelEncoder, RegexGenerator,
AnonymizedFaker, BinaryEncoder, FloatFormatter, RegexGenerator, UniformEncoder,
UnixTimestampEncoder, get_default_transformers, get_transformer_class, get_transformer_name)
from rdt.transformers.addons.identity.identity import IdentityTransformer

Expand Down Expand Up @@ -109,8 +109,8 @@ def test_get_default_transformers():
# Assert
expected_dict = {
'numerical': FloatFormatter,
'categorical': LabelEncoder,
'boolean': LabelEncoder,
'categorical': UniformEncoder,
'boolean': UniformEncoder,
'datetime': UnixTimestampEncoder,
'text': RegexGenerator,
'pii': AnonymizedFaker,
Expand Down

0 comments on commit 83f4112

Please sign in to comment.