-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
84 lines (59 loc) · 2.89 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def compute_metrics(model, test_loader, plot_roc_curve = False):
model.eval()
val_loss = 0
val_correct = 0
criterion = nn.CrossEntropyLoss()
score_list = torch.Tensor([]).to(device)
pred_list = torch.Tensor([]).to(device).long()
target_list = torch.Tensor([]).to(device).long()
path_list = []
for iter_num, data in enumerate(test_loader):
# Convert image data into single channel data
image, target = data['img'].to(device), data['label'].to(device)
paths = data['paths']
path_list.extend(paths)
# Compute the loss
with torch.no_grad():
output = model(image)
# Log loss
val_loss += criterion(output, target.long()).item()
# Calculate the number of correctly classified examples
pred = output.argmax(dim=1, keepdim=True)
val_correct += pred.eq(target.long().view_as(pred)).sum().item()
# Bookkeeping
score_list = torch.cat([score_list, nn.Softmax(dim = 1)(output)[:,1].squeeze()])
pred_list = torch.cat([pred_list, pred.squeeze()])
target_list = torch.cat([target_list, target.squeeze()])
classification_metrics = classification_report(target_list.tolist(), pred_list.tolist(),
target_names = ['CT_NonCOVID', 'CT_COVID'],
output_dict= True)
# sensitivity is the recall of the positive class
sensitivity = classification_metrics['CT_COVID']['recall']
# specificity is the recall of the negative class
specificity = classification_metrics['CT_NonCOVID']['recall']
# accuracy
accuracy = classification_metrics['accuracy']
# confusion matrix
conf_matrix = confusion_matrix(target_list.tolist(), pred_list.tolist())
# roc score
roc_score = roc_auc_score(target_list.tolist(), score_list.tolist())
# plot the roc curve
if plot_roc_curve:
fpr, tpr, _ = roc_curve(target_list.tolist(), score_list.tolist())
plt.plot(fpr, tpr, label = "Area under ROC = {:.4f}".format(roc_score))
plt.legend(loc = 'best')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.show()
# put together values
metrics_dict = {"Accuracy": accuracy,
"Sensitivity": sensitivity,
"Specificity": specificity,
"Roc_score" : roc_score,
"Confusion Matrix": conf_matrix,
"Validation Loss": val_loss / len(test_loader),
"score_list": score_list.tolist(),
"pred_list": pred_list.tolist(),
"target_list": target_list.tolist(),
"paths": path_list}
return metrics_dict