Skip to content

Commit

Permalink
Address the issue (#567)
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Oct 19, 2022
1 parent 9b89282 commit 257c972
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 27 deletions.
2 changes: 1 addition & 1 deletion rdt/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ class Error(Exception):


class TransformerInputError(Exception):
"""Error to raise when ``HyperTransformer`` receives incorrect input."""
"""Error to raise when ``HyperTransformer`` receives an incorrect input."""
13 changes: 13 additions & 0 deletions rdt/transformers/null.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pandas as pd

from rdt.errors import TransformerInputError

LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -52,10 +54,21 @@ def _get_missing_value_replacement(self, data):
Return:
object:
The fill value that needs to be used.
Raise:
TransformerInputError:
Error raised when data only contains nans and ``_missing_value_replacement``
is set to 'mean' or 'mode'.
"""
if self._missing_value_replacement is None:
return None

if self._missing_value_replacement in {'mean', 'mode'} and pd.isna(data).all():
raise TransformerInputError(
f"'missing_value_replacement' cannot be set to '{self._missing_value_replacement}'"
' when the provided data only contains NaNs.'
)

if self._missing_value_replacement == 'mean':
return data.mean()

Expand Down
57 changes: 31 additions & 26 deletions tests/unit/transformers/test_null.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Unit tests for the NullTransformer."""

import re
from unittest.mock import patch

import numpy as np
import pandas as pd
import pytest

from rdt.errors import TransformerInputError
from rdt.transformers import NullTransformer


Expand Down Expand Up @@ -98,32 +101,6 @@ def test__get_missing_value_replacement_scalar(self):
# Assert
assert missing_value_replacement == 'a_missing_value_replacement'

def test__get_missing_value_replacement_all_nulls(self):
"""Test _get_missing_value_replacement when all the values are null.
If the missing_value_replacement is not a scalar value and all the data
values are null, the output be the mean, which is `np.nan`.
Setup:
- NullTransformer passing 'mean' as the missing_value_replacement.
Input:
- A Series filled with nan values.
- A np.array of all True values.
Expected Output:
- 0
"""
# Setup
transformer = NullTransformer('mean')

# Run
data = pd.Series([np.nan, np.nan, np.nan], name='abc')
missing_value_replacement = transformer._get_missing_value_replacement(data)

# Assert
assert missing_value_replacement is np.nan

def test__get_missing_value_replacement_none_numerical(self):
"""Test _get_missing_value_replacement when missing_value_replacement is None.
Expand Down Expand Up @@ -205,6 +182,20 @@ def test__get_missing_value_replacement_mean(self):
# Assert
assert missing_value_replacement == 1.5

def test__get_missing_value_replacement_mean_only_nans(self):
"""Test when missing_value_replacement is mean and data only contains nans."""
# Setup
transformer = NullTransformer('mean')
data = pd.Series([float('nan'), None, np.nan], name='abc')

# Run and Assert
err_msg = re.escape(
"'missing_value_replacement' cannot be set to 'mean' when "
'the provided data only contains NaNs.'
)
with pytest.raises(TransformerInputError, match=err_msg):
transformer._get_missing_value_replacement(data)

def test__get_missing_value_replacement_mode(self):
"""Test _get_missing_value_replacement when missing_value_replacement is 'mode'.
Expand Down Expand Up @@ -232,6 +223,20 @@ def test__get_missing_value_replacement_mode(self):
# Assert
assert missing_value_replacement == 2

def test__get_missing_value_replacement_mode_only_nans(self):
"""Test when missing_value_replacement is mode and data only contains nans."""
# Setup
transformer = NullTransformer('mode')
data = pd.Series([float('nan'), None, np.nan], name='abc')

# Run and Assert
err_msg = re.escape(
"'missing_value_replacement' cannot be set to 'mode' when "
'the provided data only contains NaNs.'
)
with pytest.raises(TransformerInputError, match=err_msg):
transformer._get_missing_value_replacement(data)

def test_fit_model_missing_values_none_and_nulls(self):
"""Test fit when null column is none and there are nulls.
Expand Down

0 comments on commit 257c972

Please sign in to comment.