-
Notifications
You must be signed in to change notification settings - Fork 115
/
eval_existingOnes.py
148 lines (137 loc) · 8.29 KB
/
eval_existingOnes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import argparse
from glob import glob
import prettytable as pt
from evaluation.metrics import evaluator
from config import Config
config = Config()
def do_eval(args):
# evaluation for whole dataset
# dataset first in evaluation
for _data_name in args.data_lst.split('+'):
pred_data_dir = sorted(glob(os.path.join(args.pred_root, args.model_lst[0], _data_name)))
if not pred_data_dir:
print('Skip dataset {}.'.format(_data_name))
continue
gt_src = os.path.join(args.gt_root, _data_name)
gt_paths = sorted(glob(os.path.join(gt_src, 'gt', '*')))
print('#' * 20, _data_name, '#' * 20)
filename = os.path.join(args.save_dir, '{}_eval.txt'.format(_data_name))
tb = pt.PrettyTable()
tb.vertical_char = '&'
if config.task == 'DIS5K':
tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU']
elif config.task == 'COD':
tb.field_names = ["Dataset", "Method", "Smeasure", "wFmeasure", "meanFm", "meanEm", "maxEm", 'MAE', "maxFm", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU']
elif config.task == 'HRSOD':
tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU']
elif config.task == 'General':
tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU']
elif config.task == 'General-2K':
tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU']
elif config.task == 'Matting':
tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MSE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU']
else:
tb.field_names = ["Dataset", "Method", "Smeasure", 'MAE', "maxEm", "meanEm", "maxFm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU']
for _model_name in args.model_lst[:]:
print('\t', 'Evaluating model: {}...'.format(_model_name))
pred_paths = [p.replace(args.gt_root, os.path.join(args.pred_root, _model_name)).replace('/gt/', '/') for p in gt_paths]
# print(pred_paths[:1], gt_paths[:1])
em, sm, fm, mae, mse, wfm, hce, mba, biou = evaluator(
gt_paths=gt_paths,
pred_paths=pred_paths,
metrics=args.metrics.split('+'),
verbose=config.verbose_eval
)
if config.task == 'DIS5K':
scores = [
fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()),
em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3),
mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3),
]
elif config.task == 'COD':
scores = [
sm.round(3), wfm.round(3), fm['curve'].mean().round(3), em['curve'].mean().round(3), em['curve'].max().round(3), mae.round(3),
fm['curve'].max().round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3),
]
elif config.task == 'HRSOD':
scores = [
sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mae.round(3),
em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3),
]
elif config.task == 'General':
scores = [
fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()),
em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3),
mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3),
]
elif config.task == 'General-2K':
scores = [
fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()),
em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3),
mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3),
]
elif config.task == 'Matting':
scores = [
sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mse.round(5),
em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3),
]
else:
scores = [
sm.round(3), mae.round(3), em['curve'].max().round(3), em['curve'].mean().round(3),
fm['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3),
em['adp'].round(3), fm['adp'].round(3), int(hce.round()),
mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3),
]
for idx_score, score in enumerate(scores):
scores[idx_score] = '.' + format(score, '.3f').split('.')[-1] if score <= 1 else format(score, '<4')
records = [_data_name, _model_name] + scores
tb.add_row(records)
# Write results after every check.
with open(filename, 'w+') as file_to_write:
file_to_write.write(str(tb)+'\n')
print(tb)
if __name__ == '__main__':
# set parameters
parser = argparse.ArgumentParser()
parser.add_argument(
'--gt_root', type=str, help='ground-truth root',
default=os.path.join(config.data_root_dir, config.task))
parser.add_argument(
'--pred_root', type=str, help='prediction root',
default='./e_preds')
parser.add_argument(
'--data_lst', type=str, help='test dataset',
default=config.testsets.replace(',', '+'))
parser.add_argument(
'--save_dir', type=str, help='candidate competitors',
default='e_results')
parser.add_argument(
'--check_integrity', type=bool, help='whether to check the file integrity',
default=False)
parser.add_argument(
'--metrics', type=str, help='candidate competitors',
default='+'.join(['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'MSE', 'HCE'][:100 if 'DIS5K' in config.task else -1]))
args = parser.parse_args()
args.metrics = '+'.join(['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'MSE', 'HCE'][:100 if sum(['DIS-' in _data for _data in args.data_lst.split('+')]) else -1])
os.makedirs(args.save_dir, exist_ok=True)
try:
args.model_lst = [m for m in sorted(os.listdir(args.pred_root), key=lambda x: int(x.split('epoch_')[-1]), reverse=True) if int(m.split('epoch_')[-1]) % 1 == 0]
except:
args.model_lst = [m for m in sorted(os.listdir(args.pred_root))]
# check the integrity of each candidates
if args.check_integrity:
for _data_name in args.data_lst.split('+'):
for _model_name in args.model_lst:
gt_pth = os.path.join(args.gt_root, _data_name)
pred_pth = os.path.join(args.pred_root, _model_name, _data_name)
if not sorted(os.listdir(gt_pth)) == sorted(os.listdir(pred_pth)):
print(len(sorted(os.listdir(gt_pth))), len(sorted(os.listdir(pred_pth))))
print('The {} Dataset of {} Model is not matching to the ground-truth'.format(_data_name, _model_name))
else:
print('>>> skip check the integrity of each candidates')
# start engine
do_eval(args)