Skip to content

Commit

Permalink
DOC/EXA solve Tomek examples (scikit-learn-contrib#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Mar 22, 2017
1 parent 8d9d21a commit a951029
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
1 change: 0 additions & 1 deletion examples/applications/plot_over_sampling_benchmark_lfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def fit_sample(self, X, y):
y[y == majority_person] = 0
y[y == minority_person] = 1


classifier = ['3NN', neighbors.KNeighborsClassifier(3)]

samplers = [
Expand Down
46 changes: 27 additions & 19 deletions examples/under-sampling/plot_tomek_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,56 @@
===========
An illustration of the Tomek links method.
"""

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

from imblearn.under_sampling import TomekLinks

print(__doc__)


# create a synthetic dataset
X, y = make_blobs(n_samples=500, centers=2, n_features=2,
random_state=0, center_box=(-5.0, 5.0))
rng = np.random.RandomState(0)
n_samples_1 = 500
n_samples_2 = 50
X_syn = np.r_[1.5 * rng.randn(n_samples_1, 2),
0.5 * rng.randn(n_samples_2, 2) + [2, 2]]
y_syn = np.array([0] * (n_samples_1) + [1] * (n_samples_2))
X_syn, y_syn = shuffle(X_syn, y_syn)
X_syn_train, X_syn_test, y_syn_train, y_syn_test = train_test_split(X_syn,
y_syn)

# remove Tomek links
tl = TomekLinks(return_indices=True)
X_resampled, y_resampled, idx_resampled = tl.fit_sample(X, y)
X_resampled, y_resampled, idx_resampled = tl.fit_sample(X_syn, y_syn)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)

idx_class_0 = np.flatnonzero(y_resampled == 0)
idx_class_1 = np.flatnonzero(y_resampled == 1)
idx_samples_removed = np.setdiff1d(np.flatnonzero(y == 1),
np.union1d(idx_class_0, idx_class_1))

plt.scatter(X[idx_class_0, 0], X[idx_class_0, 1],
c='g', alpha=.8, label='Class #0')
plt.scatter(X[idx_class_1, 0], X[idx_class_1, 1],
c='b', alpha=.8, label='Class #1')
plt.scatter(X[idx_samples_removed, 0], X[idx_samples_removed, 1],
c='r', alpha=.8, label='Removed samples')

idx_samples_removed = np.setdiff1d(np.arange(X_syn.shape[0]),
idx_resampled)
idx_class_0 = y_resampled == 0
plt.scatter(X_resampled[idx_class_0, 0], X_resampled[idx_class_0, 1],
alpha=.8, label='Class #0')
plt.scatter(X_resampled[~idx_class_0, 0], X_resampled[~idx_class_0, 1],
alpha=.8, label='Class #1')
plt.scatter(X_syn[idx_samples_removed, 0], X_syn[idx_samples_removed, 1],
alpha=.8, label='Removed samples')

# make nice plotting
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.spines['left'].set_position(('outward', 10))
ax.spines['bottom'].set_position(('outward', 10))
ax.set_xlim([-5, 5])
ax.set_ylim([-5, 5])
plt.yticks(range(-5, 6))
plt.xticks(range(-5, 6))

plt.title('Under-sampling removing Tomek links')
plt.legend()
Expand Down

0 comments on commit a951029

Please sign in to comment.