Skip to content

Commit

Permalink
FIX remove spurious warning raised when over-sampling the minority cl…
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Jul 8, 2023
1 parent f14033b commit b468f7f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 17 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.11.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Bug fixes
since it requires a conversion to dense matrices.
:pr:`1003` by :user:`Guillaume Lemaitre <glemaitre>`.

- Remove spurious warning raised when minority class get over-sampled more than the
number of sample in the majority class.
:pr:`1007` by :user:`Guillaume Lemaitre <glemaitre>`.

Compatibility
.............

Expand Down
11 changes: 2 additions & 9 deletions imblearn/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
)
sampling_strategy_ = {}
if sampling_type == "over-sampling":
n_samples_majority = max(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
max(target_stats.values())
max(target_stats, key=target_stats.get)
for class_sample, n_samples in sampling_strategy.items():
if n_samples < target_stats[class_sample]:
raise ValueError(
Expand All @@ -318,13 +318,6 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
f" Originally, there is {target_stats[class_sample]} "
f"samples and {n_samples} samples are asked."
)
if n_samples > n_samples_majority:
warnings.warn(
f"After over-sampling, the number of samples ({n_samples})"
f" in class {class_sample} will be larger than the number of"
f" samples in the majority class (class #{class_majority} ->"
f" {n_samples_majority})"
)
sampling_strategy_[class_sample] = n_samples - target_stats[class_sample]
elif sampling_type == "under-sampling":
for class_sample, n_samples in sampling_strategy.items():
Expand Down
8 changes: 0 additions & 8 deletions imblearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,6 @@ def test_check_sampling_strategy(
assert sampling_strategy_ == expected_sampling_strategy


def test_sampling_strategy_dict_over_sampling():
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
sampling_strategy = {1: 70, 2: 140, 3: 70}
expected_msg = "After over-sampling, the number of samples "
with pytest.warns(UserWarning, match=expected_msg):
check_sampling_strategy(sampling_strategy, y, "over-sampling")


def test_sampling_strategy_callable_args():
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
multiplier = {1: 1.5, 2: 1, 3: 3}
Expand Down

0 comments on commit b468f7f

Please sign in to comment.