Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error to NullTransformer when data only contains nans #567

Merged
merged 1 commit into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rdt/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ class Error(Exception):


class TransformerInputError(Exception):
"""Error to raise when ``HyperTransformer`` receives incorrect input."""
"""Error to raise when ``HyperTransformer`` receives an incorrect input."""
13 changes: 13 additions & 0 deletions rdt/transformers/null.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pandas as pd

from rdt.errors import TransformerInputError

LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -52,10 +54,21 @@ def _get_missing_value_replacement(self, data):
Return:
object:
The fill value that needs to be used.

Raise:
TransformerInputError:
Error raised when data only contains nans and ``_missing_value_replacement``
is set to 'mean' or 'mode'.
"""
if self._missing_value_replacement is None:
return None

if self._missing_value_replacement in {'mean', 'mode'} and pd.isna(data).all():
raise TransformerInputError(
f"'missing_value_replacement' cannot be set to '{self._missing_value_replacement}'"
' when the provided data only contains NaNs.'
)

if self._missing_value_replacement == 'mean':
return data.mean()

Expand Down
57 changes: 31 additions & 26 deletions tests/unit/transformers/test_null.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Unit tests for the NullTransformer."""

import re
from unittest.mock import patch

import numpy as np
import pandas as pd
import pytest

from rdt.errors import TransformerInputError
from rdt.transformers import NullTransformer


Expand Down Expand Up @@ -98,32 +101,6 @@ def test__get_missing_value_replacement_scalar(self):
# Assert
assert missing_value_replacement == 'a_missing_value_replacement'

def test__get_missing_value_replacement_all_nulls(self):
"""Test _get_missing_value_replacement when all the values are null.

If the missing_value_replacement is not a scalar value and all the data
values are null, the output be the mean, which is `np.nan`.

Setup:
- NullTransformer passing 'mean' as the missing_value_replacement.

Input:
- A Series filled with nan values.
- A np.array of all True values.

Expected Output:
- 0
"""
# Setup
transformer = NullTransformer('mean')

# Run
data = pd.Series([np.nan, np.nan, np.nan], name='abc')
missing_value_replacement = transformer._get_missing_value_replacement(data)

# Assert
assert missing_value_replacement is np.nan

def test__get_missing_value_replacement_none_numerical(self):
"""Test _get_missing_value_replacement when missing_value_replacement is None.

Expand Down Expand Up @@ -205,6 +182,20 @@ def test__get_missing_value_replacement_mean(self):
# Assert
assert missing_value_replacement == 1.5

def test__get_missing_value_replacement_mean_only_nans(self):
"""Test when missing_value_replacement is mean and data only contains nans."""
# Setup
transformer = NullTransformer('mean')
data = pd.Series([float('nan'), None, np.nan], name='abc')

# Run and Assert
err_msg = re.escape(
"'missing_value_replacement' cannot be set to 'mean' when "
'the provided data only contains NaNs.'
)
with pytest.raises(TransformerInputError, match=err_msg):
transformer._get_missing_value_replacement(data)

def test__get_missing_value_replacement_mode(self):
"""Test _get_missing_value_replacement when missing_value_replacement is 'mode'.

Expand Down Expand Up @@ -232,6 +223,20 @@ def test__get_missing_value_replacement_mode(self):
# Assert
assert missing_value_replacement == 2

def test__get_missing_value_replacement_mode_only_nans(self):
"""Test when missing_value_replacement is mode and data only contains nans."""
# Setup
transformer = NullTransformer('mode')
data = pd.Series([float('nan'), None, np.nan], name='abc')

# Run and Assert
err_msg = re.escape(
"'missing_value_replacement' cannot be set to 'mode' when "
'the provided data only contains NaNs.'
)
with pytest.raises(TransformerInputError, match=err_msg):
transformer._get_missing_value_replacement(data)

def test_fit_model_missing_values_none_and_nulls(self):
"""Test fit when null column is none and there are nulls.

Expand Down