Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Sep 6, 2023
1 parent 5d7f8b7 commit c8c46b5
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 19 deletions.
38 changes: 19 additions & 19 deletions rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,10 @@ class GaussianNormalizer(FloatFormatter):
* ``gamma``: Use a Gamma distribution.
* ``beta``: Use a Beta distribution.
* ``student_t``: Use a Student T distribution.
* ``gussian_kde``: Use a GaussianKDE distribution. This model is non-parametric,
* ``gaussian_kde``: Use a GaussianKDE distribution. This model is non-parametric,
so using this will make ``get_parameters`` unusable.
* ``truncated_gaussian``: Use a Truncated Gaussian distribution.
# ``uniform``: Use a UniformUnivariate distribution.
missing_value_generation (str or None):
The way missing values are being handled. There are three strategies:
Expand All @@ -269,24 +270,6 @@ class GaussianNormalizer(FloatFormatter):

_univariate = None

def __init__(self, model_missing_values=None, learn_rounding_scheme=False,
enforce_min_max_values=False, distribution='truncated_gaussian',
missing_value_generation='random'):
super().__init__(
model_missing_values=model_missing_values,
missing_value_generation=missing_value_generation,
learn_rounding_scheme=learn_rounding_scheme,
enforce_min_max_values=enforce_min_max_values
)

self.distribution = distribution # Distribution initialized by the user

self._distributions = self._get_distributions()
if isinstance(distribution, str):
distribution = self._distributions[distribution]

self._distribution = distribution

@staticmethod
def _get_distributions():
try:
Expand All @@ -305,8 +288,25 @@ def _get_distributions():
'student_t': univariate.StudentTUnivariate,
'gaussian_kde': univariate.GaussianKDE,
'truncated_gaussian': univariate.TruncatedGaussian,
'uniform': univariate.UniformUnivariate,
}

def __init__(self, model_missing_values=None, learn_rounding_scheme=False,
enforce_min_max_values=False, distribution='truncated_gaussian',
missing_value_generation='random'):
super().__init__(
model_missing_values=model_missing_values,
missing_value_generation=missing_value_generation,
learn_rounding_scheme=learn_rounding_scheme,
enforce_min_max_values=enforce_min_max_values
)

self._distributions = self._get_distributions()
if isinstance(distribution, str):
distribution = self._distributions[distribution]

self._distribution = distribution

def _get_univariate(self):
distribution = self._distribution
if any(isinstance(distribution, dist) for dist in self._distributions.values()):
Expand Down
61 changes: 61 additions & 0 deletions tests/integration/transformers/test_numerical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
from copulas import univariate

from rdt.transformers.numerical import ClusterBasedNormalizer, FloatFormatter, GaussianNormalizer

Expand Down Expand Up @@ -195,6 +196,66 @@ def test_int_nan(self):
reverse = ct.reverse_transform(transformed)
np.testing.assert_array_almost_equal(reverse, data, decimal=2)

def test_uniform(self):
"""Test it works when distribution='uniform'."""
# Setup
data = pd.DataFrame(np.random.uniform(size=1000), columns=['a'])
ct = GaussianNormalizer(distribution='uniform')

# Run
ct.fit(data, 'a')
transformed = ct.transform(data)
reverse = ct.reverse_transform(transformed)

# Assert
assert isinstance(transformed, pd.DataFrame)
assert transformed.shape == (1000, 1)

np.testing.assert_almost_equal(transformed['a'].mean(), 0, decimal=1)
np.testing.assert_almost_equal(transformed['a'].std(), 1, decimal=1)

np.testing.assert_array_almost_equal(reverse, data, decimal=1)

def test_uniform_object(self):
"""Test it works when distribution=UniformUnivariate()."""
# Setup
data = pd.DataFrame(np.random.uniform(size=1000), columns=['a'])
ct = GaussianNormalizer(distribution=univariate.UniformUnivariate())

# Run
ct.fit(data, 'a')
transformed = ct.transform(data)
reverse = ct.reverse_transform(transformed)

# Assert
assert isinstance(transformed, pd.DataFrame)
assert transformed.shape == (1000, 1)

np.testing.assert_almost_equal(transformed['a'].mean(), 0, decimal=1)
np.testing.assert_almost_equal(transformed['a'].std(), 1, decimal=1)

np.testing.assert_array_almost_equal(reverse, data, decimal=1)

def test_uniform_class(self):
"""Test it works when distribution=UniformUnivariate."""
# Setup
data = pd.DataFrame(np.random.uniform(size=1000), columns=['a'])
ct = GaussianNormalizer(distribution=univariate.UniformUnivariate)

# Run
ct.fit(data, 'a')
transformed = ct.transform(data)
reverse = ct.reverse_transform(transformed)

# Assert
assert isinstance(transformed, pd.DataFrame)
assert transformed.shape == (1000, 1)

np.testing.assert_almost_equal(transformed['a'].mean(), 0, decimal=1)
np.testing.assert_almost_equal(transformed['a'].std(), 1, decimal=1)

np.testing.assert_array_almost_equal(reverse, data, decimal=1)


class TestClusterBasedNormalizer:

Expand Down
1 change: 1 addition & 0 deletions tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ def test__get_distributions(self):
'student_t': univariate.StudentTUnivariate,
'gaussian_kde': univariate.GaussianKDE,
'truncated_gaussian': univariate.TruncatedGaussian,
'uniform': univariate.UniformUnivariate
}
assert distributions == expected

Expand Down

0 comments on commit c8c46b5

Please sign in to comment.