-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_matrix.py
35 lines (31 loc) · 1.51 KB
/
plot_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import matplotlib.pyplot as pl
import numpy as np
from sklearn import metrics
# 相关库
def plot_matrix(y_true, y_pred, labels_name, title=None, thresh=0.8, axis_labels=None):
# 利用sklearn中的函数生成混淆矩阵并归一化
cm = metrics.confusion_matrix(y_true, y_pred, labels=labels_name, sample_weight=None) # 生成混淆矩阵
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # 归一化
# 画图,如果希望改变颜色风格,可以改变此部分的cmap=pl.get_cmap('Blues')处
pl.imshow(cm, interpolation='nearest', cmap=pl.get_cmap('Blues'))
pl.colorbar() # 绘制图例
# 图像标题
if title is not None:
pl.title(title)
# 绘制坐标
num_local = np.array(range(len(labels_name)))
if axis_labels is None:
axis_labels = labels_name
pl.xticks(num_local, axis_labels, rotation=45) # 将标签印在x轴坐标上, 并倾斜45度
pl.yticks(num_local, axis_labels) # 将标签印在y轴坐标上
pl.ylabel('True label')
pl.xlabel('Predicted label')
# 将百分比打印在相应的格子内,大于thresh的用白字,小于的用黑字
for i in range(np.shape(cm)[0]):
for j in range(np.shape(cm)[1]):
if int(cm[i][j] * 100 + 0.5) > 0:
pl.text(j, i, format(int(cm[i][j] * 100 + 0.5), 'd') + '%',
ha="center", va="center",
color="white" if cm[i][j] > thresh else "black") # 如果要更改颜色风格,需要同时更改此行
# 显示
pl.show()