Skip to content

Commit

Permalink
FIX/DEPR follow literature for the implementation of NCR (scikit-lear…
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Jul 10, 2023
1 parent 95e21e1 commit 6622afb
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 107 deletions.
4 changes: 2 additions & 2 deletions doc/under_sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,10 @@ union of samples to be rejected between the :class:`EditedNearestNeighbours`
and the output a 3 nearest neighbors classifier. The class can be used as::

>>> from imblearn.under_sampling import NeighbourhoodCleaningRule
>>> ncr = NeighbourhoodCleaningRule()
>>> ncr = NeighbourhoodCleaningRule(n_neighbors=11)
>>> X_resampled, y_resampled = ncr.fit_resample(X, y)
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 234), (2, 4666)]
[(0, 64), (1, 193), (2, 4535)]

.. image:: ./auto_examples/under-sampling/images/sphx_glr_plot_comparison_under_sampling_005.png
:target: ./auto_examples/under-sampling/plot_comparison_under_sampling.html
Expand Down
18 changes: 17 additions & 1 deletion doc/whats_new/v0.12.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,27 @@ Version 0.12.0 (Under development)
Changelog
---------

Bug fixes
.........

- Fix a bug in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule` where the
`kind_sel="all"` was not working as explained in the literature.
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.

- Fix a bug in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule` where the
`threshold_cleaning` ratio was multiplied on the total number of samples instead of
the number of samples in the minority class.
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.

Deprecations
............

- Deprecate `estimator_` argument in favor of `estimators_` for the classes
:class:`~imblearn.under_sampling.CondensedNearestNeighbour` and
:class:`~imblearn.under_sampling.OneSidedSelection`. `estimator_` will be removed
in 0.14.
:pr:`xxx` by :user:`Guillaume Lemaitre <glemaitre>`.
:pr:`1011` by :user:`Guillaume Lemaitre <glemaitre>`.

- Deprecate `kind_sel` in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule.
It will be removed in 0.14. The parameter does not have any effect.
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.
2 changes: 1 addition & 1 deletion examples/under-sampling/plot_comparison_under_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def plot_decision_function(X, y, clf, ax, title=None):
samplers = [
CondensedNearestNeighbour(random_state=0),
OneSidedSelection(random_state=0),
NeighbourhoodCleaningRule(),
NeighbourhoodCleaningRule(n_neighbors=11),
]

for ax, sampler in zip(axs, samplers):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
# License: MIT

import numbers
import warnings
from collections import Counter

import numpy as np
from sklearn.base import clone
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
from sklearn.utils import _safe_indexing

from ...utils import Substitution, check_neighbors_object
from ...utils import Substitution
from ...utils._docstring import _n_jobs_docstring
from ...utils._param_validation import HasMethods, Interval, StrOptions
from ...utils.fixes import _mode
from ...utils._param_validation import HasMethods, Hidden, Interval, StrOptions
from ..base import BaseCleaningSampler
from ._edited_nearest_neighbours import EditedNearestNeighbours

