Skip to content

Commit

Permalink
ENH: Pass a classifier object instead of string (scikit-learn-contrib…
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored and chkoar committed Nov 6, 2016
1 parent a1af197 commit 7a5afeb
Show file tree
Hide file tree
Showing 4 changed files with 524 additions and 211 deletions.
198 changes: 138 additions & 60 deletions imblearn/ensemble/balance_cascade.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Class to perform under-sampling using balace cascade."""
from __future__ import print_function

import warnings

import numpy as np

from sklearn.base import ClassifierMixin
from sklearn.neighbors import KNeighborsClassifier
from sklearn.utils import check_random_state
from sklearn.utils.validation import has_fit_parameter

from ..base import BaseBinarySampler
from six import string_types

ESTIMATOR_KIND = ('knn', 'decision-tree', 'random-forest', 'adaboost',
'gradient-boosting', 'linear-svm')
from ..base import BaseBinarySampler


class BalanceCascade(BaseBinarySampler):
Expand Down Expand Up @@ -40,18 +45,29 @@ class BalanceCascade(BaseBinarySampler):
the training will be selected that could lead to a large number of
subsets. We can probably deduce this number empirically.
classifier : str, optional (default='knn')
classifier : str, optional (default=None)
The classifier that will be selected to confront the prediction
with the real labels. The choices are the following: 'knn',
'decision-tree', 'random-forest', 'adaboost', 'gradient-boosting'
and 'linear-svm'.
NOTE: `classifier` is deprecated from 0.2 and will be replaced in 0.4.
Use `estimator` instead.
estimator : object, optional (default=KNeighborsClassifier())
An estimator inherited from `sklearn.base.ClassifierMixin` and having
an attribute `predict_proba`.
bootstrap : bool, optional (default=True)
Whether to bootstrap the data before each iteration.
**kwargs : keywords
The parameters associated with the classifier provided.
NOTE: `**kwargs` has been deprecated from 0.2 and will be replaced in
0.4. Use `estimator` object instead to pass parameters associated
to an estimator.
Attributes
----------
min_c_ : str or int
Expand Down Expand Up @@ -100,16 +116,97 @@ class BalanceCascade(BaseBinarySampler):
"""

def __init__(self, ratio='auto', return_indices=False, random_state=None,
n_max_subset=None, classifier='knn', bootstrap=True,
**kwargs):
n_max_subset=None, classifier=None, estimator=None,
bootstrap=True, **kwargs):
super(BalanceCascade, self).__init__(ratio=ratio,
random_state=random_state)
self.return_indices = return_indices
self.classifier = classifier
self.estimator = estimator
self.n_max_subset = n_max_subset
self.bootstrap = bootstrap
self.kwargs = kwargs

def _validate_estimator(self):
"""Private function to create the classifier"""

if self.classifier is not None:
warnings.warn('`classifier` will be replaced in version'
' 0.4. Use a `estimator` instead.',
DeprecationWarning)
self.estimator = self.classifier

if (self.estimator is not None and
isinstance(self.estimator, ClassifierMixin) and
hasattr(self.estimator, 'predict')):
self.estimator_ = self.estimator
elif self.estimator is None:
self.estimator_ = KNeighborsClassifier()
# To be removed in 0.4
elif (self.estimator is not None and
isinstance(self.estimator, string_types)):
warnings.warn('`estimator` will be replaced in version'
' 0.4. Use a classifier object instead of a string.',
DeprecationWarning)
# Define the classifier to use
if self.estimator == 'knn':
self.estimator_ = KNeighborsClassifier(
**self.kwargs)
elif self.estimator == 'decision-tree':
from sklearn.tree import DecisionTreeClassifier
self.estimator_ = DecisionTreeClassifier(
random_state=self.random_state,
**self.kwargs)
elif self.estimator == 'random-forest':
from sklearn.ensemble import RandomForestClassifier
self.estimator_ = RandomForestClassifier(
random_state=self.random_state,
**self.kwargs)
elif self.estimator == 'adaboost':
from sklearn.ensemble import AdaBoostClassifier
self.estimator_ = AdaBoostClassifier(
random_state=self.random_state,
**self.kwargs)
elif self.estimator == 'gradient-boosting':
from sklearn.ensemble import GradientBoostingClassifier
self.estimator_ = GradientBoostingClassifier(
random_state=self.random_state,
**self.kwargs)
elif self.estimator == 'linear-svm':
from sklearn.svm import LinearSVC
self.estimator_ = LinearSVC(random_state=self.random_state,
**self.kwargs)
else:
raise NotImplementedError
else:
raise ValueError('Invalid parameter `estimator`')

