Skip to content

Commit

Permalink
FIX remove smoothed_bootstrap and use only shrinkage param (scikit-le…
Browse files Browse the repository at this point in the history
…arn-contrib#794)

Co-authored-by: Christos Aridas <chkoar@users.noreply.github.com>
  • Loading branch information
glemaitre and chkoar authored Feb 12, 2021
1 parent 3444430 commit 1130324
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 79 deletions.
12 changes: 6 additions & 6 deletions doc/over_sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ It would also work with pandas dataframe::
>>> df_resampled, y_resampled = ros.fit_resample(df_adult, y_adult)
>>> df_resampled.head() # doctest: +SKIP

If repeating samples is an issue, the parameter `smoothed_bootstrap` can be
turned to `True` to create a smoothed bootstrap. However, the original data
needs to be numerical. The `shrinkage` parameter controls the dispersion of the
new generated samples. We show an example illustrate that the new samples are
not overlapping anymore once using a smoothed bootstrap. This ways of
generating smoothed bootstrap is also known a Random Over-Sampler Examples
If repeating samples is an issue, the parameter `shrinkage` allows to create a
smoothed bootstrap. However, the original data needs to be numerical. The
`shrinkage` parameter controls the dispersion of the new generated samples. We
show an example illustrate that the new samples are not overlapping anymore
once using a smoothed bootstrap. This ways of generating smoothed bootstrap is
also known a Random Over-Sampling Examples
(ROSE) :cite:`torelli2014rose`.

.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_003.png
Expand Down
4 changes: 2 additions & 2 deletions doc/whats_new/v0.7.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ Enhancements

- Added an option to generate smoothed bootstrap in
:class:`imblearn.over_sampling.RandomOverSampler`. It is controls by the
parameters `smoothed_bootstrap` and `shrinkage`. This method is also known as
Random Over-Sampling Examples (ROSE).
parameter `shrinkage`. This method is also known as Random Over-Sampling
Examples (ROSE).
:pr:`754` by :user:`Andrea Lorenzon <andrealorenzon>` and
:user:`Guillaume Lemaitre <glemaitre>`.

Expand Down
4 changes: 2 additions & 2 deletions examples/over-sampling/plot_comparison_over_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,15 @@ def plot_decision_function(X, y, clf, ax):

###############################################################################
# By default, random over-sampling generates a bootstrap. The parameter
# `smoothed_bootstrap` allows adding a small perturbation to the generated data
# `shrinkage` allows adding a small perturbation to the generated data
# to generate a smoothed bootstrap instead. The plot below shows the difference
# between the two data generation strategies.

fig, axs = plt.subplots(1, 2, figsize=(15, 7))
sampler = RandomOverSampler(random_state=0)
plot_resampling(X, y, sampler, ax=axs[0])
axs[0].set_title("RandomOverSampler with normal bootstrap")
sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=0.2, random_state=0)
sampler = RandomOverSampler(shrinkage=0.2, random_state=0)
plot_resampling(X, y, sampler, ax=axs[1])
axs[1].set_title("RandomOverSampler with smoothed bootstrap")
fig.tight_layout()
Expand Down
8 changes: 4 additions & 4 deletions examples/over-sampling/plot_shrinkage_effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@
# from the majority class. Indeed, it is due to the fact that these samples
# of the minority class are repeated during the bootstrap generation.
#
# We can set `smoothed_bootstrap=True` to add a small perturbation to the
# We can set `shrinkage` to a floating value to add a small perturbation to the
# samples created and therefore create a smoothed bootstrap.
sampler = RandomOverSampler(smoothed_bootstrap=True, random_state=0)
sampler = RandomOverSampler(shrinkage=1, random_state=0)
X_res, y_res = sampler.fit_resample(X, y)
Counter(y_res)

Expand All @@ -81,7 +81,7 @@
#
# The parameter `shrinkage` allows to add more or less perturbation. Let's
# add more perturbation when generating the smoothed bootstrap.
sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=3, random_state=0)
sampler = RandomOverSampler(shrinkage=3, random_state=0)
X_res, y_res = sampler.fit_resample(X, y)
Counter(y_res)

Expand All @@ -96,7 +96,7 @@
# %%
# Increasing the value of `shrinkage` will disperse the new samples. Forcing
# the shrinkage to 0 will be equivalent to generating a normal bootstrap.
sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=0, random_state=0)
sampler = RandomOverSampler(shrinkage=0, random_state=0)
X_res, y_res = sampler.fit_resample(X, y)
Counter(y_res)

Expand Down
78 changes: 45 additions & 33 deletions imblearn/over_sampling/_random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Christos Aridas
# License: MIT

from collections.abc import Mapping
from numbers import Real

