diff --git a/rdt/transformers/numerical.py b/rdt/transformers/numerical.py index d0edb2984..72bb347f9 100644 --- a/rdt/transformers/numerical.py +++ b/rdt/transformers/numerical.py @@ -252,9 +252,10 @@ class GaussianNormalizer(FloatFormatter): * ``gamma``: Use a Gamma distribution. * ``beta``: Use a Beta distribution. * ``student_t``: Use a Student T distribution. - * ``gussian_kde``: Use a GaussianKDE distribution. This model is non-parametric, + * ``gaussian_kde``: Use a GaussianKDE distribution. This model is non-parametric, so using this will make ``get_parameters`` unusable. * ``truncated_gaussian``: Use a Truncated Gaussian distribution. + # ``uniform``: Use a UniformUnivariate distribution. missing_value_generation (str or None): The way missing values are being handled. There are three strategies: @@ -269,24 +270,6 @@ class GaussianNormalizer(FloatFormatter): _univariate = None - def __init__(self, model_missing_values=None, learn_rounding_scheme=False, - enforce_min_max_values=False, distribution='truncated_gaussian', - missing_value_generation='random'): - super().__init__( - model_missing_values=model_missing_values, - missing_value_generation=missing_value_generation, - learn_rounding_scheme=learn_rounding_scheme, - enforce_min_max_values=enforce_min_max_values - ) - - self.distribution = distribution # Distribution initialized by the user - - self._distributions = self._get_distributions() - if isinstance(distribution, str): - distribution = self._distributions[distribution] - - self._distribution = distribution - @staticmethod def _get_distributions(): try: @@ -305,8 +288,25 @@ def _get_distributions(): 'student_t': univariate.StudentTUnivariate, 'gaussian_kde': univariate.GaussianKDE, 'truncated_gaussian': univariate.TruncatedGaussian, + 'uniform': univariate.UniformUnivariate, } + def __init__(self, model_missing_values=None, learn_rounding_scheme=False, + enforce_min_max_values=False, distribution='truncated_gaussian', + missing_value_generation='random'): + super().__init__( + model_missing_values=model_missing_values, + missing_value_generation=missing_value_generation, + learn_rounding_scheme=learn_rounding_scheme, + enforce_min_max_values=enforce_min_max_values + ) + + self._distributions = self._get_distributions() + if isinstance(distribution, str): + distribution = self._distributions[distribution] + + self._distribution = distribution + def _get_univariate(self): distribution = self._distribution if any(isinstance(distribution, dist) for dist in self._distributions.values()): diff --git a/tests/integration/transformers/test_numerical.py b/tests/integration/transformers/test_numerical.py index 6ae2e92ff..a4a78f26c 100644 --- a/tests/integration/transformers/test_numerical.py +++ b/tests/integration/transformers/test_numerical.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +from copulas import univariate from rdt.transformers.numerical import ClusterBasedNormalizer, FloatFormatter, GaussianNormalizer @@ -195,6 +196,66 @@ def test_int_nan(self): reverse = ct.reverse_transform(transformed) np.testing.assert_array_almost_equal(reverse, data, decimal=2) + def test_uniform(self): + """Test it works when distribution='uniform'.""" + # Setup + data = pd.DataFrame(np.random.uniform(size=1000), columns=['a']) + ct = GaussianNormalizer(distribution='uniform') + + # Run + ct.fit(data, 'a') + transformed = ct.transform(data) + reverse = ct.reverse_transform(transformed) + + # Assert + assert isinstance(transformed, pd.DataFrame) + assert transformed.shape == (1000, 1) + + np.testing.assert_almost_equal(transformed['a'].mean(), 0, decimal=1) + np.testing.assert_almost_equal(transformed['a'].std(), 1, decimal=1) + + np.testing.assert_array_almost_equal(reverse, data, decimal=1) + + def test_uniform_object(self): + """Test it works when distribution=UniformUnivariate().""" + # Setup + data = pd.DataFrame(np.random.uniform(size=1000), columns=['a']) + ct = GaussianNormalizer(distribution=univariate.UniformUnivariate()) + + # Run + ct.fit(data, 'a') + transformed = ct.transform(data) + reverse = ct.reverse_transform(transformed) + + # Assert + assert isinstance(transformed, pd.DataFrame) + assert transformed.shape == (1000, 1) + + np.testing.assert_almost_equal(transformed['a'].mean(), 0, decimal=1) + np.testing.assert_almost_equal(transformed['a'].std(), 1, decimal=1) + + np.testing.assert_array_almost_equal(reverse, data, decimal=1) + + def test_uniform_class(self): + """Test it works when distribution=UniformUnivariate.""" + # Setup + data = pd.DataFrame(np.random.uniform(size=1000), columns=['a']) + ct = GaussianNormalizer(distribution=univariate.UniformUnivariate) + + # Run + ct.fit(data, 'a') + transformed = ct.transform(data) + reverse = ct.reverse_transform(transformed) + + # Assert + assert isinstance(transformed, pd.DataFrame) + assert transformed.shape == (1000, 1) + + np.testing.assert_almost_equal(transformed['a'].mean(), 0, decimal=1) + np.testing.assert_almost_equal(transformed['a'].std(), 1, decimal=1) + + np.testing.assert_array_almost_equal(reverse, data, decimal=1) + class TestClusterBasedNormalizer: diff --git a/tests/unit/transformers/test_numerical.py b/tests/unit/transformers/test_numerical.py index 9b35c1135..e13300dcf 100644 --- a/tests/unit/transformers/test_numerical.py +++ b/tests/unit/transformers/test_numerical.py @@ -841,6 +841,7 @@ def test__get_distributions(self): 'student_t': univariate.StudentTUnivariate, 'gaussian_kde': univariate.GaussianKDE, 'truncated_gaussian': univariate.TruncatedGaussian, + 'uniform': univariate.UniformUnivariate } assert distributions == expected