-
Notifications
You must be signed in to change notification settings - Fork 426
/
Copy pathtrain.py
32 lines (26 loc) · 1.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from config import get_config
from Learner import face_learner
import argparse
# python train.py -net mobilefacenet -b 200 -w 4
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='for face verification')
parser.add_argument("-e", "--epochs", help="training epochs", default=20, type=int)
parser.add_argument("-net", "--net_mode", help="which network, [ir, ir_se, mobilefacenet]",default='ir_se', type=str)
parser.add_argument("-depth", "--net_depth", help="how many layers [50,100,152]", default=50, type=int)
parser.add_argument('-lr','--lr',help='learning rate',default=1e-3, type=float)
parser.add_argument("-b", "--batch_size", help="batch_size", default=96, type=int)
parser.add_argument("-w", "--num_workers", help="workers number", default=3, type=int)
parser.add_argument("-d", "--data_mode", help="use which database, [vgg, ms1m, emore, concat]",default='emore', type=str)
args = parser.parse_args()
conf = get_config()
if args.net_mode == 'mobilefacenet':
conf.use_mobilfacenet = True
else:
conf.net_mode = args.net_mode
conf.net_depth = args.net_depth
conf.lr = args.lr
conf.batch_size = args.batch_size
conf.num_workers = args.num_workers
conf.data_mode = args.data_mode
learner = face_learner(conf)
learner.train(conf, args.epochs)