Skip to content

Commit a4947ea

Browse files
authored
[MRG] EHN add BalancedBaggingClassifier (scikit-learn-contrib#315)
* EHN add BalancedBaggingClassifier * TST add two missing test * DOC add examples * FIX not passing sample_weight at fit * DOC add api documentation * DOC fix docstring * iter * DOC fix docstring * DOC add user guide entry and cross referencing * FIX mv into a new module * FIX add missing dependency
1 parent 2e7c070 commit a4947ea

File tree

10 files changed

+898
-7
lines changed

10 files changed

+898
-7
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ Prototype selection
109109
:template: class.rst
110110

111111
ensemble.BalanceCascade
112+
ensemble.BalancedBaggingClassifier
112113
ensemble.EasyEnsemble
113114

114115

doc/ensemble.rst

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ Ensemble of samplers
66

77
.. currentmodule:: imblearn.ensemble
88

9+
.. _ensemble_samplers:
10+
11+
Samplers
12+
--------
13+
914
An imbalanced data set can be balanced by creating several balanced
1015
subsets. The module :mod:`imblearn.ensemble` allows to create such sets.
1116

@@ -54,3 +59,54 @@ parameter ``n_max_subset`` and an additional bootstraping can be activated with
5459
See
5560
:ref:`sphx_glr_auto_examples_ensemble_plot_easy_ensemble.py` and
5661
:ref:`sphx_glr_auto_examples_ensemble_plot_balance_cascade.py`.
62+
63+
.. _ensemble_meta_estimators:
64+
65+
Chaining ensemble of samplers and estimators
66+
--------------------------------------------
67+
68+
In ensemble classifiers, bagging methods build several estimators on different
69+
randomly selected subset of data. In scikit-learn, this classifier is named
70+
``BaggingClassifier``. However, this classifier does not allow to balance each
71+
subset of data. Therefore, when training on imbalanced data set, this
72+
classifier will favor the majority classes::
73+
74+
>>> from sklearn.model_selection import train_test_split
75+
>>> from sklearn.metrics import confusion_matrix
76+
>>> from sklearn.ensemble import BaggingClassifier
77+
>>> from sklearn.tree import DecisionTreeClassifier
78+
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
79+
>>> bc = BaggingClassifier(base_estimator=DecisionTreeClassifier(),
80+
... random_state=0)
81+
>>> bc.fit(X_train, y_train) #doctest: +ELLIPSIS
82+
BaggingClassifier(...)
83+
>>> y_pred = bc.predict(X_test)
84+
>>> confusion_matrix(y_test, y_pred)
85+
array([[ 0, 0, 12],
86+
[ 0, 0, 59],
87+
[ 0, 0, 1179]])
88+
89+
:class:`BalancedBaggingClassifier` allows to resample each subset of data
90+
before to train each estimator of the ensemble. In short, it combines the
91+
output of an :class:`EasyEnsemble` sampler with an ensemble of classifiers
92+
(i.e. ``BaggingClassifier``). Therefore, :class:`BalancedBaggingClassifier`
93+
takes the same parameters than the scikit-learn
94+
``BaggingClassifier``. Additionally, there is two additional parameters,
95+
``ratio`` and ``replacement``, as in the :class:`EasyEnsemble` sampler::
96+
97+
98+
>>> from imblearn.ensemble import BalancedBaggingClassifier
99+
>>> bbc = BalancedBaggingClassifier(base_estimator=DecisionTreeClassifier(),
100+
... ratio='auto',
101+
... replacement=False,
102+
... random_state=0)
103+
>>> bbc.fit(X, y) # doctest: +ELLIPSIS
104+
BalancedBaggingClassifier(...)
105+
>>> y_pred = bbc.predict(X_test)
106+
>>> confusion_matrix(y_test, y_pred)
107+
array([[ 12, 0, 0],
108+
[ 0, 55, 4],
109+
[ 68, 53, 1058]])
110+
111+
See
112+
:ref:`sphx_glr_auto_examples_ensemble_plot_comparison_bagging_classifier.py`.

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ New features
5353
Enhancement
5454
~~~~~~~~~~~
5555

56+
- Add :class:`ensemble.BalancedBaggingClassifier` which is a meta estimator to
57+
directly use the :class:`ensemble.EasyEnsemble` chained with a classifier. By
58+
`Guillaume Lemaitre`_.
59+
5660
- All samplers accepts sparse matrices with defaulting on CSR type. By
5761
`Guillaume Lemaitre`_.
5862

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
=========================================================
3+
Comparison of balanced and imbalanced bagging classifiers
4+
=========================================================
5+
6+
This example shows the benefit of balancing the training set when using a
7+
bagging classifier. ``BalancedBaggingClassifier`` chains a
8+
``RandomUnderSampler`` and a given classifier while ``BaggingClassifier`` is
9+
using directly the imbalanced data.
10+
11+
Balancing the data set before training the classifier improve the
12+
classification performance. In addition, it avoids the ensemble to focus on the
13+
majority class which would be a known drawback of the decision tree
14+
classifiers.
15+
16+
"""
17+
18+
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
19+
# License: MIT
20+
21+
from collections import Counter
22+
import itertools
23+
24+
import matplotlib.pyplot as plt
25+
import numpy as np
26+
27+
from sklearn.datasets import load_iris
28+
from sklearn.model_selection import train_test_split
29+
from sklearn.ensemble import BaggingClassifier
30+
from sklearn.metrics import confusion_matrix
31+
32+
from imblearn.datasets import make_imbalance
33+
from imblearn.ensemble import BalancedBaggingClassifier
34+
35+
from imblearn.metrics import classification_report_imbalanced
36+
37+
38+
def plot_confusion_matrix(cm, classes,
39+
normalize=False,
40+
title='Confusion matrix',
41+
cmap=plt.cm.Blues):
42+
"""
43+
This function prints and plots the confusion matrix.
44+
Normalization can be applied by setting `normalize=True`.
45+
"""
46+
if normalize:
47+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
48+
print("Normalized confusion matrix")
49+
else:
50+
print('Confusion matrix, without normalization')
51+
52+
print(cm)
53+
54+
plt.imshow(cm, interpolation='nearest', cmap=cmap)
55+
plt.title(title)
56+
plt.colorbar()
57+
tick_marks = np.arange(len(classes))
58+
plt.xticks(tick_marks, classes, rotation=45)
59+
plt.yticks(tick_marks, classes)
60+
61+
fmt = '.2f' if normalize else 'd'
62+
thresh = cm.max() / 2.
63+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
64+
plt.text(j, i, format(cm[i, j], fmt),
65+
horizontalalignment="center",
66+
color="white" if cm[i, j] > thresh else "black")
67+
68+
plt.tight_layout()
69+
plt.ylabel('True label')
70+
plt.xlabel('Predicted label')
71+
72+
73+
iris = load_iris()
74+
X, y = make_imbalance(iris.data, iris.target, ratio={0: 25, 1: 40, 2: 50},
75+
random_state=0)
76+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
77+
78+
bagging = BaggingClassifier(random_state=0)
79+
balanced_bagging = BalancedBaggingClassifier(random_state=0)
80+
81+
print('Class distribution of the training set: {}'.format(Counter(y_train)))
82+
83+
bagging.fit(X_train, y_train)
84+
balanced_bagging.fit(X_train, y_train)
85+
86+
print('Class distribution of the test set: {}'.format(Counter(y_test)))
87+
88+
print('Classification results using a bagging classifier on imbalanced data')
89+
y_pred_bagging = bagging.predict(X_test)
90+
print(classification_report_imbalanced(y_test, y_pred_bagging))
91+
cm_bagging = confusion_matrix(y_test, y_pred_bagging)
92+
plt.figure()
93+
plot_confusion_matrix(cm_bagging, classes=iris.target_names,
94+
title='Confusion matrix using BaggingClassifier')
95+
96+
print('Classification results using a bagging classifier on balanced data')
97+
y_pred_balanced_bagging = balanced_bagging.predict(X_test)
98+
print(classification_report_imbalanced(y_test, y_pred_balanced_bagging))
99+
cm_balanced_bagging = confusion_matrix(y_test, y_pred_balanced_bagging)
100+
plt.figure()
101+
plot_confusion_matrix(cm_balanced_bagging, classes=iris.target_names,
102+
title='Confusion matrix using BalancedBaggingClassifier')
103+
104+
plt.show()

imblearn/ensemble/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@
66
from .easy_ensemble import EasyEnsemble
77
from .balance_cascade import BalanceCascade
88

9-
__all__ = ['EasyEnsemble', 'BalanceCascade']
9+
from .classifier import BalancedBaggingClassifier
10+
11+
__all__ = ['EasyEnsemble', 'BalancedBaggingClassifier', 'BalanceCascade']

imblearn/ensemble/balance_cascade.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class BalanceCascade(BaseEnsembleSampler):
2727
This method iteratively select subset and make an ensemble of the
2828
different sets. The selection is performed using a specific classifier.
2929
30-
Read more in the :ref:`User Guide <ensemble>`.
30+
Read more in the :ref:`User Guide <ensemble_samplers>`.
3131
3232
Parameters
3333
----------
@@ -99,7 +99,7 @@ class BalanceCascade(BaseEnsembleSampler):
9999
100100
See also
101101
--------
102-
EasyEnsemble
102+
BalancedBaggingClassifier, EasyEnsemble
103103
104104
References
105105
----------

0 commit comments

Comments
 (0)