-
Notifications
You must be signed in to change notification settings - Fork 7
/
validation.py
118 lines (96 loc) · 4.22 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import utils
from torch import nn
import torch
def validation_binary(model, criterion, valid_loader, num_classes=None):
with torch.no_grad():
model.eval()
losses = []
jaccard = []
for inputs, targets in valid_loader:
inputs = utils.cuda(inputs)
targets = utils.cuda(targets)
outputs = model(inputs)
loss = criterion(outputs, targets)
losses.append(loss.item())
jaccard += get_jaccard(targets, (outputs > 0).float())
valid_loss = np.mean(losses) # type: float
valid_jaccard = np.mean(jaccard).astype(np.float64)
print('Valid loss: {:.5f}, jaccard: {:.5f}'.format(valid_loss, valid_jaccard))
metrics = {'valid_loss': valid_loss, 'jaccard_loss': valid_jaccard}
return metrics
def get_jaccard(y_true, y_pred):
epsilon = 1e-15
intersection = (y_pred * y_true).sum(dim=-2).sum(dim=-1)
union = y_true.sum(dim=-2).sum(dim=-1) + y_pred.sum(dim=-2).sum(dim=-1)
return list(((intersection + epsilon) / (union - intersection + epsilon)).data.cpu().numpy())
def validation_multi(model: nn.Module, criterion, valid_loader, num_classes):
with torch.no_grad():
model.eval()
losses = []
confusion_matrix = np.zeros(
(num_classes, num_classes), dtype=np.uint32)
for inputs, targets in valid_loader:
inputs = utils.cuda(inputs)
targets = utils.cuda(targets)
outputs = model(inputs)
loss = criterion(outputs, targets)
losses.append(loss.item())
output_classes = outputs.data.cpu().numpy().argmax(axis=1)
target_classes = targets.data.cpu().numpy()
confusion_matrix += calculate_confusion_matrix_from_arrays(
output_classes, target_classes, num_classes)
confusion_matrix = confusion_matrix[1:, 1:] # exclude background
valid_loss = np.mean(losses) # type: float
ious = {'iou_{}'.format(cls + 1): iou
for cls, iou in enumerate(calculate_iou(confusion_matrix))}
dices = {'dice_{}'.format(cls + 1): dice
for cls, dice in enumerate(calculate_dice(confusion_matrix))}
average_iou = np.mean(list(ious.values()))
average_dices = np.mean(list(dices.values()))
print(
'Valid loss: {:.4f}, average IoU: {:.4f}, average Dice: {:.4f}'.format(valid_loss,
average_iou,
average_dices))
metrics = {'valid_loss': valid_loss, 'iou': average_iou}
metrics.update(ious)
metrics.update(dices)
return metrics
def calculate_confusion_matrix_from_arrays(prediction, ground_truth, nr_labels):
replace_indices = np.vstack((
ground_truth.flatten(),
prediction.flatten())
).T
confusion_matrix, _ = np.histogramdd(
replace_indices,
bins=(nr_labels, nr_labels),
range=[(0, nr_labels), (0, nr_labels)]
)
confusion_matrix = confusion_matrix.astype(np.uint32)
return confusion_matrix
def calculate_iou(confusion_matrix):
ious = []
for index in range(confusion_matrix.shape[0]):
true_positives = confusion_matrix[index, index]
false_positives = confusion_matrix[:, index].sum() - true_positives
false_negatives = confusion_matrix[index, :].sum() - true_positives
denom = true_positives + false_positives + false_negatives
if denom == 0:
iou = 0
else:
iou = float(true_positives) / denom
ious.append(iou)
return ious
def calculate_dice(confusion_matrix):
dices = []
for index in range(confusion_matrix.shape[0]):
true_positives = confusion_matrix[index, index]
false_positives = confusion_matrix[:, index].sum() - true_positives
false_negatives = confusion_matrix[index, :].sum() - true_positives
denom = 2 * true_positives + false_positives + false_negatives
if denom == 0:
dice = 0
else:
dice = 2 * float(true_positives) / denom
dices.append(dice)
return dices