Skip to content

Commit 636f4e2

Browse files
authored
fix (#706)
1 parent 8fc2e5e commit 636f4e2

File tree

3 files changed

+81
-19
lines changed

3 files changed

+81
-19
lines changed

rdt/transformers/numerical.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,10 @@ class GaussianNormalizer(FloatFormatter):
252252
* ``gamma``: Use a Gamma distribution.
253253
* ``beta``: Use a Beta distribution.
254254
* ``student_t``: Use a Student T distribution.
255-
* ``gussian_kde``: Use a GaussianKDE distribution. This model is non-parametric,
255+
* ``gaussian_kde``: Use a GaussianKDE distribution. This model is non-parametric,
256256
so using this will make ``get_parameters`` unusable.
257257
* ``truncated_gaussian``: Use a Truncated Gaussian distribution.
258+
# ``uniform``: Use a UniformUnivariate distribution.
258259
259260
missing_value_generation (str or None):
260261
The way missing values are being handled. There are three strategies:
@@ -269,24 +270,6 @@ class GaussianNormalizer(FloatFormatter):
269270

270271
_univariate = None
271272

272-
def __init__(self, model_missing_values=None, learn_rounding_scheme=False,
273-
enforce_min_max_values=False, distribution='truncated_gaussian',
274-
missing_value_generation='random'):
275-
super().__init__(
276-
model_missing_values=model_missing_values,
277-
missing_value_generation=missing_value_generation,
278-
learn_rounding_scheme=learn_rounding_scheme,
279-
enforce_min_max_values=enforce_min_max_values
280-
)
281-
282-
self.distribution = distribution # Distribution initialized by the user
283-
284-
self._distributions = self._get_distributions()
285-
if isinstance(distribution, str):
286-
distribution = self._distributions[distribution]
287-
288-
self._distribution = distribution
289-
290273
@staticmethod
291274
def _get_distributions():
292275
try:
@@ -305,8 +288,25 @@ def _get_distributions():
305288
'student_t': univariate.StudentTUnivariate,
306289
'gaussian_kde': univariate.GaussianKDE,
307290
'truncated_gaussian': univariate.TruncatedGaussian,
291+
'uniform': univariate.UniformUnivariate,
308292
}
309293

294+
def __init__(self, model_missing_values=None, learn_rounding_scheme=False,
295+
enforce_min_max_values=False, distribution='truncated_gaussian',
296+
missing_value_generation='random'):
297+
super().__init__(
298+
model_missing_values=model_missing_values,
299+
missing_value_generation=missing_value_generation,
300+
learn_rounding_scheme=learn_rounding_scheme,
301+
enforce_min_max_values=enforce_min_max_values
302+
)
303+
304+
self._distributions = self._get_distributions()
305+
if isinstance(distribution, str):
306+
distribution = self._distributions[distribution]
307+
308+
self._distribution = distribution
309+
310310
def _get_univariate(self):
311311
distribution = self._distribution
312312
if any(isinstance(distribution, dist) for dist in self._distributions.values()):

tests/integration/transformers/test_numerical.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pandas as pd
3+
from copulas import univariate
34

45
from rdt.transformers.numerical import ClusterBasedNormalizer, FloatFormatter, GaussianNormalizer
56

@@ -195,6 +196,66 @@ def test_int_nan(self):
195196
reverse = ct.reverse_transform(transformed)
196197
np.testing.assert_array_almost_equal(reverse, data, decimal=2)
197198

199+
def test_uniform(self):
200+
"""Test it works when distribution='uniform'."""
201+
# Setup
202+
data = pd.DataFrame(np.random.uniform(size=1000), columns=['a'])
203+
ct = GaussianNormalizer(distribution='uniform')
204+
205+
# Run
206+
ct.fit(data, 'a')
207+
transformed = ct.transform(data)
208+
reverse = ct.reverse_transform(transformed)
209+
210+
# Assert
211+
assert isinstance(transformed, pd.DataFrame)
212+
assert transformed.shape == (1000, 1)
213+
214+
np.testing.assert_almost_equal(transformed['a'].mean(), 0, decimal=1)
215+
np.testing.assert_almost_equal(transformed['a'].std(), 1, decimal=1)
216+
217+
np.testing.assert_array_almost_equal(reverse, data, decimal=1)
218+
219+
def test_uniform_object(self):
220+
"""Test it works when distribution=UniformUnivariate()."""
221+
# Setup
222+
data = pd.DataFrame(np.random.uniform(size=1000), columns=['a'])
223+
ct = GaussianNormalizer(distribution=univariate.UniformUnivariate())
224+
225+
# Run
226+
ct.fit(data, 'a')
227+
transformed = ct.transform(data)
228+
reverse = ct.reverse_transform(transformed)
229+
230+
# Assert
231+
assert isinstance(transformed, pd.DataFrame)
232+
assert transformed.shape == (1000, 1)
233+
234+
np.testing.assert_almost_equal(transformed['a'].mean(), 0, decimal=1)
235+
np.testing.assert_almost_equal(transformed['a'].std(), 1, decimal=1)
236+
237+
np.testing.assert_array_almost_equal(reverse, data, decimal=1)
238+
239+
def test_uniform_class(self):
240+
"""Test it works when distribution=UniformUnivariate."""
241+
# Setup
242+
data = pd.DataFrame(np.random.uniform(size=1000), columns=['a'])
243+
ct = GaussianNormalizer(distribution=univariate.UniformUnivariate)
244+
245+
# Run
246+
ct.fit(data, 'a')
247+
transformed = ct.transform(data)
248+
reverse = ct.reverse_transform(transformed)
249+
250+
# Assert
251+
assert isinstance(transformed, pd.DataFrame)
252+
assert transformed.shape == (1000, 1)
253+
254+
np.testing.assert_almost_equal(transformed['a'].mean(), 0, decimal=1)
255+
np.testing.assert_almost_equal(transformed['a'].std(), 1, decimal=1)
256+
257+
np.testing.assert_array_almost_equal(reverse, data, decimal=1)
258+
198259

199260
class TestClusterBasedNormalizer:
200261

tests/unit/transformers/test_numerical.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,7 @@ def test__get_distributions(self):
841841
'student_t': univariate.StudentTUnivariate,
842842
'gaussian_kde': univariate.GaussianKDE,
843843
'truncated_gaussian': univariate.TruncatedGaussian,
844+
'uniform': univariate.UniformUnivariate
844845
}
845846
assert distributions == expected
846847

0 commit comments

Comments
 (0)