Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples/datasets/plot_make_imbalance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
===========================
make_imbalance function
===========================

An illustration of the make_imbalance function

"""

print(__doc__)

import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

# Define some color for the plotting
almost_black = '#262626'
palette = sns.color_palette()

from sklearn.datasets import make_moons
from imblearn.datasets import make_imbalance


# Generate the dataset
X, y = make_moons(n_samples=200, shuffle=True, noise=0.5, random_state=10)

# Two subplots, unpack the axes array immediately
f, axs = plt.subplots(2, 3)

axs = [a for ax in axs for a in ax]

axs[0].scatter(X[y == 0, 0], X[y == 0, 1], label="Class #0",
alpha=0.5, edgecolor=almost_black, facecolor=palette[0],
linewidth=0.15)
axs[0].scatter(X[y == 1, 0], X[y == 1, 1], label="Class #1",
alpha=0.5, edgecolor=almost_black, facecolor=palette[2],
linewidth=0.15)
axs[0].set_title('Original set')

ratios = [0.9, 0.75, 0.5, 0.25, 0.1]
for i, ratio in enumerate(ratios, start=1):
ax = axs[i]

X_, y_ = make_imbalance(X, y, ratio=ratio, min_c_=1)

ax.scatter(X_[y_ == 0, 0], X_[y_ == 0, 1], label="Class #0",
alpha=0.5, edgecolor=almost_black, facecolor=palette[0],
linewidth=0.15)
ax.scatter(X_[y_ == 1, 0], X_[y_ == 1, 1], label="Class #1",
alpha=0.5, edgecolor=almost_black, facecolor=palette[2],
linewidth=0.15)
ax.set_title('make_imbalance ratio ({})'.format(ratio))

plt.show()
8 changes: 8 additions & 0 deletions imblearn/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
The :mod:`imblearn.datasets` provides methods to generate
imbalanced data.
"""

from .imbalance import make_imbalance

__all__ = ['make_imbalance']
74 changes: 74 additions & 0 deletions imblearn/datasets/imbalance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Transform a dataset into an imbalanced dataset."""

import numpy as np

from collections import Counter

from sklearn.utils import check_X_y
from sklearn.utils import check_random_state

def make_imbalance(X, y, ratio, min_c_=None, random_state=None):
"""Turns a dataset into an imbalanced dataset at specific ratio.
A simple toy dataset to visualize clustering and classification
algorithms.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data to be imbalanced.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

ratio : float,
The desired ratio given by the number of samples in
the minority class over the the number of samples in
the majority class.

min_c_ : str or int, optional (default=None)
The identifier of the class to be the minority class.
If None, min_c_ is set to be the current minority class.

random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by np.random.

Returns
-------
X_resampled : ndarray, shape (n_samples_new, n_features)
The array containing the imbalanced data.

y_resampled : ndarray, shape (n_samples_new)
The corresponding label of `X_resampled`
"""
if ratio <= 0.0 or ratio >= 1.0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would check the function _validate_ratio to be sure about the boundary values.

raise ValueError('ratio value must be such that 0.0 < ratio < 1.0')

X, y = check_X_y(X, y)

random_state = check_random_state(random_state)

stats_c_ = Counter(y)

if min_c_ is None:
min_c_ = min(stats_c_, key=stats_c_.get)

n_min_samples = int(np.count_nonzero(y != min_c_) * ratio)
if n_min_samples > stats_c_[min_c_]:
raise ValueError('Current imbalance ratio of data is lower than desired ratio!')
if n_min_samples == 0:
raise ValueError('Not enough samples for desired ratio!')

mask = y == min_c_

idx_maj = np.where(~mask)[0]
idx_min = np.where(mask)[0]
idx_min = random_state.choice(idx_min, size=n_min_samples, replace=False)
idx = np.concatenate((idx_min, idx_maj), axis=0)

X_resampled, y_resampled = X[idx,:], y[idx]

return X_resampled, y_resampled

Empty file.
108 changes: 108 additions & 0 deletions imblearn/datasets/tests/test_make_imbalance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Test the module easy ensemble."""
from __future__ import print_function

import numpy as np
from numpy.testing import assert_raises
from numpy.testing import assert_equal

from collections import Counter

from imblearn.datasets import make_imbalance

# Generate a global dataset to use
X = np.random.random((1000, 2))
Y = np.zeros(1000)
Y[500:] = 1

def test_make_imbalance_bad_ratio():
"""Test either if an error is raised with bad ratio
argument"""
min_c_ = 1

# Define a zero ratio
ratio = 0.0
assert_raises(ValueError, make_imbalance, X, Y, ratio, min_c_)

# Define a negative ratio
ratio = -2.0
assert_raises(ValueError, make_imbalance, X, Y, ratio, min_c_)

# Define a ratio greater than 1
ratio = 2.0
assert_raises(ValueError, make_imbalance, X, Y, ratio, min_c_)

# Define ratio as a list which is not supported
ratio = [.5, .5]
assert_raises(ValueError, make_imbalance, X, Y, ratio, min_c_)


def test_make_imbalance_invalid_ratio():
"""Test either if error is raised with higher ratio
than current ratio."""

y_ = np.zeros((X.shape[0], ))
y_[0] = 1

ratio = 0.5
assert_raises(ValueError, make_imbalance, X, y_, ratio)

def test_make_imbalance_single_class():
"""Test either if an error when there is a single class"""
y_ = np.zeros((X.shape[0], ))
ratio = 0.5
assert_raises(ValueError, make_imbalance, X, y_, ratio)

def test_make_imbalance_1():
"""Test make_imbalance"""
X_, y_ = make_imbalance(X, Y, ratio=0.5, min_c_=1)
counter = Counter(y_)
assert_equal(counter[0], 500)
assert_equal(counter[1], 250)
assert(np.all([X_i in X for X_i in X_]))

def test_make_imbalance_2():
"""Test make_imbalance"""
X_, y_ = make_imbalance(X, Y, ratio=0.25, min_c_=1)
counter = Counter(y_)
assert_equal(counter[0], 500)
assert_equal(counter[1], 125)
assert(np.all([X_i in X for X_i in X_]))

def test_make_imbalance_3():
"""Test make_imbalance"""
X_, y_ = make_imbalance(X, Y, ratio=0.1, min_c_=1)
counter = Counter(y_)
assert_equal(counter[0], 500)
assert_equal(counter[1], 50)
assert(np.all([X_i in X for X_i in X_]))

def test_make_imbalance_4():
"""Test make_imbalance"""
X_, y_ = make_imbalance(X, Y, ratio=0.01, min_c_=1)
counter = Counter(y_)
assert_equal(counter[0], 500)
assert_equal(counter[1], 5)
assert(np.all([X_i in X for X_i in X_]))

def test_make_imbalance_5():
"""Test make_imbalance"""
X_, y_ = make_imbalance(X, Y, ratio=0.01, min_c_=0)
counter = Counter(y_)
assert_equal(counter[1], 500)
assert_equal(counter[0], 5)
assert(np.all([X_i in X for X_i in X_]))

def test_make_imbalance_multiclass():
"""Test make_imbalance with multiclass data"""
# Make y to be multiclass
y_ = np.zeros(1000)
y_[100:500] = 1
y_[500:] = 2

# Resample the data
X_, y_ = make_imbalance(X, y_, ratio=0.1, min_c_=0)
counter = Counter(y_)
assert_equal(counter[0], 90)
assert_equal(counter[1], 400)
assert_equal(counter[2], 500)
assert(np.all([X_i in X for X_i in X_]))