-
Notifications
You must be signed in to change notification settings - Fork 1
/
segmenter_test_predictions.py
131 lines (99 loc) · 4.25 KB
/
segmenter_test_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Get classification metrics for a trained classifier model
# Authors:
# Christian F. Baumgartner (c.f.baumgartner@gmail.com)
import numpy as np
import os
import glob
from importlib.machinery import SourceFileLoader
import argparse
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from medpy.metric import dc, assd, hd
import config.system as sys_config
from segmenter.model_segmenter import segmenter as segmenter
if not sys_config.running_on_gpu_host:
import matplotlib.pyplot as plt
import logging
from data.data_switch import data_switch
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
structures_dict = {1: 'RV', 2: 'Myo', 3: 'LV'}
def main(model_path, exp_config, do_plots=False):
# Get Data
data_loader = data_switch(exp_config.data_identifier)
data = data_loader(exp_config)
# Make and restore vagan model
segmenter_model = segmenter(exp_config=exp_config, data=data, fixed_batch_size=1) # CRF model requires fixed batch size
segmenter_model.load_weights(model_path, type='best_dice')
# Run predictions in an endless loop
dice_list = []
assd_list = []
hd_list = []
for ii, batch in enumerate(data.test.iterate_batches(1)):
if ii % 100 == 0:
logging.info("Progress: %d" % ii)
x, y = batch
y_ = segmenter_model.predict(x)[0]
per_lbl_dice = []
per_lbl_assd = []
per_lbl_hd = []
per_pixel_preds = []
per_pixel_gts = []
if do_plots and not sys_config.running_on_gpu_host:
fig = plt.figure()
fig.add_subplot(131)
plt.imshow(np.squeeze(x), cmap='gray')
fig.add_subplot(132)
plt.imshow(np.squeeze(y_))
fig.add_subplot(133)
plt.imshow(np.squeeze(y))
plt.show()
for lbl in range(exp_config.nlabels):
binary_pred = (y_ == lbl) * 1
binary_gt = (y == lbl) * 1
if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
per_lbl_dice.append(1)
per_lbl_assd.append(0)
per_lbl_hd.append(0)
elif np.sum(binary_pred) > 0 and np.sum(binary_gt) == 0 or np.sum(binary_pred) == 0 and np.sum(binary_gt) > 0:
logging.warning('Structure missing in either GT (x)or prediction. ASSD and HD will not be accurate.')
per_lbl_dice.append(0)
per_lbl_assd.append(1)
per_lbl_hd.append(1)
else:
per_lbl_dice.append(dc(binary_pred, binary_gt))
per_lbl_assd.append(assd(binary_pred, binary_gt))
per_lbl_hd.append(hd(binary_pred, binary_gt))
dice_list.append(per_lbl_dice)
assd_list.append(per_lbl_assd)
hd_list.append(per_lbl_hd)
per_pixel_preds.append(y_.flatten())
per_pixel_gts.append(y.flatten())
dice_arr = np.asarray(dice_list)
assd_arr = np.asarray(assd_list)
hd_arr = np.asarray(hd_list)
mean_per_lbl_dice = dice_arr.mean(axis=0)
mean_per_lbl_assd = assd_arr.mean(axis=0)
mean_per_lbl_hd = hd_arr.mean(axis=0)
logging.info('Dice')
logging.info(structures_dict)
logging.info(mean_per_lbl_dice)
logging.info(np.mean(mean_per_lbl_dice))
logging.info('foreground mean: %f' % (np.mean(mean_per_lbl_dice[1:])))
logging.info('ASSD')
logging.info(structures_dict)
logging.info(mean_per_lbl_assd)
logging.info(np.mean(mean_per_lbl_assd))
logging.info('HD')
logging.info(structures_dict)
logging.info(mean_per_lbl_hd)
logging.info(np.mean(mean_per_lbl_hd))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Script for a simple test loop evaluating a network on the test dataset")
parser.add_argument("EXP_PATH", type=str, help="Path to experiment folder (assuming you are in the working directory)")
args = parser.parse_args()
base_path = sys_config.project_root
model_path = os.path.join(base_path, args.EXP_PATH)
config_file = glob.glob(model_path + '/*py')[0]
config_module = config_file.split('/')[-1].rstrip('.py')
exp_config = SourceFileLoader(config_module, os.path.join(config_file)).load_module()
main(model_path, exp_config=exp_config, do_plots=False)