Expand All @@ -35,9 +37,14 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
----------
{sampling_strategy}
edited_nearest_neighbours : estimator object, default=None
The :class:`~imblearn.under_sampling.EditedNearestNeighbours` (ENN)
object to clean the dataset. If `None`, a default ENN is created with
`kind_sel="mode"` and `n_neighbors=n_neighbors`.
n_neighbors : int or estimator object, default=3
If ``int``, size of the neighbourhood to consider to compute the
nearest neighbors. If object, an estimator that inherits from
K-nearest neighbors. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the nearest-neighbors. By default, it will be a 3-NN.
Expand All @@ -52,6 +59,11 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
The strategy `"all"` will be less conservative than `'mode'`. Thus,
more samples will be removed when `kind_sel="all"` generally.
.. deprecated:: 0.12
`kind_sel` is deprecated in 0.12 and will be removed in 0.14.
Currently the parameter has no effect and corresponds always to the
`"all"` strategy.
threshold_cleaning : float, default=0.5
Threshold used to whether consider a class or not during the cleaning
after applying ENN. A class will be considered during cleaning when:
Expand All @@ -70,9 +82,16 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
corresponds to the class labels from which to sample and the values
are the number of samples to sample.
edited_nearest_neighbours_ : estimator object
The edited nearest neighbour object used to make the first resampling.
nn_ : estimator object
Validated K-nearest Neighbours object created from `n_neighbors` parameter.
classes_to_clean_ : list
The classes considered with under-sampling by `nn_` in the second cleaning
phase.
sample_indices_ : ndarray of shape (n_new_samples,)
Indices of the samples selected.
Expand Down Expand Up @@ -118,52 +137,75 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
>>> ncr = NeighbourhoodCleaningRule()
>>> X_res, y_res = ncr.fit_resample(X, y)
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{1: 877, 0: 100}})
Resampled dataset shape Counter({{1: 888, 0: 100}})
"""

_parameter_constraints: dict = {
**BaseCleaningSampler._parameter_constraints,
"edited_nearest_neighbours": [
HasMethods(["fit_resample"]),
None,
],
"n_neighbors": [
Interval(numbers.Integral, 1, None, closed="left"),
HasMethods(["kneighbors", "kneighbors_graph"]),
],
"kind_sel": [StrOptions({"all", "mode"})],
"threshold_cleaning": [Interval(numbers.Real, 0, 1, closed="neither")],
"kind_sel": [StrOptions({"all", "mode"}), Hidden(StrOptions({"deprecated"}))],
"threshold_cleaning": [Interval(numbers.Real, 0, None, closed="neither")],
"n_jobs": [numbers.Integral, None],
}

def __init__(
self,
*,
sampling_strategy="auto",
edited_nearest_neighbours=None,
n_neighbors=3,
kind_sel="all",
kind_sel="deprecated",
threshold_cleaning=0.5,
n_jobs=None,
):
super().__init__(sampling_strategy=sampling_strategy)
self.edited_nearest_neighbours = edited_nearest_neighbours
self.n_neighbors = n_neighbors
self.kind_sel = kind_sel
self.threshold_cleaning = threshold_cleaning
self.n_jobs = n_jobs

def _validate_estimator(self):
"""Create the objects required by NCR."""
self.nn_ = check_neighbors_object(
"n_neighbors", self.n_neighbors, additional_neighbor=1
)
self.nn_.set_params(**{"n_jobs": self.n_jobs})
if isinstance(self.n_neighbors, numbers.Integral):
self.nn_ = KNeighborsClassifier(
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
)
elif isinstance(self.n_neighbors, NearestNeighbors):
# backward compatibility when passing a NearestNeighbors object
self.nn_ = KNeighborsClassifier(
n_neighbors=self.n_neighbors.n_neighbors - 1, n_jobs=self.n_jobs
)
else:
self.nn_ = clone(self.n_neighbors)

if self.edited_nearest_neighbours is None:
self.edited_nearest_neighbours_ = EditedNearestNeighbours(
sampling_strategy=self.sampling_strategy,
n_neighbors=self.n_neighbors,
kind_sel="mode",
n_jobs=self.n_jobs,
)
else:
self.edited_nearest_neighbours_ = clone(self.edited_nearest_neighbours)

def _fit_resample(self, X, y):
if self.kind_sel != "deprecated":
warnings.warn(
"`kind_sel` is deprecated in 0.12 and will be removed in 0.14. "
"It already has not effect and corresponds to the `'all'` option.",
FutureWarning,
)
self._validate_estimator()
enn = EditedNearestNeighbours(
sampling_strategy=self.sampling_strategy,
n_neighbors=self.n_neighbors,
kind_sel="mode",
n_jobs=self.n_jobs,
)
enn.fit_resample(X, y)
index_not_a1 = enn.sample_indices_
self.edited_nearest_neighbours_.fit_resample(X, y)
index_not_a1 = self.edited_nearest_neighbours_.sample_indices_
index_a1 = np.ones(y.shape, dtype=bool)
index_a1[index_not_a1] = False
index_a1 = np.flatnonzero(index_a1)
Expand All @@ -172,30 +214,34 @@ def _fit_resample(self, X, y):
target_stats = Counter(y)
class_minority = min(target_stats, key=target_stats.get)
# compute which classes to consider for cleaning for the A2 group
classes_under_sample = [
self.classes_to_clean_ = [
c
for c, n_samples in target_stats.items()
if (
c in self.sampling_strategy_.keys()
and (n_samples > X.shape[0] * self.threshold_cleaning)
and (n_samples > target_stats[class_minority] * self.threshold_cleaning)
)
]
self.nn_.fit(X)
self.nn_.fit(X, y)

class_minority_indices = np.flatnonzero(y == class_minority)
X_class = _safe_indexing(X, class_minority_indices)
y_class = _safe_indexing(y, class_minority_indices)
nnhood_idx = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
nnhood_label = y[nnhood_idx]
if self.kind_sel == "mode":
nnhood_label_majority, _ = _mode(nnhood_label, axis=1)
nnhood_bool = np.ravel(nnhood_label_majority) == y_class
else: # self.kind_sel == "all":
nnhood_label_majority = nnhood_label == class_minority
nnhood_bool = np.all(nnhood_label, axis=1)
# compute a2 group
index_a2 = np.ravel(nnhood_idx[~nnhood_bool])
index_a2 = np.unique(
[index for index in index_a2 if y[index] in classes_under_sample]
X_minority = _safe_indexing(X, class_minority_indices)
y_minority = _safe_indexing(y, class_minority_indices)

y_pred_minority = self.nn_.predict(X_minority)
# add an additional sample since the query points contains the original dataset
neighbors_to_minority_indices = self.nn_.kneighbors(
X_minority, n_neighbors=self.nn_.n_neighbors + 1, return_distance=False
)[:, 1:]

mask_misclassified_minority = y_pred_minority != y_minority
index_a2 = np.ravel(neighbors_to_minority_indices[mask_misclassified_minority])
index_a2 = np.array(
[
index
for index in np.unique(index_a2)
if y[index] in self.classes_to_clean_
]
)

union_a1_a2 = np.union1d(index_a1, index_a2).astype(int)
Expand Down
Loading

0 comments on commit 6622afb

Please sign in to comment.