Skip to content

Commit

Permalink
FIX handle full NaT columns properly in Random*Sampler (scikit-learn-…
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Jan 24, 2024
1 parent 9e976a4 commit dcfa5f3
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 1 deletion.
5 changes: 5 additions & 0 deletions doc/whats_new/v0.12.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ Bug fixes
the number of samples in the minority class.
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.

- Fix a bug in :class:`~imblearn.under_sampling.RandomUnderSampler` and
:class:`~imblearn.over_sampling.RandomOverSampler` where a column containing only
NaT was not handled correctly.
:pr:`1059` by :user:`Guillaume Lemaitre <glemaitre>`.

Compatibility
.............

Expand Down
23 changes: 23 additions & 0 deletions imblearn/over_sampling/tests/test_random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,26 @@ def test_random_over_sampling_datetime():
pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
pd.testing.assert_index_equal(X_res.index, y_res.index)
assert_array_equal(y_res.to_numpy(), np.array([0, 0, 0, 1, 1, 1]))


def test_random_over_sampler_full_nat():
"""Check that we can return timedelta columns full of NaT.
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/1055
"""
pd = pytest.importorskip("pandas")

X = pd.DataFrame(
{
"col_str": ["abc", "def", "xyz"],
"col_timedelta": pd.to_timedelta([np.nan, np.nan, np.nan]),
}
)
y = np.array([0, 0, 1])

X_res, y_res = RandomOverSampler().fit_resample(X, y)
assert X_res.shape == (4, 2)
assert y_res.shape == (4,)

assert X_res["col_timedelta"].dtype == "timedelta64[ns]"
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,26 @@ def test_random_under_sampling_datetime():
pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
pd.testing.assert_index_equal(X_res.index, y_res.index)
assert_array_equal(y_res.to_numpy(), np.array([0, 1]))


def test_random_under_sampler_full_nat():
"""Check that we can return timedelta columns full of NaT.
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/1055
"""
pd = pytest.importorskip("pandas")

X = pd.DataFrame(
{
"col_str": ["abc", "def", "xyz"],
"col_timedelta": pd.to_timedelta([np.nan, np.nan, np.nan]),
}
)
y = np.array([0, 0, 1])

X_res, y_res = RandomUnderSampler().fit_resample(X, y)
assert X_res.shape == (2, 2)
assert y_res.shape == (2,)

assert X_res["col_timedelta"].dtype == "timedelta64[ns]"
19 changes: 18 additions & 1 deletion imblearn/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,24 @@ def _transfrom_one(self, array, props):
ret = pd.DataFrame.sparse.from_spmatrix(array, columns=props["columns"])
else:
ret = pd.DataFrame(array, columns=props["columns"])
ret = ret.astype(props["dtypes"])

try:
ret = ret.astype(props["dtypes"])
except TypeError:
# We special case the following error:
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/1055
# There is no easy way to have a generic workaround. Here, we detect
# that we have a column with only null values that is datetime64
# (resulting from the np.vstack of the resampling).
for col in ret.columns:
if (
ret[col].isnull().all()
and ret[col].dtype == "datetime64[ns]"
and props["dtypes"][col] == "timedelta64[ns]"
):
ret[col] = pd.to_timedelta(["NaT"] * len(ret[col]))
# try again
ret = ret.astype(props["dtypes"])
elif type_ == "series":
import pandas as pd

Expand Down

0 comments on commit dcfa5f3

Please sign in to comment.