Skip to content

Commit

Permalink
API change behaviour of bootstrap in BRF (scikit-learn-contrib#1010)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Jul 9, 2023
1 parent 124d108 commit d8cf8d6
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 22 deletions.
3 changes: 2 additions & 1 deletion doc/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ each tree of the forest will be provided a balanced bootstrap sample

>>> from imblearn.ensemble import BalancedRandomForestClassifier
>>> brf = BalancedRandomForestClassifier(
... n_estimators=100, random_state=0, sampling_strategy="all", replacement=True
... n_estimators=100, random_state=0, sampling_strategy="all", replacement=True,
... bootstrap=False,
... )
>>> brf.fit(X_train, y_train)
BalancedRandomForestClassifier(...)
Expand Down
13 changes: 11 additions & 2 deletions doc/whats_new/v0.11.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
.. _changes_0_11:

Version 0.11.1
==============

Changelog
---------


Version 0.11.0
==============

Expand Down Expand Up @@ -40,9 +47,11 @@ Deprecation
and will be removed in version 0.13. Use `categorical_encoder_` instead.
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.

- The default of the parameters `sampling_strategy` and `replacement` will change in
- The default of the parameters `sampling_strategy`, `bootstrap` and
`replacement` will change in
:class:`~imblearn.ensemble.BalancedRandomForestClassifier` to follow the
implementation of the original paper. This changes will take effect in version 0.13.
implementation of the original paper. This changes will take effect in
version 0.13.
:pr:`1006` by :user:`Guillaume Lemaitre <glemaitre>`.

Enhancements
Expand Down
6 changes: 5 additions & 1 deletion examples/applications/plot_impact_imbalanced_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,11 @@
rf_clf = make_pipeline(
preprocessor_tree,
BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True, random_state=42, n_jobs=2
sampling_strategy="all",
replacement=True,
bootstrap=False,
random_state=42,
n_jobs=2,
),
)

Expand Down
6 changes: 5 additions & 1 deletion examples/ensemble/plot_comparison_ensemble_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@

rf = RandomForestClassifier(n_estimators=50, random_state=0)
brf = BalancedRandomForestClassifier(
n_estimators=50, sampling_strategy="all", replacement=True, random_state=0
n_estimators=50,
sampling_strategy="all",
replacement=True,
bootstrap=False,
random_state=0,
)

