Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 17, 2023
1 parent 1ac0a8f commit fa08380
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
15 changes: 8 additions & 7 deletions rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ class GaussianNormalizer(FloatFormatter):
"""

_univariate = None
_DEPRECATED_DISTRIBUTIONS_MAPPING = {
'gaussian': 'norm',
'student_t': 't',
'truncated_gaussian': 'truncnorm'
}

@staticmethod
def _get_distributions():
Expand All @@ -292,10 +297,6 @@ def _get_distributions():
'gaussian_kde': univariate.GaussianKDE,
'truncnorm': univariate.TruncatedGaussian,
'uniform': univariate.UniformUnivariate,
# the following are deprecated
'gaussian': univariate.GaussianUnivariate,
'student_t': univariate.StudentTUnivariate,
'truncated_gaussian': univariate.TruncatedGaussian,
}

def __init__(self, model_missing_values=None, learn_rounding_scheme=False,
Expand All @@ -316,13 +317,13 @@ def __init__(self, model_missing_values=None, learn_rounding_scheme=False,
self._distributions = self._get_distributions()
if isinstance(distribution, str):
if distribution in {'gaussian', 'student_t', 'truncated_gaussian'}:
deprecated_distributions_mapping = {
'gaussian': 'norm', 'student_t': 't', 'truncated_gaussian': 'truncnorm'}
warnings.warn(
f"Future versions of RDT will not support '{distribution}' as an option. "
f"Please use '{deprecated_distributions_mapping[distribution]}' instead.",
f"Please use '{self._DEPRECATED_DISTRIBUTIONS_MAPPING[distribution]}' "
'instead.',
FutureWarning
)
distribution = self._DEPRECATED_DISTRIBUTIONS_MAPPING[distribution]

distribution = self._distributions[distribution]

Expand Down
3 changes: 0 additions & 3 deletions tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,12 +849,9 @@ def test__get_distributions(self):

# Assert
expected = {
'gaussian': univariate.GaussianUnivariate,
'gamma': univariate.GammaUnivariate,
'beta': univariate.BetaUnivariate,
'student_t': univariate.StudentTUnivariate,
'gaussian_kde': univariate.GaussianKDE,
'truncated_gaussian': univariate.TruncatedGaussian,
'uniform': univariate.UniformUnivariate,
'truncnorm': univariate.TruncatedGaussian,
'norm': univariate.GaussianUnivariate,
Expand Down

0 comments on commit fa08380

Please sign in to comment.