|
| 1 | +#coding:utf8 |
| 2 | +from config import opt |
| 3 | +import os |
| 4 | +import torch as t |
| 5 | +import models |
| 6 | +from data.dataset import DogCat |
| 7 | +from torch.utils.data import DataLoader |
| 8 | +from torch.autograd import Variable |
| 9 | +from torchnet import meter |
| 10 | +from utils.visualize import Visualizer |
| 11 | + |
| 12 | +def test(**kwargs): |
| 13 | + opt.parse(kwargs) |
| 14 | + import ipdb; |
| 15 | + ipdb.set_trace() |
| 16 | + # configure model |
| 17 | + model = getattr(models, opt.model)().eval() |
| 18 | + if opt.load_model_path: |
| 19 | + model.load(opt.load_model_path) |
| 20 | + if opt.use_gpu: model.cuda() |
| 21 | + |
| 22 | + # data |
| 23 | + train_data = DogCat(opt.test_data_root,test=True) |
| 24 | + test_dataloader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers) |
| 25 | + results = [] |
| 26 | + for ii,(data,path) in enumerate(test_dataloader): |
| 27 | + input = t.autograd.Variable(data,volatile = True) |
| 28 | + if opt.use_gpu: input = input.cuda() |
| 29 | + score = model(input) |
| 30 | + probability = t.nn.functional.softmax(score)[:,0].data.tolist() |
| 31 | + # label = score.max(dim = 1)[1].data.tolist() |
| 32 | + |
| 33 | + batch_results = [(path_,probability_) for path_,probability_ in zip(path,probability) ] |
| 34 | + |
| 35 | + results += batch_results |
| 36 | + write_csv(results,opt.result_file) |
| 37 | + |
| 38 | + return results |
| 39 | + |
| 40 | +def write_csv(results,file_name): |
| 41 | + import csv |
| 42 | + with open(file_name,'w') as f: |
| 43 | + writer = csv.writer(f) |
| 44 | + writer.writerow(['id','label']) |
| 45 | + writer.writerows(results) |
| 46 | + |
| 47 | +def train(**kwargs): |
| 48 | + opt.parse(kwargs) |
| 49 | + vis = Visualizer(opt.env) |
| 50 | + |
| 51 | + # step1: configure model |
| 52 | + model = getattr(models, opt.model)() |
| 53 | + if opt.load_model_path: |
| 54 | + model.load(opt.load_model_path) |
| 55 | + if opt.use_gpu: model.cuda() |
| 56 | + |
| 57 | + # step2: data |
| 58 | + train_data = DogCat(opt.train_data_root,train=True) |
| 59 | + val_data = DogCat(opt.train_data_root,train=False) |
| 60 | + train_dataloader = DataLoader(train_data,opt.batch_size, |
| 61 | + shuffle=True,num_workers=opt.num_workers) |
| 62 | + val_dataloader = DataLoader(val_data,opt.batch_size, |
| 63 | + shuffle=False,num_workers=opt.num_workers) |
| 64 | + |
| 65 | + # step3: criterion and optimizer |
| 66 | + criterion = t.nn.CrossEntropyLoss() |
| 67 | + lr = opt.lr |
| 68 | + optimizer = t.optim.Adam(model.parameters(),lr = lr,weight_decay = opt.weight_decay) |
| 69 | + |
| 70 | + # step4: meters |
| 71 | + loss_meter = meter.AverageValueMeter() |
| 72 | + confusion_matrix = meter.ConfusionMeter(2) |
| 73 | + previous_loss = 1e100 |
| 74 | + |
| 75 | + # train |
| 76 | + for epoch in range(opt.max_epoch): |
| 77 | + |
| 78 | + loss_meter.reset() |
| 79 | + confusion_matrix.reset() |
| 80 | + |
| 81 | + for ii,(data,label) in enumerate(train_dataloader): |
| 82 | + |
| 83 | + # train model |
| 84 | + input = Variable(data) |
| 85 | + target = Variable(label) |
| 86 | + if opt.use_gpu: |
| 87 | + input = input.cuda() |
| 88 | + target = target.cuda() |
| 89 | + |
| 90 | + optimizer.zero_grad() |
| 91 | + score = model(input) |
| 92 | + loss = criterion(score,target) |
| 93 | + loss.backward() |
| 94 | + optimizer.step() |
| 95 | + |
| 96 | + |
| 97 | + # meters update and visualize |
| 98 | + loss_meter.add(loss.data[0]) |
| 99 | + confusion_matrix.add(score.data, target.data) |
| 100 | + |
| 101 | + if ii%opt.print_freq==opt.print_freq-1: |
| 102 | + vis.plot('loss', loss_meter.value()[0]) |
| 103 | + |
| 104 | + # 进入debug模式 |
| 105 | + if os.path.exists(opt.debug_file): |
| 106 | + import ipdb; |
| 107 | + ipdb.set_trace() |
| 108 | + |
| 109 | + |
| 110 | + model.save() |
| 111 | + |
| 112 | + # validate and visualize |
| 113 | + val_cm,val_accuracy = val(model,val_dataloader) |
| 114 | + |
| 115 | + vis.plot('val_accuracy',val_accuracy) |
| 116 | + vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format( |
| 117 | + epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr)) |
| 118 | + |
| 119 | + # update learning rate |
| 120 | + if loss_meter.value()[0] > previous_loss: |
| 121 | + lr = lr * opt.lr_decay |
| 122 | + # 第二种降低学习率的方法:不会有moment等信息的丢失 |
| 123 | + for param_group in optimizer.param_groups: |
| 124 | + param_group['lr'] = lr |
| 125 | + |
| 126 | + |
| 127 | + previous_loss = loss_meter.value()[0] |
| 128 | + |
| 129 | +def val(model,dataloader): |
| 130 | + ''' |
| 131 | + 计算模型在验证集上的准确率等信息 |
| 132 | + ''' |
| 133 | + model.eval() |
| 134 | + confusion_matrix = meter.ConfusionMeter(2) |
| 135 | + for ii, data in enumerate(dataloader): |
| 136 | + input, label = data |
| 137 | + val_input = Variable(input, volatile=True) |
| 138 | + val_label = Variable(label.type(t.LongTensor), volatile=True) |
| 139 | + if opt.use_gpu: |
| 140 | + val_input = val_input.cuda() |
| 141 | + val_label = val_label.cuda() |
| 142 | + score = model(val_input) |
| 143 | + confusion_matrix.add(score.data.squeeze(), label.type(t.LongTensor)) |
| 144 | + |
| 145 | + model.train() |
| 146 | + cm_value = confusion_matrix.value() |
| 147 | + accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum()) |
| 148 | + return confusion_matrix, accuracy |
| 149 | + |
| 150 | +def help(): |
| 151 | + ''' |
| 152 | + 打印帮助的信息: python file.py help |
| 153 | + ''' |
| 154 | + |
| 155 | + print(''' |
| 156 | + usage : python file.py <function> [--args=value] |
| 157 | + <function> := train | test | help |
| 158 | + example: |
| 159 | + python {0} train --env='env0701' --lr=0.01 |
| 160 | + python {0} test --dataset='path/to/dataset/root/' |
| 161 | + python {0} help |
| 162 | + avaiable args:'''.format(__file__)) |
| 163 | + |
| 164 | + from inspect import getsource |
| 165 | + source = (getsource(opt.__class__)) |
| 166 | + print(source) |
| 167 | + |
| 168 | +if __name__=='__main__': |
| 169 | + import fire |
| 170 | + fire.Fire() |
0 commit comments