rf.fit(X_train, y_train)
Expand Down
34 changes: 28 additions & 6 deletions imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ def _local_parallel_build_trees(
class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassifier):
"""A balanced random forest classifier.
A balanced random forest randomly under-samples each bootstrap sample to
balance it.
A balanced random forest differs from a classical random forest by the
fact that it will draw a bootstrap sample from the minority class and
sample with replacement the same number of samples from the majority
class.
Read more in the :ref:`User Guide <forest>`.
Expand Down Expand Up @@ -187,6 +189,12 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
bootstrap : bool, default=True
Whether bootstrap samples are used when building trees.
.. versionchanged:: 0.13
The default of `bootstrap` will change from `True` to `False` in
version 0.13. Bootstrapping is already taken care by the internal
sampler using `replacement=True`. This implementation follows the
algorithm proposed in [1]_.
oob_score : bool, default=False
Whether to use out-of-bag samples to estimate
the generalization accuracy.
Expand Down Expand Up @@ -395,7 +403,8 @@ class labels (multi-output problem).
... n_informative=4, weights=[0.2, 0.3, 0.5],
... random_state=0)
>>> clf = BalancedRandomForestClassifier(
... sampling_strategy="all", replacement=True, max_depth=2, random_state=0)
... sampling_strategy="all", replacement=True, max_depth=2, random_state=0,
... bootstrap=False)
>>> clf.fit(X, y)
BalancedRandomForestClassifier(...)
>>> print(clf.feature_importances_)
Expand All @@ -415,6 +424,7 @@ class labels (multi-output problem).

_parameter_constraints.update(
{
"bootstrap": ["boolean", Hidden(StrOptions({"warn"}))],
"sampling_strategy": [
Interval(numbers.Real, 0, 1, closed="right"),
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
Expand All @@ -438,7 +448,7 @@ def __init__(
max_features="sqrt",
max_leaf_nodes=None,
min_impurity_decrease=0.0,
bootstrap=True,
bootstrap="warn",
oob_score=False,
sampling_strategy="warn",
replacement="warn",
Expand Down Expand Up @@ -566,6 +576,18 @@ def fit(self, X, y, sample_weight=None):
else:
self._replacement = self.replacement

if self.bootstrap == "warn":
warn(
"The default of `bootstrap` will change from `True` to "
"`False` in version 0.13. This change will follow the implementation "
"proposed in the original paper. Set to `False` to silence this "
"warning and adopt the future behaviour.",
FutureWarning,
)
self._bootstrap = True
else:
self._bootstrap = self.bootstrap

# Validate or convert input data
if issparse(y):
raise ValueError("sparse multilabel-indicator for y is not supported.")
Expand Down Expand Up @@ -629,7 +651,7 @@ def fit(self, X, y, sample_weight=None):
# Check parameters
self._validate_estimator()

if not self.bootstrap and self.oob_score:
if not self._bootstrap and self.oob_score:
raise ValueError("Out of bag estimation only available if bootstrap=True")

random_state = check_random_state(self.random_state)
Expand Down Expand Up @@ -681,7 +703,7 @@ def fit(self, X, y, sample_weight=None):
delayed(_local_parallel_build_trees)(
s,
t,
self.bootstrap,
self._bootstrap,
X,
y_encoded,
sample_weight,
Expand Down
37 changes: 28 additions & 9 deletions imblearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def imbalanced_dataset():

def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset):
brf = BalancedRandomForestClassifier(
n_estimators=5, sampling_strategy="all", replacement=True
n_estimators=5, sampling_strategy="all", replacement=True, bootstrap=False
)
brf.fit(*imbalanced_dataset)

Expand All @@ -51,6 +51,7 @@ def test_balanced_random_forest(imbalanced_dataset):
random_state=0,
sampling_strategy="all",
replacement=True,
bootstrap=False,
)
brf.fit(*imbalanced_dataset)

Expand All @@ -68,6 +69,7 @@ def test_balanced_random_forest_attributes(imbalanced_dataset):
random_state=0,
sampling_strategy="all",
replacement=True,
bootstrap=False,
)
brf.fit(X, y)

Expand All @@ -93,7 +95,11 @@ def test_balanced_random_forest_sample_weight(imbalanced_dataset):
X, y = imbalanced_dataset
sample_weight = rng.rand(y.shape[0])
brf = BalancedRandomForestClassifier(
n_estimators=5, random_state=0, sampling_strategy="all", replacement=True
n_estimators=5,
random_state=0,
sampling_strategy="all",
replacement=True,
bootstrap=False,
)
brf.fit(X, y, sample_weight)

Expand All @@ -111,6 +117,7 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
min_samples_leaf=2,
sampling_strategy="all",
replacement=True,
bootstrap=True,
)

est.fit(X_train, y_train)
Expand All @@ -132,7 +139,9 @@ def test_balanced_random_forest_oob(imbalanced_dataset):


def test_balanced_random_forest_grid_search(imbalanced_dataset):
brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True)
brf = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True, bootstrap=False
)
grid = GridSearchCV(brf, {"n_estimators": (1, 2), "max_depth": (1, 2)}, cv=3)
grid.fit(*imbalanced_dataset)

Expand All @@ -150,6 +159,7 @@ def test_little_tree_with_small_max_samples():
max_samples=None,
sampling_strategy="all",
replacement=True,
bootstrap=True,
)

# Second fit with max samples restricted to just 2
Expand All @@ -159,6 +169,7 @@ def test_little_tree_with_small_max_samples():
max_samples=2,
sampling_strategy="all",
replacement=True,
bootstrap=True,
)

est1.fit(X, y)
Expand All @@ -172,12 +183,14 @@ def test_little_tree_with_small_max_samples():


def test_balanced_random_forest_pruning(imbalanced_dataset):
brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True)
brf = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True, bootstrap=False
)
brf.fit(*imbalanced_dataset)
n_nodes_no_pruning = brf.estimators_[0].tree_.node_count

brf_pruned = BalancedRandomForestClassifier(
ccp_alpha=0.015, sampling_strategy="all", replacement=True
ccp_alpha=0.015, sampling_strategy="all", replacement=True, bootstrap=False
)
brf_pruned.fit(*imbalanced_dataset)
n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count
Expand All @@ -200,6 +213,7 @@ def test_balanced_random_forest_oob_binomial(ratio):
random_state=42,
sampling_strategy="not minority",
replacement=False,
bootstrap=True,
)
erf.fit(X, y)
assert np.abs(erf.oob_score_ - 0.5) < 0.1
Expand All @@ -209,7 +223,7 @@ def test_balanced_bagging_classifier_n_features():
"""Check that we raise a FutureWarning when accessing `n_features_`."""
X, y = load_iris(return_X_y=True)
estimator = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True
sampling_strategy="all", replacement=True, bootstrap=False
).fit(X, y)
with pytest.warns(FutureWarning, match="`n_features_` was deprecated"):
estimator.n_features_
Expand All @@ -222,7 +236,7 @@ def test_balanced_random_forest_classifier_base_estimator():
"""Check that we raise a FutureWarning when accessing `base_estimator_`."""
X, y = load_iris(return_X_y=True)
estimator = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True
sampling_strategy="all", replacement=True, bootstrap=False
).fit(X, y)
with pytest.warns(FutureWarning, match="`base_estimator_` was deprecated"):
estimator.base_estimator_
Expand All @@ -233,9 +247,14 @@ def test_balanced_random_forest_change_behaviour(imbalanced_dataset):
"""Check that we raise a change of behaviour for the parameters `sampling_strategy`
and `replacement`.
"""
estimator = BalancedRandomForestClassifier(sampling_strategy="all")
estimator = BalancedRandomForestClassifier(sampling_strategy="all", bootstrap=False)
with pytest.warns(FutureWarning, match="The default of `replacement`"):
estimator.fit(*imbalanced_dataset)
estimator = BalancedRandomForestClassifier(replacement=True)
estimator = BalancedRandomForestClassifier(replacement=True, bootstrap=False)
with pytest.warns(FutureWarning, match="The default of `sampling_strategy`"):
estimator.fit(*imbalanced_dataset)
estimator = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True
)
with pytest.warns(FutureWarning, match="The default of `bootstrap`"):
estimator.fit(*imbalanced_dataset)
2 changes: 1 addition & 1 deletion imblearn/tests/test_docstring_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_fit_docstring_attributes(name, Estimator):
X = _enforce_estimator_tags_x(est, X)

if "oob_score" in est.get_params():
est.set_params(oob_score=True)
est.set_params(bootstrap=True, oob_score=True)

if is_sampler(est):
est.fit_resample(X, y)
Expand Down
2 changes: 1 addition & 1 deletion imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _set_checking_parameters(estimator):
if name == "BalancedRandomForestClassifier":
# TODO: remove in 0.13
# future default in 0.13
estimator.set_params(replacement=True, sampling_strategy="all")
estimator.set_params(replacement=True, sampling_strategy="all", bootstrap=False)


def _yield_sampler_checks(sampler):
Expand Down

0 comments on commit d8cf8d6

Please sign in to comment.