-
Notifications
You must be signed in to change notification settings - Fork 5
/
predict.py
112 lines (95 loc) · 4.33 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import time
import torch
import numpy as np
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from argparse import ArgumentParser
# user
from builders.model_builder import build_model
from builders.dataset_builder import build_dataset_test
from utils.utils import save_predict
from utils.convert_state import convert_state_dict
def predict(args, test_loader, model):
"""
args:
test_loader: loaded for test dataset, for those that do not provide label on the test set
model: model
return: class IoU and mean IoU
"""
# evaluation or test mode
model.eval()
total_batches = len(test_loader)
for i, (input, size, name) in enumerate(test_loader):
with torch.no_grad():
input_var = Variable(input).cuda()
start_time = time.time()
output0,output1,output2,output3 = model(input_var)
torch.cuda.synchronize()
time_taken = time.time() - start_time
print('[%d/%d] time: %.2f' % (i + 1, total_batches, time_taken))
output = output.cpu().data[0].numpy()
output = output.transpose(1, 2, 0)
output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
# Save the predict greyscale output for Cityscapes official evaluation
# Modify image name to meet official requirement
name[0] = name[0].rsplit('_', 1)[0] + '*'
save_predict(output, None, name[0], args.dataset, args.save_seg_dir,
output_grey=True, output_color=False, gt_color=False)
def test_model(args):
"""
main function for testing
param args: global arguments
return: None
"""
print(args)
if args.cuda:
print("=====> use gpu id: '{}'".format(args.gpus))
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if not torch.cuda.is_available():
raise Exception("no GPU found or wrong gpu id, please run without --cuda")
# build the model
model = build_model(args.model, num_classes=args.classes)
if args.cuda:
model = model.cuda() # using GPU for inference
cudnn.benchmark = True
if not os.path.exists(args.save_seg_dir):
os.makedirs(args.save_seg_dir)
# load the test set
datas, testLoader = build_dataset_test(args.dataset, args.num_workers, none_gt=True)
if args.checkpoint:
if os.path.isfile(args.checkpoint):
print("=====> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint)
model.load_state_dict(checkpoint['model'])
# model.load_state_dict(convert_state_dict(checkpoint['model']))
else:
print("=====> no checkpoint found at '{}'".format(args.checkpoint))
raise FileNotFoundError("no checkpoint found at '{}'".format(args.checkpoint))
print("=====> beginning testing")
print("test set length: ", len(testLoader))
predict(args, testLoader, model)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--model', default="SSFPN", help="model name: [DSANet,SPFNet,SSFPN]")
parser.add_argument('--dataset', default="cityscapes", help="dataset: cityscapes or camvid")
parser.add_argument('--num_workers', type=int, default=1, help="the number of parallel threads")
parser.add_argument('--batch_size', type=int, default=1,
help=" the batch_size is set to 1 when evaluating or testing")
parser.add_argument('--checkpoint', type=str,
default="./checkpoint/cityscapes/SSFPNbs8gpu1_trainval/model_500.pth",
help="use the file to load the checkpoint for evaluating or testing ")
parser.add_argument('--save_seg_dir', type=str, default="result/",
help="saving path of prediction result")
parser.add_argument('--cuda', default=True, help="run on CPU or GPU")
parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
args = parser.parse_args()
args.save_seg_dir = os.path.join(args.save_seg_dir, args.dataset, 'predict', args.model)
if args.dataset == 'cityscapes':
args.classes = 19
elif args.dataset == 'camvid':
args.classes = 11
else:
raise NotImplementedError(
"This repository now supports two datasets: cityscapes and camvid, %s is not included" % args.dataset)
test_model(args)