import numpy as np
Expand Down Expand Up @@ -37,20 +38,20 @@ class RandomOverSampler(BaseOverSampler):
{random_state}
smoothed_bootstrap : bool, default=False
Whether or not to generate smoothed bootstrap samples. When this option
is triggered, be aware that the data to be resampled needs to be
numerical data since a Gaussian perturbation will be generated and
added to the bootstrap.
shrinkage : float or dict, default=None
Parameter controlling the shrinkage applied to the covariance matrix.
when a smoothed bootstrap is generated. The options are:
.. versionadded:: 0.7
- if `None`, a normal bootstrap will be generated without perturbation.
It is equivalent to `shrinkage=0` as well;
- if a `float` is given, the shrinkage factor will be used for all
classes to generate the smoothed bootstrap;
- if a `dict` is given, the shrinkage factor will specific for each
class. The key correspond to the targeted class and the value is
the shrinkage factor.
shrinkage : float or dict, default=1.0
Factor to shrink the covariance matrix used to generate the
smoothed bootstrap. A factor could be shared by all classes by
providing a floating number or different for each class over-sampled
by providing a dictionary where the key are the class targeted and the
value is the shrinkage factor.
The value needs of the shrinkage parameter needs to be higher or equal
to 0.
.. versionadded:: 0.7
Expand All @@ -63,7 +64,7 @@ class RandomOverSampler(BaseOverSampler):
shrinkage_ : dict or None
The per-class shrinkage factor used to generate the smoothed bootstrap
sample. `None` when `smoothed_bootstrap=False`.
sample. When `shrinkage=None` a normal bootstrap will be generated.
.. versionadded:: 0.7
Expand Down Expand Up @@ -125,12 +126,10 @@ def __init__(
*,
sampling_strategy="auto",
random_state=None,
smoothed_bootstrap=False,
shrinkage=1.0,
shrinkage=None,
):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.smoothed_bootstrap = smoothed_bootstrap
self.shrinkage = shrinkage

def _check_X_y(self, X, y):
Expand All @@ -148,34 +147,47 @@ def _check_X_y(self, X, y):
def _fit_resample(self, X, y):
random_state = check_random_state(self.random_state)

