Skip to content

Commit

Permalink
DOC/FIX fix Tomek links example (scikit-learn-contrib#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored and chkoar committed Mar 20, 2017
1 parent 67128f8 commit 2508f96
Showing 1 changed file with 36 additions and 42 deletions.
78 changes: 36 additions & 42 deletions examples/under-sampling/plot_tomek_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,47 @@
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_classification
from sklearn.decomposition import PCA

from sklearn.datasets import make_blobs

from imblearn.under_sampling import TomekLinks

print(__doc__)

sns.set()

# Define some color for the plotting
almost_black = '#262626'
palette = sns.color_palette()


# Generate the dataset
X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9],
n_informative=3, n_redundant=1, flip_y=0,
n_features=20, n_clusters_per_class=1,
n_samples=5000, random_state=10)

# Instanciate a PCA object for the sake of easy visualisation
pca = PCA(n_components=2)
# Fit and transform x to visualise inside a 2D feature space
X_vis = pca.fit_transform(X)

# Apply Tomek Links cleaning
tl = TomekLinks()
X_resampled, y_resampled = tl.fit_sample(X, y)
X_res_vis = pca.transform(X_resampled)

# Two subplots, unpack the axes array immediately
f, (ax1, ax2) = plt.subplots(1, 2)

ax1.scatter(X_vis[y == 0, 0], X_vis[y == 0, 1], label="Class #0", alpha=0.5,
edgecolor=almost_black, facecolor=palette[0], linewidth=0.15)
ax1.scatter(X_vis[y == 1, 0], X_vis[y == 1, 1], label="Class #1", alpha=0.5,
edgecolor=almost_black, facecolor=palette[2], linewidth=0.15)
ax1.set_title('Original set')

ax2.scatter(X_res_vis[y_resampled == 0, 0], X_res_vis[y_resampled == 0, 1],
label="Class #0", alpha=.5, edgecolor=almost_black,
facecolor=palette[0], linewidth=0.15)
ax2.scatter(X_res_vis[y_resampled == 1, 0], X_res_vis[y_resampled == 1, 1],
label="Class #1", alpha=.5, edgecolor=almost_black,
facecolor=palette[2], linewidth=0.15)
ax2.set_title('Tomek links')

# create a synthetic dataset
X, y = make_blobs(n_samples=500, centers=2, n_features=2,
random_state=0, center_box=(-5.0, 5.0))

# remove Tomek links
tl = TomekLinks(return_indices=True)
X_resampled, y_resampled, idx_resampled = tl.fit_sample(X, y)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)

idx_class_0 = np.flatnonzero(y_resampled == 0)
idx_class_1 = np.flatnonzero(y_resampled == 1)
idx_samples_removed = np.setdiff1d(np.flatnonzero(y == 1),
np.union1d(idx_class_0, idx_class_1))

plt.scatter(X[idx_class_0, 0], X[idx_class_0, 1],
c='g', alpha=.8, label='Class #0')
plt.scatter(X[idx_class_1, 0], X[idx_class_1, 1],
c='b', alpha=.8, label='Class #1')
plt.scatter(X[idx_samples_removed, 0], X[idx_samples_removed, 1],
c='r', alpha=.8, label='Removed samples')

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.spines['left'].set_position(('outward', 10))
ax.spines['bottom'].set_position(('outward', 10))

plt.title('Under-sampling removing Tomek links')
plt.legend()

plt.show()

0 comments on commit 2508f96

Please sign in to comment.