Skip to content

Commit

Permalink
API deprecate estimator_ in favor of estimators_ in CNN and OSS (scik…
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Jul 10, 2023
1 parent d8cf8d6 commit 95e21e1
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 17 deletions.
10 changes: 9 additions & 1 deletion doc/whats_new/v0.12.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
Version 0.12.0 (Under development)
==================================


Changelog
---------

Deprecations
............

- Deprecate `estimator_` argument in favor of `estimators_` for the classes
:class:`~imblearn.under_sampling.CondensedNearestNeighbour` and
:class:`~imblearn.under_sampling.OneSidedSelection`. `estimator_` will be removed
in 0.14.
:pr:`xxx` by :user:`Guillaume Lemaitre <glemaitre>`.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: MIT

import numbers
import warnings
from collections import Counter

import numpy as np
Expand Down Expand Up @@ -59,6 +60,16 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
estimator_ : estimator object
The validated K-nearest neighbor estimator created from `n_neighbors` parameter.
.. deprecated:: 0.12
`estimator_` is deprecated in 0.12 and will be removed in 0.14. Use
`estimators_` instead that contains the list of all K-nearest
neighbors estimator used for each pair of class.
estimators_ : list of estimator objects of shape (n_resampled_classes - 1,)
Contains the K-nearest neighbor estimator used for per of classes.
.. versionadded:: 0.12
sample_indices_ : ndarray of shape (n_new_samples,)
Indices of the samples selected.
Expand Down Expand Up @@ -87,8 +98,8 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
-----
The method is based on [1]_.
Supports multi-class resampling. A one-vs.-rest scheme is used when
sampling a class as proposed in [1]_.
Supports multi-class resampling: a strategy one (minority) vs. each other
classes is applied.
References
----------
Expand Down Expand Up @@ -142,22 +153,25 @@ def __init__(
def _validate_estimator(self):
"""Private function to create the NN estimator"""
if self.n_neighbors is None:
self.estimator_ = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
estimator = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
elif isinstance(self.n_neighbors, numbers.Integral):
self.estimator_ = KNeighborsClassifier(
estimator = KNeighborsClassifier(
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
)
elif isinstance(self.n_neighbors, KNeighborsClassifier):
self.estimator_ = clone(self.n_neighbors)
estimator = clone(self.n_neighbors)

return estimator

def _fit_resample(self, X, y):
self._validate_estimator()
estimator = self._validate_estimator()

random_state = check_random_state(self.random_state)
target_stats = Counter(y)
class_minority = min(target_stats, key=target_stats.get)
idx_under = np.empty((0,), dtype=int)

self.estimators_ = []
for target_class in np.unique(y):
if target_class in self.sampling_strategy_.keys():
# Randomly get one sample from the majority class
Expand All @@ -184,7 +198,7 @@ def _fit_resample(self, X, y):
S_y = _safe_indexing(y, S_indices)

# fit knn on C
self.estimator_.fit(C_x, C_y)
self.estimators_.append(clone(estimator).fit(C_x, C_y))

good_classif_label = idx_maj_sample.copy()
# Check each sample in S if we keep it or drop it
Expand All @@ -196,7 +210,7 @@ def _fit_resample(self, X, y):
# Classify on S
if not issparse(x_sam):
x_sam = x_sam.reshape(1, -1)
pred_y = self.estimator_.predict(x_sam)
pred_y = self.estimators_[-1].predict(x_sam)

# If the prediction do not agree with the true label
# append it in C_x
Expand All @@ -210,12 +224,12 @@ def _fit_resample(self, X, y):
C_y = _safe_indexing(y, C_indices)

# fit a knn on C
self.estimator_.fit(C_x, C_y)
self.estimators_[-1].fit(C_x, C_y)

# This experimental to speed up the search
# Classify all the element in S and avoid to test the
# well classified elements
pred_S_y = self.estimator_.predict(S_x)
pred_S_y = self.estimators_[-1].predict(S_x)
good_classif_label = np.unique(
np.append(idx_maj_sample, np.flatnonzero(pred_S_y == S_y))
)
Expand All @@ -230,5 +244,15 @@ def _fit_resample(self, X, y):

return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under)

@property
def estimator_(self):
"""Last fitted k-NN estimator."""
warnings.warn(
"`estimator_` attribute has been deprecated in 0.12 and will be "
"removed in 0.14. Use `estimators_` instead.",
FutureWarning,
)
return self.estimators_[-1]

def _more_tags(self):
return {"sample_indices": True}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# License: MIT

import numbers
import warnings
from collections import Counter

import numpy as np
Expand Down Expand Up @@ -58,6 +59,16 @@ class OneSidedSelection(BaseCleaningSampler):
estimator_ : estimator object
Validated K-nearest neighbors estimator created from parameter `n_neighbors`.
.. deprecated:: 0.12
`estimator_` is deprecated in 0.12 and will be removed in 0.14. Use
`estimators_` instead that contains the list of all K-nearest
neighbors estimator used for each pair of class.
estimators_ : list of estimator objects of shape (n_resampled_classes - 1,)
Contains the K-nearest neighbor estimator used for per of classes.
.. versionadded:: 0.12
sample_indices_ : ndarray of shape (n_new_samples,)
Indices of the samples selected.
Expand Down Expand Up @@ -138,23 +149,26 @@ def __init__(
def _validate_estimator(self):
"""Private function to create the NN estimator"""
if self.n_neighbors is None:
self.estimator_ = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
estimator = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
elif isinstance(self.n_neighbors, int):
self.estimator_ = KNeighborsClassifier(
estimator = KNeighborsClassifier(
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
)
elif isinstance(self.n_neighbors, KNeighborsClassifier):
self.estimator_ = clone(self.n_neighbors)
estimator = clone(self.n_neighbors)

return estimator

def _fit_resample(self, X, y):
self._validate_estimator()
estimator = self._validate_estimator()

random_state = check_random_state(self.random_state)
target_stats = Counter(y)
class_minority = min(target_stats, key=target_stats.get)

idx_under = np.empty((0,), dtype=int)

self.estimators_ = []
for target_class in np.unique(y):
if target_class in self.sampling_strategy_.keys():
# select a sample from the current class
Expand All @@ -177,8 +191,8 @@ def _fit_resample(self, X, y):
idx_maj_extracted = np.delete(idx_maj, sel_idx_maj, axis=0)
S_x = _safe_indexing(X, idx_maj_extracted)
S_y = _safe_indexing(y, idx_maj_extracted)
self.estimator_.fit(C_x, C_y)
pred_S_y = self.estimator_.predict(S_x)
self.estimators_.append(clone(estimator).fit(C_x, C_y))
pred_S_y = self.estimators_[-1].predict(S_x)

S_misclassified_indices = np.flatnonzero(pred_S_y != S_y)
idx_tmp = idx_maj_extracted[S_misclassified_indices]
Expand All @@ -199,5 +213,15 @@ def _fit_resample(self, X, y):

return X_cleaned, y_cleaned

@property
def estimator_(self):
"""Last fitted k-NN estimator."""
warnings.warn(
"`estimator_` attribute has been deprecated in 0.12 and will be "
"removed in 0.14. Use `estimators_` instead.",
FutureWarning,
)
return self.estimators_[-1]

def _more_tags(self):
return {"sample_indices": True}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.utils._testing import assert_array_equal

Expand Down Expand Up @@ -95,3 +96,34 @@ def test_cnn_fit_resample_with_object(n_neighbors):
X_resampled, y_resampled = cnn.fit_resample(X, Y)
assert_array_equal(X_resampled, X_gt)
assert_array_equal(y_resampled, y_gt)


def test_condensed_nearest_neighbour_multiclass():
"""Check the validity of the fitted attributes `estimators_`."""
X, y = make_classification(
n_samples=1_000,
n_classes=4,
weights=[0.1, 0.2, 0.2, 0.5],
n_clusters_per_class=1,
random_state=0,
)
cnn = CondensedNearestNeighbour(random_state=RND_SEED)
cnn.fit_resample(X, y)

assert len(cnn.estimators_) == len(cnn.sampling_strategy_)
other_classes = []
for est in cnn.estimators_:
assert est.classes_[0] == 0 # minority class
assert est.classes_[1] in {1, 2, 3} # other classes
other_classes.append(est.classes_[1])
assert len(set(other_classes)) == len(other_classes)


# TODO: remove in 0.14
def test_condensed_nearest_neighbors_deprecation():
"""Check that we raise a FutureWarning when accessing the parameter `estimator_`."""
cnn = CondensedNearestNeighbour(random_state=RND_SEED)
cnn.fit_resample(X, Y)
warn_msg = "`estimator_` attribute has been deprecated"
with pytest.warns(FutureWarning, match=warn_msg):
cnn.estimator_
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.utils._testing import assert_array_equal

Expand Down Expand Up @@ -95,3 +96,34 @@ def test_oss_with_object(n_neighbors):
X_resampled, y_resampled = oss.fit_resample(X, Y)
assert_array_equal(X_resampled, X_gt)
assert_array_equal(y_resampled, y_gt)


def test_one_sided_selection_multiclass():
"""Check the validity of the fitted attributes `estimators_`."""
X, y = make_classification(
n_samples=1_000,
n_classes=4,
weights=[0.1, 0.2, 0.2, 0.5],
n_clusters_per_class=1,
random_state=0,
)
oss = OneSidedSelection(random_state=RND_SEED)
oss.fit_resample(X, y)

assert len(oss.estimators_) == len(oss.sampling_strategy_)
other_classes = []
for est in oss.estimators_:
assert est.classes_[0] == 0 # minority class
assert est.classes_[1] in {1, 2, 3} # other classes
other_classes.append(est.classes_[1])
assert len(set(other_classes)) == len(other_classes)


# TODO: remove in 0.14
def test_one_sided_selection_deprecation():
"""Check that we raise a FutureWarning when accessing the parameter `estimator_`."""
oss = OneSidedSelection(random_state=RND_SEED)
oss.fit_resample(X, Y)
warn_msg = "`estimator_` attribute has been deprecated"
with pytest.warns(FutureWarning, match=warn_msg):
oss.estimator_

0 comments on commit 95e21e1

Please sign in to comment.