diff --git a/examples/model_selection/plot_confusion_matrix.py b/examples/model_selection/plot_confusion_matrix.py index 250d71c08c442..233c72f658ba5 100644 --- a/examples/model_selection/plot_confusion_matrix.py +++ b/examples/model_selection/plot_confusion_matrix.py @@ -26,6 +26,7 @@ print(__doc__) +import itertools import numpy as np import matplotlib.pyplot as plt @@ -37,6 +38,7 @@ iris = datasets.load_iris() X = iris.data y = iris.target +class_names = iris.target_names # Split the data into a training set and a test set X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) @@ -47,32 +49,51 @@ y_pred = classifier.fit(X_train, y_train).predict(X_test) -def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues): +def plot_confusion_matrix(cm, classes, + normalize=False, + title='Confusion matrix', + cmap=plt.cm.Blues): + """ + This function prints and plots the confusion matrix. + Normalization can be applied by setting `normalize=True`. + """ plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() - tick_marks = np.arange(len(iris.target_names)) - plt.xticks(tick_marks, iris.target_names, rotation=45) - plt.yticks(tick_marks, iris.target_names) + tick_marks = np.arange(len(classes)) + plt.xticks(tick_marks, classes, rotation=45) + plt.yticks(tick_marks, classes) + + if normalize: + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + print("Normalized confusion matrix") + else: + print('Confusion matrix, without normalization') + + print(cm) + + thresh = cm.max() / 2. + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + plt.text(j, i, cm[i, j], + horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black") + plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label') - # Compute confusion matrix -cm = confusion_matrix(y_test, y_pred) +cnf_matrix = confusion_matrix(y_test, y_pred) np.set_printoptions(precision=2) -print('Confusion matrix, without normalization') -print(cm) + +# Plot non-normalized confusion matrix plt.figure() -plot_confusion_matrix(cm) +plot_confusion_matrix(cnf_matrix, classes=class_names, + title='Confusion matrix, without normalization') -# Normalize the confusion matrix by row (i.e by the number of samples -# in each class) -cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] -print('Normalized confusion matrix') -print(cm_normalized) +# Plot normalized confusion matrix plt.figure() -plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix') +plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, + title='Normalized confusion matrix') plt.show()