-
Notifications
You must be signed in to change notification settings - Fork 128
/
main.py
81 lines (61 loc) · 2.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import argparse
import pprint
from data import dataloader
from run_networks import model
import warnings
from utils import source_import
# ================
# LOAD CONFIGURATIONS
data_root = {'ImageNet': '/home/public/public_dataset/ILSVRC2014/Img',
'Places': '/home/public/dataset/Places365'}
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./config/Imagenet_LT/Stage_1.py', type=str)
parser.add_argument('--test', default=False, action='store_true')
parser.add_argument('--test_open', default=False, action='store_true')
parser.add_argument('--output_logits', default=False)
args = parser.parse_args()
test_mode = args.test
test_open = args.test_open
if test_open:
test_mode = True
output_logits = args.output_logits
config = source_import(args.config).config
training_opt = config['training_opt']
# change
relatin_opt = config['memory']
dataset = training_opt['dataset']
if not os.path.isdir(training_opt['log_dir']):
os.makedirs(training_opt['log_dir'])
print('Loading dataset from: %s' % data_root[dataset.rstrip('_LT')])
pprint.pprint(config)
if not test_mode:
sampler_defs = training_opt['sampler']
if sampler_defs:
sampler_dic = {'sampler': source_import(sampler_defs['def_file']).get_sampler(),
'num_samples_cls': sampler_defs['num_samples_cls']}
else:
sampler_dic = None
data = {x: dataloader.load_data(data_root=data_root[dataset.rstrip('_LT')], dataset=dataset, phase=x,
batch_size=training_opt['batch_size'],
sampler_dic=sampler_dic,
num_workers=training_opt['num_workers'])
for x in (['train', 'val', 'train_plain'] if relatin_opt['init_centroids'] else ['train', 'val'])}
training_model = model(config, data, test=False)
training_model.train()
else:
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
print('Under testing phase, we load training data simply to calculate training data number for each class.')
data = {x: dataloader.load_data(data_root=data_root[dataset.rstrip('_LT')], dataset=dataset, phase=x,
batch_size=training_opt['batch_size'],
sampler_dic=None,
test_open=test_open,
num_workers=training_opt['num_workers'],
shuffle=False)
for x in ['train', 'test']}
training_model = model(config, data, test=True)
training_model.load_model()
training_model.eval(phase='test', openset=test_open)
if output_logits:
training_model.output_logits(openset=test_open)
print('ALL COMPLETED.')