diff --git a/doc/whats_new/v0.12.rst b/doc/whats_new/v0.12.rst index 88017b547..1d4a34a4b 100644 --- a/doc/whats_new/v0.12.rst +++ b/doc/whats_new/v0.12.rst @@ -18,6 +18,11 @@ Bug fixes the number of samples in the minority class. :pr:`1012` by :user:`Guillaume Lemaitre `. +- Fix a bug in :class:`~imblearn.under_sampling.RandomUnderSampler` and + :class:`~imblearn.over_sampling.RandomOverSampler` where a column containing only + NaT was not handled correctly. + :pr:`1059` by :user:`Guillaume Lemaitre `. + Compatibility ............. diff --git a/imblearn/over_sampling/tests/test_random_over_sampler.py b/imblearn/over_sampling/tests/test_random_over_sampler.py index 6ad4b75ef..efa40c855 100644 --- a/imblearn/over_sampling/tests/test_random_over_sampler.py +++ b/imblearn/over_sampling/tests/test_random_over_sampler.py @@ -287,3 +287,26 @@ def test_random_over_sampling_datetime(): pd.testing.assert_series_equal(X_res.dtypes, X.dtypes) pd.testing.assert_index_equal(X_res.index, y_res.index) assert_array_equal(y_res.to_numpy(), np.array([0, 0, 0, 1, 1, 1])) + + +def test_random_over_sampler_full_nat(): + """Check that we can return timedelta columns full of NaT. + + Non-regression test for: + https://github.com/scikit-learn-contrib/imbalanced-learn/issues/1055 + """ + pd = pytest.importorskip("pandas") + + X = pd.DataFrame( + { + "col_str": ["abc", "def", "xyz"], + "col_timedelta": pd.to_timedelta([np.nan, np.nan, np.nan]), + } + ) + y = np.array([0, 0, 1]) + + X_res, y_res = RandomOverSampler().fit_resample(X, y) + assert X_res.shape == (4, 2) + assert y_res.shape == (4,) + + assert X_res["col_timedelta"].dtype == "timedelta64[ns]" diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py index 9fc9f084c..f4e927902 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py @@ -162,3 +162,26 @@ def test_random_under_sampling_datetime(): pd.testing.assert_series_equal(X_res.dtypes, X.dtypes) pd.testing.assert_index_equal(X_res.index, y_res.index) assert_array_equal(y_res.to_numpy(), np.array([0, 1])) + + +def test_random_under_sampler_full_nat(): + """Check that we can return timedelta columns full of NaT. + + Non-regression test for: + https://github.com/scikit-learn-contrib/imbalanced-learn/issues/1055 + """ + pd = pytest.importorskip("pandas") + + X = pd.DataFrame( + { + "col_str": ["abc", "def", "xyz"], + "col_timedelta": pd.to_timedelta([np.nan, np.nan, np.nan]), + } + ) + y = np.array([0, 0, 1]) + + X_res, y_res = RandomUnderSampler().fit_resample(X, y) + assert X_res.shape == (2, 2) + assert y_res.shape == (2,) + + assert X_res["col_timedelta"].dtype == "timedelta64[ns]" diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index bf1d8351f..b21c15788 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -66,7 +66,24 @@ def _transfrom_one(self, array, props): ret = pd.DataFrame.sparse.from_spmatrix(array, columns=props["columns"]) else: ret = pd.DataFrame(array, columns=props["columns"]) - ret = ret.astype(props["dtypes"]) + + try: + ret = ret.astype(props["dtypes"]) + except TypeError: + # We special case the following error: + # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/1055 + # There is no easy way to have a generic workaround. Here, we detect + # that we have a column with only null values that is datetime64 + # (resulting from the np.vstack of the resampling). + for col in ret.columns: + if ( + ret[col].isnull().all() + and ret[col].dtype == "datetime64[ns]" + and props["dtypes"][col] == "timedelta64[ns]" + ): + ret[col] = pd.to_timedelta(["NaT"] * len(ret[col])) + # try again + ret = ret.astype(props["dtypes"]) elif type_ == "series": import pandas as pd