This repository has been archived by the owner on Jul 12, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
/
train_yolo_train.lua
356 lines (275 loc) · 11.7 KB
/
train_yolo_train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
--
--
-- User: changqi
-- Date: 3/14/16
-- Time: 12:25 PM
-- To change this template use File | Settings | File Templates.
require 'nn';
require 'optim'
require 'cunn'
require 'cudnn'
require 'image';
require 'xlua';
require './yoloCriterion/RegionProposalCriterion.lua'
dofile './provider_yolo.lua'
local class = require 'class'
local c = require 'trepl.colorize'
opt = lapp [[
-s,--save (default "logs") subdirectory to save logs
-b,--batchSize (default 6) batch size
-r,--learningRate (default 1e-3) learning rate
--learningRateDecay (default 1e-7) learning rate decay
--weightDecay (default 0.0005) weightDecay
-m,--momentum (default 0.9) momentum
--epoch_step (default 25) epoch step
--model (default model_yolo_train) model name
--max_epoch (default 300) maximum number of iterations
--backend (default nn) backend
-i,--log_interval (default 5.0) show log interval
--modelPath (default logs/model_fcnn.net) exist model
--trainSize (default 0) set training and test data size;0->unlimit
--testSize (default 0) set training and test data size;0->unlimit
--criWeight (default 4) set criterion's weights for region proposal
--bWeight (default 0.25) set criterion's weights for balance classification
]]
print(opt)
------------------------------------ loading data----------------------------------------
print(c.blue '==>' .. ' loading data')
--provider = torch.load('provider_yolo.t7')
provider.trainData.data = provider.trainData.data:float()
provider.testData.data = provider.testData.data:float()
function _narrow(trSize,teSize)
if trSize ~=0 then
provider.trainData.data = provider.trainData.data:narrow(1,1,trSize);
provider.trainData.labels = provider.trainData.labels:narrow(1,1,trSize);
end
if teSize ~=0 then
provider.testData.data = provider.testData.data:narrow(1,1,teSize);
provider.testData.labels = provider.testData.labels:narrow(1,1,teSize);
end
end
--if opt.trainSize ~=0 then
-- _narrow(opt.trainSize,opt.testSize);
--end
do -- data augmentation module
local BatchFlip, parent = torch.class('nn.BatchFlip', 'nn.Module')
function BatchFlip:__init()
parent.__init(self)
self.train = true
end
function BatchFlip:updateOutput(input)
if self.train then
local bs = input:size(1)
local flip_mask = torch.randperm(bs):le(bs / 2)
for i = 1, input:size(1) do
if flip_mask[i] == 1 then image.hflip(input[i], input[i]) end
end
end
self.output:set(input)
return self.output
end
end
------------------------------------ configuring----------------------------------------
print(c.red '==>' .. 'configuring model')
local modelPath = opt.modelPath;
local model = nn.Sequential();
--model:add(nn.BatchFlip():float())
model:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor'));
--if modelPath and paths.filep(modelPath) then
-- model:add(torch.load(modelPath));
-- print('==> load exist model:' .. modelPath);
--else
model:add(dofile(opt.model .. '.lua'):cuda());
--end
--model:get(1).updateGradInput = function(input) return end
----------------------------------- load exist model -------------------------------------
if opt.backend == 'cudnn' then
require 'cudnn'
cudnn.convert(model:get(2), cudnn)
end
print(model);
parameters, gradParameters = model:getParameters()
------------------------------------ save log----------------------------------------
print('Will save at ' .. opt.save)
paths.mkdir(opt.save)
testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
testLogger:setNames { 'type 2 error', 'type 1 error' }
testLogger.showPlot = false
------------------------------------ set criterion---------------------------------------
print(c.blue '==>' .. ' setting criterion')
--train set:351641088 positive:2344174
criterion = nn.RegionProposalCriterion(opt.criWeight,opt.bWeight,7,true):cuda();
--criterion = cudnn.SpatialCrossEntropyCriterion(torch.Tensor{1.006,150}):cuda();
--criterion = nn.CrossEntropyCriterion(torch.Tensor({1.006,150})):cuda();
confusion = optim.ConfusionMatrix(2);
------------------------------------ optimizer config-------------------------------------
print(c.blue '==>' .. ' configuring optimizer')
optimState = {
learningRate = opt.learningRate,
weightDecay = opt.weightDecay,
momentum = opt.momentum,
learningRateDecay = opt.learningRateDecay,
}
function train()
model:training();
epoch = epoch or 1;
-- drop learning rate every "epoch_step" epochs ?
if epoch % opt.epoch_step == 0 then optimState.learningRate = optimState.learningRate / 2 end
-- update negative set every 6 epochs.
print(c.blue '==>' .. " online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']')
local targets = torch.CudaTensor(opt.batchSize,provider.trainData.labels:size(2),provider.trainData.labels:size(3),provider.trainData.labels:size(4));
-- random index and split all index into batches.
local indices = torch.randperm(provider.trainData.data:size(1)):long():split(opt.batchSize);
indices[#indices] = nil;
local tic = torch.tic();
for t, v in ipairs(indices) do
xlua.progress(t, #indices)
local innerTic = torch.tic();
local inputs = provider.trainData.data:index(1, v);
targets:copy(provider.trainData.labels:index(1, v));
local feval = function(x)
if x ~= parameters then parameters:copy(x) end
gradParameters:zero();
-- require('mobdebug').start(nill,8222);
outputs = model:cuda():forward(inputs)
-- flatOutput = _flatTensor(outputs);
-- flatTargets = targets:reshape(targets:nElement());
f = criterion:forward(outputs:cuda(), targets:cuda())
--outputs: Bx2xHxW target: Bx1xHxW
df = criterion:backward(outputs:cuda(), targets:cuda());
-- local result = outputs:select(2,2):csub(outputs:select(2,1));
-- result[result:gt(0)]=2;
-- result[result:le(0)]=1;
-- require('mobdebug').start(nill,8222);
-- local final = torch.eq(result,targets);
-- local accuracy = final:sum()/targets:nElement();
model:backward(inputs:cuda(), df:cuda());
--outputs: Bx2xHxW target: BxHxW
local outputsClasses = outputs:reshape(outputs:size(1),7,7,6):narrow(4,5,2);
local targetClasses = targets:reshape(targets:size(1),7,7,6):narrow(4,6,1);
confusion:batchAdd(outputsClasses:reshape(outputs:size(1)*7*7,2), targetClasses:reshape(targets:size(1)*7*7,1)+1);
-- print('losts: '..f..' and accuracy: '..accuracy..'\n');
return f, gradParameters;
end
local x, fx = optim.sgd(feval, parameters, optimState);
local innerToc = torch.toc(innerTic);
local function printInfo()
local tmpl = '---------%d/%d (epoch %.3f), ' ..
'train_loss = %6.8f, grad/param norm = %6.4e, ' ..
'speed = %5.1f/s, %5.3fs/iter -----------'
print(string.format(tmpl,
t, #indices, epoch,
fx[1], gradParameters:norm() / parameters:norm(),
opt.batchSize / innerToc, innerToc))
end
if t % opt.log_interval == 0 then
printInfo();
end
end
confusion:updateValids();
-- print(c.red('Train accuracy: ' .. c.cyan '%.2f' .. ' %%\t time: %.2f s'):format(confusion.totalValid * 100, torch.toc(tic)));
print('Train accuracy:', confusion.totalValid * 100)
print(confusion)
confusion:zero()
epoch = epoch + 1;
end
function _flatTensor(tensor)
-- from Bx2xHxW to B*H*W x 2
local subset1 = tensor:select(2,1);
local outputs1 = subset1:reshape(subset1:nElement());
local subset2 = tensor:select(2,2);
local outputs2 = subset2:reshape(subset2:nElement());
return torch.cat(outputs1,outputs2,2);
end
function _unflatTensor(tensor,origSize)
local subset1 = tensor:select(2,1);
local outputs1 = subset1:reshape(origSize);
local subset2 = tensor:select(2,1);
local outputs2 = subset2:reshape(origSize);
return torch.cat(outputs1,outputs2,2);
end
function test()
-- disable flips, dropouts and batch normalization
model:evaluate()
print(c.blue '==>' .. " testing")
local bs = opt.batchSize;
len = provider.testData.data:size(1);
for i = 1, len, bs do
xlua.progress(i, len)
if (i + bs) > len then idxEnd = len - i; end
-- print (('-->testDataSize:%s;i:%s;bs:%s;idxEnd:%s;idxEnd or bs: %s'):format(provider.testData.data:size(1),i,bs,idxEnd,idxEnd or bs))
local inputs = provider.testData.data:narrow(1, i, idxEnd or bs);
local targets = provider.testData.labels:narrow(1, i, idxEnd or bs);
local outputs = model:forward(inputs)
-- local flatOutput = _flatTensor(outputs);
local outputsClasses = outputs:reshape(outputs:size(1),7,7,6):narrow(4,5,2);
local targetClasses = targets:reshape(targets:size(1),7,7,6):narrow(4,6,1);
local flatOutput = outputs;
local flatTargets = targets;
-- require('mobdebug').start(nill,8222);
--outputs: Bx2xHxW target: BxHxW
confusion:batchAdd(outputsClasses:reshape(outputs:size(1)*7*7,2), targetClasses:reshape(targets:size(1)*7*7,1)+1);
-- local result = outputs:select(2,1):csub(outputs:select(2,2));
-- result[result:gt(0)]=2;
-- result[result:le(0)]=1;
--
-- local targets = provider.testData.labels:narrow(1, i, idxEnd or bs):cuda();
--
--
-- local final = torch.eq(result,targets);
--
-- local accuracy = final:sum()/(6*549*512);
--
-- -- confusion:batchAdd(outputs, targets);
-- print('Accuracy: '..accuracy..'\n');
end
confusion:updateValids()
print('Test accuracy:', confusion.totalValid * 100)
print(confusion)
if testLogger then
paths.mkdir(opt.save)
-- require('mobdebug').start(nill,8222);
testLogger:add { confusion.valids[1] ,confusion.valids[2]}
testLogger:style { '+-','+-' }
testLogger:plot()
local base64im
do
os.execute(('convert -density 200 %s/test.log.eps %s/test.png'):format(opt.save, opt.save))
os.execute(('openssl base64 -in %s/test.png -out %s/test.base64'):format(opt.save, opt.save))
local f = io.open(opt.save .. '/test.base64')
if f then base64im = f:read '*all' end
end
local file = io.open(opt.save .. '/report.html', 'w')
file:write(([[
<!DOCTYPE html>
<html>
<body>
<title>%s - %s</title>
<img src="data:image/png;base64,%s">
<h4>optimState:</h4>
<table>
]]):format(opt.save, epoch, base64im))
for k, v in pairs(optimState) do
if torch.type(v) == 'number' then
file:write('<tr><td>' .. k .. '</td><td>' .. v .. '</td></tr>\n')
end
end
file:write '</table><pre>\n'
file:write(tostring(confusion) .. '\n')
file:write(tostring(model) .. '\n')
file:write '</pre></body></html>'
file:close()
end
-- save model every 5 epochs
if epoch % 5 == 0 then
local filename = paths.concat(opt.save, 'model_'..epoch..'.net')
print('==> saving model to ' .. filename)
torch.save(filename, model:get(2):clearState())
end
confusion:zero()
end
for i = 1, opt.max_epoch do
train()
test()
end
-- CUDA_VISIBLE_DEVICES=0 th -i train_yolo.lua --backend=cudnn --save=logs/i1_yoloTest --model=model_yolo