self.logger.debug(self.estimator_)

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.
Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.
Returns
-------
self : object,
Return self.
"""

super(BalanceCascade, self).fit(X, y)

self._validate_estimator()

return self

def _sample(self, X, y):
"""Resample the dataset.
Expand All @@ -135,42 +232,9 @@ def _sample(self, X, y):
"""

if self.classifier not in ESTIMATOR_KIND:
raise NotImplementedError

random_state = check_random_state(self.random_state)

# Define the classifier to use
if self.classifier == 'knn':
from sklearn.neighbors import KNeighborsClassifier
classifier = KNeighborsClassifier(
**self.kwargs)
elif self.classifier == 'decision-tree':
from sklearn.tree import DecisionTreeClassifier
classifier = DecisionTreeClassifier(
random_state=random_state,
**self.kwargs)
elif self.classifier == 'random-forest':
from sklearn.ensemble import RandomForestClassifier
classifier = RandomForestClassifier(
random_state=random_state,
**self.kwargs)
elif self.classifier == 'adaboost':
from sklearn.ensemble import AdaBoostClassifier
classifier = AdaBoostClassifier(
random_state=random_state,
**self.kwargs)
elif self.classifier == 'gradient-boosting':
from sklearn.ensemble import GradientBoostingClassifier
classifier = GradientBoostingClassifier(
random_state=random_state,
**self.kwargs)
elif self.classifier == 'linear-svm':
from sklearn.svm import LinearSVC
classifier = LinearSVC(random_state=random_state,
**self.kwargs)
else:
raise NotImplementedError
support_sample_weight = has_fit_parameter(self.estimator_,
"sample_weight")

X_resampled = []
y_resampled = []
Expand All @@ -185,6 +249,7 @@ def _sample(self, X, y):
# return them later
if self.return_indices:
idx_min = np.flatnonzero(y == self.min_c_)
idx_maj = np.flatnonzero(y == self.maj_c_)

# Condition to initiliase before the search
b_subset_search = True
Expand Down Expand Up @@ -227,27 +292,42 @@ def _sample(self, X, y):
X_resampled.append(x_data)
y_resampled.append(y_data)
if self.return_indices:
idx_under.append(np.concatenate((idx_min, idx_sel_from_maj),
idx_under.append(np.concatenate((idx_min,
idx_maj[idx_sel_from_maj]),
axis=0))

if (not (self.classifier == 'knn' or
self.classifier == 'linear-svm') and
self.bootstrap):
# Apply a bootstrap on x_data
curr_sample_weight = np.ones((y_data.size,), dtype=np.float64)
# Get the indices of interest
if self.bootstrap:
indices = random_state.randint(0, y_data.size, y_data.size)
sample_counts = np.bincount(indices, minlength=y_data.size)
curr_sample_weight *= sample_counts
else:
indices = np.arange(y_data.size)

# Train the classifier using the current data
classifier.fit(x_data, y_data, curr_sample_weight)
# Draw samples, using sample weights, and then fit
if support_sample_weight:
self.logger.debug('Sample-weight is supported')
curr_sample_weight = np.ones((y_data.size,), dtype=np.float64)

if self.bootstrap:
self.logger.debug('Go for a bootstrap')
sample_counts = np.bincount(indices, minlength=y_data.size)
curr_sample_weight *= sample_counts
else:
self.logger.debug('No bootstrap')
mask = np.zeros(y_data.size, dtype=np.bool)
mask[indices] = True
not_indices_mask = ~mask
curr_sample_weight[not_indices_mask] = 0

self.estimator_.fit(x_data, y_data,
sample_weight=curr_sample_weight)

# Draw samples, using a mask, and then fit
else:
# Train the classifier using the current data
classifier.fit(x_data, y_data)
self.logger.debug('Sample-weight is not supported')
self.estimator_.fit(x_data[indices], y_data[indices])

# Predict using only the majority class
pred_label = classifier.predict(N_x[idx_sel_from_maj, :])
pred_label = self.estimator_.predict(N_x[idx_sel_from_maj, :])

# Basically let's find which sample have to be retained for the
# next round
Expand Down Expand Up @@ -288,9 +368,8 @@ def _sample(self, X, y):
X_resampled.append(x_data)
y_resampled.append(y_data)
if self.return_indices:
idx_under.append(np.concatenate((idx_min,
idx_sel_from_maj),
axis=0))
idx_under.append(np.concatenate(
(idx_min, idx_maj[idx_sel_from_maj]), axis=0))

self.logger.debug('Creation of the subset #%s', n_subsets)

Expand Down Expand Up @@ -321,9 +400,8 @@ def _sample(self, X, y):
X_resampled.append(x_data)
y_resampled.append(y_data)
if self.return_indices:
idx_under.append(np.concatenate((idx_min,
idx_sel_from_maj),
axis=0))
idx_under.append(np.concatenate(
(idx_min, idx_maj[idx_sel_from_maj]), axis=0))
self.logger.debug('Creation of the subset #%s', n_subsets)

# We found a new subset, increase the counter
Expand Down
Loading

0 comments on commit 7a5afeb

Please sign in to comment.