diff --git a/doc/whats_new/v0.11.rst b/doc/whats_new/v0.11.rst index aa49204f1..1f0b2721a 100644 --- a/doc/whats_new/v0.11.rst +++ b/doc/whats_new/v0.11.rst @@ -17,6 +17,10 @@ Bug fixes since it requires a conversion to dense matrices. :pr:`1003` by :user:`Guillaume Lemaitre `. +- Remove spurious warning raised when minority class get over-sampled more than the + number of sample in the majority class. + :pr:`1007` by :user:`Guillaume Lemaitre `. + Compatibility ............. diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index da1e492f4..a36e6d81b 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -307,8 +307,8 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type): ) sampling_strategy_ = {} if sampling_type == "over-sampling": - n_samples_majority = max(target_stats.values()) - class_majority = max(target_stats, key=target_stats.get) + max(target_stats.values()) + max(target_stats, key=target_stats.get) for class_sample, n_samples in sampling_strategy.items(): if n_samples < target_stats[class_sample]: raise ValueError( @@ -318,13 +318,6 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type): f" Originally, there is {target_stats[class_sample]} " f"samples and {n_samples} samples are asked." ) - if n_samples > n_samples_majority: - warnings.warn( - f"After over-sampling, the number of samples ({n_samples})" - f" in class {class_sample} will be larger than the number of" - f" samples in the majority class (class #{class_majority} ->" - f" {n_samples_majority})" - ) sampling_strategy_[class_sample] = n_samples - target_stats[class_sample] elif sampling_type == "under-sampling": for class_sample, n_samples in sampling_strategy.items(): diff --git a/imblearn/utils/tests/test_validation.py b/imblearn/utils/tests/test_validation.py index 587b5e278..4394f04fc 100644 --- a/imblearn/utils/tests/test_validation.py +++ b/imblearn/utils/tests/test_validation.py @@ -256,14 +256,6 @@ def test_check_sampling_strategy( assert sampling_strategy_ == expected_sampling_strategy -def test_sampling_strategy_dict_over_sampling(): - y = np.array([1] * 50 + [2] * 100 + [3] * 25) - sampling_strategy = {1: 70, 2: 140, 3: 70} - expected_msg = "After over-sampling, the number of samples " - with pytest.warns(UserWarning, match=expected_msg): - check_sampling_strategy(sampling_strategy, y, "over-sampling") - - def test_sampling_strategy_callable_args(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) multiplier = {1: 1.5, 2: 1, 3: 3}