Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make colorbar optional in plot_confusion_matrix() #114

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions scikitplot/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,67 +34,52 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
pred_labels=None, title=None, normalize=False,
hide_zeros=False, hide_counts=False, x_tick_rotation=0, ax=None,
figsize=None, cmap='Blues', title_fontsize="large",
text_fontsize="medium"):
text_fontsize="medium", colorbar=True):
"""Generates confusion matrix plot from predictions and true labels

Args:
y_true (array-like, shape (n_samples)):
Ground truth (correct) target values.

y_pred (array-like, shape (n_samples)):
Estimated targets as returned by a classifier.

labels (array-like, shape (n_classes), optional): List of labels to
index the matrix. This may be used to reorder or select a subset
of labels. If none is given, those that appear at least once in
``y_true`` or ``y_pred`` are used in sorted order. (new in v0.2.5)

true_labels (array-like, optional): The true labels to display.
If none is given, then all of the labels are used.

pred_labels (array-like, optional): The predicted labels to display.
If none is given, then all of the labels are used.

title (string, optional): Title of the generated plot. Defaults to
"Confusion Matrix" if `normalize` is True. Else, defaults to
"Normalized Confusion Matrix.

normalize (bool, optional): If True, normalizes the confusion matrix
before plotting. Defaults to False.

hide_zeros (bool, optional): If True, does not plot cells containing a
value of zero. Defaults to False.

hide_counts (bool, optional): If True, doe not overlay counts.
Defaults to False.

x_tick_rotation (int, optional): Rotates x-axis tick labels by the
specified angle. This is useful in cases where there are numerous
categories and the labels overlap each other.

ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
plot the curve. If None, the plot is drawn on a new set of axes.

figsize (2-tuple, optional): Tuple denoting figure size of the plot
e.g. (6, 6). Defaults to ``None``.

cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
Colormap used for plotting the projection. View Matplotlib Colormap
documentation for available options.
https://matplotlib.org/users/colormaps.html

title_fontsize (string or int, optional): Matplotlib-style fontsizes.
Use e.g. "small", "medium", "large" or integer-values. Defaults to
"large".

text_fontsize (string or int, optional): Matplotlib-style fontsizes.
Use e.g. "small", "medium", "large" or integer-values. Defaults to
"medium".

colorbar (bool, optional): If False, does not add colour bar.
Defaults to True.
Returns:
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
drawn.

Example:
>>> import scikitplot as skplt
>>> rf = RandomForestClassifier()
Expand All @@ -103,7 +88,6 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
>>> skplt.metrics.plot_confusion_matrix(y_test, y_pred, normalize=True)
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
>>> plt.show()

.. image:: _static/examples/plot_confusion_matrix.png
:align: center
:alt: Confusion matrix
Expand Down Expand Up @@ -153,7 +137,10 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
ax.set_title('Confusion Matrix', fontsize=title_fontsize)

image = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.get_cmap(cmap))
plt.colorbar(mappable=image)

if colorbar == True:
plt.colorbar(mappable=image)

x_tick_marks = np.arange(len(pred_classes))
y_tick_marks = np.arange(len(true_classes))
ax.set_xticks(x_tick_marks)
Expand Down