-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy pathtest.py
147 lines (121 loc) · 5.2 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import logging
from config import setup_config
from dataset.dataset import FGDataset
from torchvision import transforms
from model.registry import MODEL
from utils import PerformanceMeter, TqdmHandler, AverageMeter, accuracy
class Tester(object):
"""Test a model from a config which could be a training config.
"""
def __init__(self):
self.config = setup_config()
self.report_one_line = True
self.logger = self.get_logger()
# set device. `config.experiment.cuda` should be a list of gpu device ids, None or [] for cpu only.
self.device = self.config.experiment.cuda if isinstance(self.config.experiment.cuda, list) else []
if len(self.device) > 0:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in self.device])
self.logger.info(f'Using GPU: {self.device}')
else:
self.logger.info(f'Using CPU!')
# build dataloader and model
self.transformer = self.get_transformer(self.config.dataset.transformer)
self.collate_fn = self.get_collate_fn()
self.dataset = self.get_dataset(self.config.dataset)
self.dataloader = self.get_dataloader(self.config.dataset)
self.logger.info(f'Building model {self.config.model.name} ...')
self.model = self.get_model(self.config.model)
self.model = self.to_device(self.model, parallel=True)
self.logger.info(f'Building model {self.config.model.name} OK!')
# build meters
self.performance_meters = self.get_performance_meters()
self.average_meters = self.get_average_meters()
def get_logger(self):
logger = logging.getLogger()
logger.handlers = []
logger.setLevel(logging.INFO)
screen_handler = TqdmHandler()
screen_handler.setFormatter(logging.Formatter('[%(asctime)s] %(message)s'))
logger.addHandler(screen_handler)
return logger
def get_performance_meters(self):
return {
metric: PerformanceMeter() for metric in ['acc']
}
def get_average_meters(self):
return {
meter: AverageMeter() for meter in ['acc']
}
def get_model(self, config):
"""Build and load model in config
"""
name = config.name
model = MODEL.get(name)(config)
assert 'load' in config and config.load != '', 'There is no valid `load` in config[model.load]!'
self.logger.info(f'Loading model from {config.load}')
state_dict = torch.load(config.load, map_location='cpu')
model.load_state_dict(state_dict)
self.logger.info(f'OK! Model loaded from {config.load}')
return model
def get_transformer(self, config):
return transforms.Compose([
transforms.Resize(size=config.resize_size),
transforms.CenterCrop(size=config.image_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
def get_collate_fn(self):
return None
def get_dataset(self, config):
path = os.path.join(config.meta_dir, 'val.txt')
return FGDataset(config.root_dir, path, transform=self.transformer)
def get_dataloader(self, config):
return DataLoader(self.dataset, config.batch_size, num_workers=config.num_workers, pin_memory=False,
shuffle=False, collate_fn=self.collate_fn)
def to_device(self, m, parallel=False):
if len(self.device) == 0:
m = m.to('cpu')
elif len(self.device) == 1 or not parallel:
m = m.to(f'cuda:{self.device[0]}')
else:
m = m.cuda(self.device[0])
m = torch.nn.DataParallel(m, device_ids=self.device)
return m
def get_model_module(self, model=None):
"""get `model` in single-gpu mode or `model.module` in multi-gpu mode.
"""
if model is None:
model = self.model
if isinstance(model, torch.nn.DataParallel):
return model.module
else:
return model
def test(self):
self.logger.info(f'Testing model from {self.config.model.load}')
self.validate()
self.performance_meters['acc'].update(self.average_meters['acc'].avg)
self.report()
def validate(self):
self.model.train(False)
with torch.no_grad():
val_bar = tqdm(self.dataloader, ncols=100)
for data in val_bar:
self.batch_validate(data)
val_bar.set_description(f'Testing')
val_bar.set_postfix(acc=self.average_meters['acc'].avg)
def batch_validate(self, data):
images, labels = self.to_device(data['img']), self.to_device(data['label'])
logits = self.model(images)
acc = accuracy(logits, labels, 1)
self.average_meters['acc'].update(acc, images.size(0))
def report(self):
metric_str = ' '.join([f'{metric}: {self.performance_meters[metric].current_value:.2f}'
for metric in self.performance_meters])
self.logger.info(metric_str)
if __name__ == '__main__':
tester = Tester()
tester.test()