if self.smoothed_bootstrap:
if isinstance(self.shrinkage, Real):
self.shrinkage_ = {
klass: self.shrinkage for klass in self.sampling_strategy_
}
else:
missing_shrinkage_keys = (
self.sampling_strategy_.keys() - self.shrinkage.keys()
if isinstance(self.shrinkage, Real):
self.shrinkage_ = {
klass: self.shrinkage for klass in self.sampling_strategy_
}
elif self.shrinkage is None or isinstance(self.shrinkage, Mapping):
self.shrinkage_ = self.shrinkage
else:
raise ValueError(
f"`shrinkage` should either be a positive floating number or "
f"a dictionary mapping a class to a positive floating number. "
f"Got {repr(self.shrinkage)} instead."
)

if self.shrinkage_ is not None:
missing_shrinkage_keys = (
self.sampling_strategy_.keys() - self.shrinkage_.keys()
)
if missing_shrinkage_keys:
raise ValueError(
f"`shrinkage` should contain a shrinkage factor for "
f"each class that will be resampled. The missing "
f"classes are: {repr(missing_shrinkage_keys)}"
)
if missing_shrinkage_keys:

for klass, shrink_factor in self.shrinkage_.items():
if shrink_factor < 0:
raise ValueError(
f"`shrinkage` should contain a shrinkage factor for "
f"each class that will be resampled. The missing "
f"classes are: {repr(missing_shrinkage_keys)}"
f"The shrinkage factor needs to be >= 0. "
f"Got {shrink_factor} for class {klass}."
)
self.shrinkage_ = self.shrinkage

# smoothed bootstrap imposes to make numerical operation; we need
# to be sure to have only numerical data in X
try:
X = check_array(X, accept_sparse=["csr", "csc"], dtype="numeric")
except ValueError as exc:
raise ValueError(
"When smoothed_bootstrap=True, X needs to contain only "
"When shrinkage is not None, X needs to contain only "
"numerical data to later generate a smoothed bootstrap "
"sample."
) from exc
else:
self.shrinkage_ = None

X_resampled = [X.copy()]
y_resampled = [y.copy()]
Expand All @@ -189,7 +201,7 @@ def _fit_resample(self, X, y):
replace=True,
)
sample_indices = np.append(sample_indices, bootstrap_indices)
if self.smoothed_bootstrap:
if self.shrinkage_ is not None:
# generate a smoothed bootstrap with a perturbation
n_samples, n_features = X.shape
smoothing_constant = (4 / ((n_features + 2) * n_samples)) ** (
Expand Down
56 changes: 24 additions & 32 deletions imblearn/over_sampling/tests/test_random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def test_ros_init():
assert ros.random_state == RND_SEED


@pytest.mark.parametrize(
"params",
[{"smoothed_bootstrap": False}, {"smoothed_bootstrap": True, "shrinkage": 0}]
)
@pytest.mark.parametrize("params", [{"shrinkage": None}, {"shrinkage": 0}])
@pytest.mark.parametrize("X_type", ["array", "dataframe"])
def test_ros_fit_resample(X_type, data, params):
X, Y = data
Expand Down Expand Up @@ -80,16 +77,13 @@ def test_ros_fit_resample(X_type, data, params):
assert_allclose(X_resampled, X_gt)
assert_array_equal(y_resampled, y_gt)

if not params["smoothed_bootstrap"]:
if params["shrinkage"] is None:
assert ros.shrinkage_ is None
else:
assert ros.shrinkage_ == {0: 0}


@pytest.mark.parametrize(
"params",
[{"smoothed_bootstrap": False}, {"smoothed_bootstrap": True, "shrinkage": 0}]
)
@pytest.mark.parametrize("params", [{"shrinkage": None}, {"shrinkage": 0}])
def test_ros_fit_resample_half(data, params):
X, Y = data
sampling_strategy = {0: 3, 1: 7}
Expand All @@ -115,16 +109,13 @@ def test_ros_fit_resample_half(data, params):
assert_allclose(X_resampled, X_gt)
assert_array_equal(y_resampled, y_gt)

if not params["smoothed_bootstrap"]:
if params["shrinkage"] is None:
assert ros.shrinkage_ is None
else:
assert ros.shrinkage_ == {0: 0, 1: 0}


@pytest.mark.parametrize(
"params",
[{"smoothed_bootstrap": False}, {"smoothed_bootstrap": True, "shrinkage": 0}]
)
@pytest.mark.parametrize("params", [{"shrinkage": None}, {"shrinkage": 0}])
def test_multiclass_fit_resample(data, params):
# check the random over-sampling with a multiclass problem
X, Y = data
Expand All @@ -138,7 +129,7 @@ def test_multiclass_fit_resample(data, params):
assert count_y_res[1] == 5
assert count_y_res[2] == 5

if not params["smoothed_bootstrap"]:
if params["shrinkage"] is None:
assert ros.shrinkage_ is None
else:
assert ros.shrinkage_ == {0: 0, 2: 0}
Expand Down Expand Up @@ -188,11 +179,8 @@ def test_random_over_sampling_heterogeneous_data_smoothed_bootstrap():
[["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=object
)
y = np.array([0, 0, 1])
ros = RandomOverSampler(
smoothed_bootstrap=True,
random_state=RND_SEED,
)
err_msg = "When smoothed_bootstrap=True, X needs to contain only numerical"
ros = RandomOverSampler(shrinkage=1, random_state=RND_SEED)
err_msg = "When shrinkage is not None, X needs to contain only numerical"
with pytest.raises(ValueError, match=err_msg):
ros.fit_resample(X_hetero, y)

Expand All @@ -201,7 +189,7 @@ def test_random_over_sampling_heterogeneous_data_smoothed_bootstrap():
def test_random_over_sampler_smoothed_bootstrap(X_type, data):
# check that smoothed bootstrap is working for numerical array
X, y = data
sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=1)
sampler = RandomOverSampler(shrinkage=1)
X = _convert_container(X, X_type)
X_res, y_res = sampler.fit_resample(X, y)

Expand All @@ -217,10 +205,8 @@ def test_random_over_sampler_equivalence_shrinkage(data):
# bootstrap
X, y = data

ros_not_shrink = RandomOverSampler(
smoothed_bootstrap=True, shrinkage=0, random_state=0
)
ros_hard_bootstrap = RandomOverSampler(smoothed_bootstrap=False, random_state=0)
ros_not_shrink = RandomOverSampler(shrinkage=0, random_state=0)
ros_hard_bootstrap = RandomOverSampler(shrinkage=None, random_state=0)

X_res_not_shrink, y_res_not_shrink = ros_not_shrink.fit_resample(X, y)
X_res, y_res = ros_hard_bootstrap.fit_resample(X, y)
Expand All @@ -240,7 +226,7 @@ def test_random_over_sampler_shrinkage_behaviour(data):
# should also be larger.
X, y = data

ros = RandomOverSampler(smoothed_bootstrap=True, shrinkage=1, random_state=0)
ros = RandomOverSampler(shrinkage=1, random_state=0)
X_res_shink_1, y_res_shrink_1 = ros.fit_resample(X, y)

ros.set_params(shrinkage=5)
Expand All @@ -252,12 +238,18 @@ def test_random_over_sampler_shrinkage_behaviour(data):
assert disperstion_shrink_1 < disperstion_shrink_5


def test_random_over_sampler_shrinkage_error(data):
# check that we raise proper error when shrinkage do not contain the
# necessary information
@pytest.mark.parametrize(
"shrinkage, err_msg",
[
({}, "`shrinkage` should contain a shrinkage factor for each class"),
(-1, "The shrinkage factor needs to be >= 0"),
({0: -1}, "The shrinkage factor needs to be >= 0"),
([1, ], "`shrinkage` should either be a positive floating number or")
]
)
def test_random_over_sampler_shrinkage_error(data, shrinkage, err_msg):
# check the validation of the shrinkage parameter
X, y = data
shrinkage = {}
ros = RandomOverSampler(smoothed_bootstrap=True, shrinkage=shrinkage)
err_msg = "`shrinkage` should contain a shrinkage factor for each class"
ros = RandomOverSampler(shrinkage=shrinkage)
with pytest.raises(ValueError, match=err_msg):
ros.fit_resample(X, y)

0 comments on commit 1130324

Please sign in to comment.