-
Notifications
You must be signed in to change notification settings - Fork 44
/
train.m
342 lines (301 loc) · 14.5 KB
/
train.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
addpath('./CustomLayers/','./utils/')
%% 1、准备数据,适合yolov3,yolov4,无需VOC-xml格式
% 数据问题参考,满足以下其一即可:
% 1、matlab中标注参考 https://ww2.mathworks.cn/help/vision/ug/get-started-with-the-image-labeler.html?requestedDomain=cn
% 2、外部标注文件导入到matlab参考 https://github.com/cuixing158/imageLabeler-API.git
load gTruth.mat % 自己bbox标注文件,格式参考上面链接,最终为table类型,看起来直观
cfg_file = './cfg/yolov3-tiny.cfg';
weight_file = './weights/yolov3-tiny.weights'; %预训练backbone权重,其他类型也OK
annotateImgHeight = 416; % 自己标注的图像原始高度
annotateImgWeight = 416; % 自己标注的图像原始宽度
% 类别名字和对应的ID序号
classesNames = gTruth.Properties.VariableNames(2:end);
classIDs = (0:length(classesNames)-1);% 从0开始标注,保持与darknet官网一致
numClasses = length(classesNames);
structNamesIDs = struct();
for i = 1:numClasses
structNamesIDs.(classesNames{i}) = classIDs(i);
end
% 创建可迭代的数据集
bldsTrain = boxLabelDatastore(gTruth(:, 2:end));
imdsTrain = imageDatastore(gTruth.imageFilename);
miniBatch = 16;
imdsTrain.ReadSize = miniBatch;
bldsTrain.ReadSize = miniBatch;
trainingData = combine(imdsTrain, bldsTrain);
%% 设定超参数,导入训练权重或者导入matlab其他官方预训练权重,这里以darknet中的".weight"二进制权重
[lgModel,hyperParams] = importDarknetWeights(cfg_file,weight_file);
% analyzeNetwork(lgModel);
inputWeight = str2double(hyperParams.width);
inputHeight = str2double(hyperParams.height);
networkInputSize = [inputHeight inputWeight 3];
preprocessedTrainingData = transform(trainingData,@(data)preprocessTrainData(data,networkInputSize,structNamesIDs));
% 预览数据
for k = 1:1
data = read(preprocessedTrainingData);
I = data{1,1}{1};
bbox = data{1,2}{1};
annotatedImage = zeros(size(I),'like',I);
for i = 1:size(I,4)
annotatedImage(:,:,:,i) = insertShape(I(:,:,:,i),'Rectangle',bbox{i}(:,1:4));
end
annotatedImage = imresize(annotatedImage,2);
figure
montage(annotatedImage)
end
numAnchors = 6;
anchorBoxes = estimateAnchorBoxes(trainingData,numAnchors);% anchorBoxes是networkInputSize上的大小,但是 estimateAnchorBoxes函数输入参数限制太死,无法把preprocessedTrainingData传入
area = anchorBoxes(:, 1).*anchorBoxes(:, 2);
[~, idx] = sort(area, 'descend');
anchorBoxes = anchorBoxes(idx, :);
anchorBoxes = round(anchorBoxes);
anchorBoxMasks = {[1,2,3],[4,5,6]};% 面积大的anchor结合特征图较小的yolov3层,面积小的anchor结合特征图较大的yolov3层
%% 2,搭建darknet网络,加入yolov3Layer
anchorBoxes(:,[2,1]) = anchorBoxes(:,[1,2]);% anchorBoxes现在是宽高,与darknet官网保持一致
imageSize = lgModel.Layers(1).InputSize(1:2);
yoloModule1 = [convolution2dLayer(1,length(anchorBoxMasks{1})*(5+numClasses),'Name','yoloconv1');
yolov3Layer('yolov3layer1',anchorBoxMasks{1},anchorBoxes,numClasses,1,imageSize)];
yoloModule2 = [convolution2dLayer(1,length(anchorBoxMasks{2})*(5+numClasses),'Name','yoloconv2');
yolov3Layer('yolov3layer2',anchorBoxMasks{2},anchorBoxes,numClasses,2,imageSize)];
lgModel = removeLayers(lgModel,{'yolo_v3_id1','yolo_v3_id2'});
lgModel = replaceLayer(lgModel,'conv_17',yoloModule1);
lgModel = replaceLayer(lgModel,'conv_24',yoloModule2);
analyzeNetwork(lgModel);
yoloLayerNumber = [36,47];% yolov3层在layers数组中的位置,看模型图得出!!!
model = dlnetwork(lgModel); % dlnetwork函数耗时过长!一个转换操作而已
%% 3,for loop循环迭代更新模型
% 训练选项
learningRate = 0.001;
scheduler = @(epoch,epochs)0.5*(1+cos(epoch*pi/epochs));% cosine https://arxiv.org/pdf/1812.01187.pdf
l2Regularization = 0.0005;
velocity = [];
executionEnvironment = "auto";
figure;
ax1 = subplot(211);
ax2 = subplot(212);
lossPlotter = animatedline(ax1);
learningRatePlotter = animatedline(ax2);
nEpochs = 10;
allIteration = 0;
cumLoss = 0;
meanLoss = 0;
for numEpoch = 1:nEpochs
reset(preprocessedTrainingData);% Reset datastore.
iteration = 0;
while hasdata(preprocessedTrainingData)
t_start = tic;
% Custom training loop.
% Read batch of data and create batch of images and
% ground truths.
outDataTable = read(preprocessedTrainingData);
XTrain = outDataTable{1,1}{1};
YTrain = outDataTable{1,2}{1};
if isempty(YTrain)
continue;
end
% Convert mini-batch of data to dlarray.
XTrain = dlarray(single(XTrain),'SSCB');
% If training on a GPU, then convert data to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
XTrain = gpuArray(XTrain);
end
% Evaluate the model gradients and loss using dlfeval and the
% modelGradients function.
[gradients,boxLoss,objLoss,clsLoss,totalLoss,state] = dlfeval(@modelGradients, model, XTrain, YTrain,yoloLayerNumber);
% Apply L2 regularization.
gradients = dlupdate(@(g,w) g + l2Regularization*w, gradients, model.Learnables);
% Update the network learnable parameters using the SGDM optimizer.
[model, velocity] = sgdmupdate(model, gradients, velocity, learningRate);
% Update the state parameters of dlnetwork.
model.State = state;
fprintf('[%d][%d/%d]\t BatchTime(s):%.2f , LR:%.5f, boxLoss:%-12.3f, objLoss:%-12.3f, clsLoss:%-12.3f, totalLoss:%-12.5f, meanLoss:%-12.5f\n',...
numEpoch,iteration+1,floor(numpartitions(preprocessedTrainingData)/miniBatch),...
toc(t_start),learningRate,boxLoss,objLoss,clsLoss,totalLoss,cumLoss/(allIteration+1));
if isnan(totalLoss)
fprintf('loss is nan!');
return
end
cumLoss = cumLoss+totalLoss;
iteration = iteration +1;
allIteration = allIteration+1;
end
meanLoss = cumLoss/allIteration;
% save model
if (mod(numEpoch,5)==0) % 设置每5个epoch保存下权重
timeStr = datestr(now,'yyyy_mm_dd_HH_MM_SS');
matlabModel = fullfile('./save',[timeStr,'.mat']);
save(matlabModel,'model');
cfgFile = fullfile('./cfg',[timeStr,'.cfg']);
darknetModel = fullfile('./weights',[timeStr,'.weights']);
exportDarkNetNetwork(model,hyperParams,cfgFile,darknetModel);
end
% Update training plot with new points.
addpoints(lossPlotter, numEpoch, double(meanLoss));
% update scheduler
learningRate = scheduler(numEpoch,nEpochs);
addpoints(learningRatePlotter, numEpoch, learningRate);
drawnow
end
%% yolov3 损失函数
function [gradients, boxLoss, objLoss, clsLoss, totalLoss, state] = modelGradients(net, XTrain, YTrain,yoloLayerNumber)
% 功能:计算模型梯度,求取损失
% allYoloLayers = net.Layers(yoloLayerNumber);
yolov3layerNames = net.OutputNames;
outFeatureMaps = cell(size(yolov3layerNames));
[outFeatureMaps{:},state] = forward(net,XTrain,'Outputs',yolov3layerNames);
boxLoss = dlarray(0);
objLoss = dlarray(0);
clsLoss = dlarray(0);
for i = 1:length(outFeatureMaps)
currentYOLOV3Layer = net.Layers(yoloLayerNumber(i));
currentFeatureMap = outFeatureMaps{i};
% 由于yolov3Layer类里面predict函数未改变类属性,故重新给属性赋值
currentYOLOV3Layer.numY = size(currentFeatureMap,1);
currentYOLOV3Layer.numX = size(currentFeatureMap,2);
currentYOLOV3Layer.stride = max(currentYOLOV3Layer.imageSize)./max(currentYOLOV3Layer.numX,...
currentYOLOV3Layer.numY);
% reshape currentFeatureMap到有意义的维度,h*w*c*bs --> h*w*(5+nc)*na*bs
% --> bs*na*h*w*(5+nc),最终的维度方式与darknet官网兼容
bs = size(currentFeatureMap,4);
h = currentYOLOV3Layer.numY;
w = currentYOLOV3Layer.numX;
na = currentYOLOV3Layer.nAnchors;
nc = currentYOLOV3Layer.classes;
currentFeatureMap = reshape(currentFeatureMap,h,w,5+nc,na,bs);% h*w*(5+nc)*na*bs
currentFeatureMap = permute(currentFeatureMap,[5,4,1,2,3]);% bs*na*h*w*(5+nc)
% 构建目标值
[tcls,tbox,indices,anchor_grids] = buildTargets(currentYOLOV3Layer,YTrain);
N = size(tcls,1);% N<=na*YTrain中所有检测框的数量,其代表有效的数量
featuresMapSize = size(currentFeatureMap);% bs*na*h*w*(5+nc)
tobj = zeros(featuresMapSize(1:4),'like',currentFeatureMap);% bs*na*h*w
featuresCh = zeros(N,(5+nc),'like',currentFeatureMap);
if N
b = indices(:,1); % N*1
a = indices(:,2); % N*1
gj = indices(:,3); % N*1
gi = indices(:,4); % N*1
for idx = 1:N
featuresChannels = currentFeatureMap(b(idx),a(idx),gj(idx),gi(idx),:);% 1*1*1*1*(5+nc)
featuresChannels = squeeze(featuresChannels);%(5+nc)*1
featuresChannels = featuresChannels';%1*(5+nc)
featuresCh(idx,:) = featuresChannels; % N*(5+nc)
tobj(b(idx),a(idx),gj(idx),gi(idx)) = 1.0;
end
% mse or GIoU loss
predictXY = sigmoid(featuresCh(:,1:2)); % 大小为N*2,预测对应xy
predictWH = exp(featuresCh(:,3:4)).*anchor_grids;% 大小为N*2
predictBboxs = cat(2,predictXY,predictWH);% 大小为N*4
isUseGIOU = 0;
if isUseGIOU
giouRatio = getGIOU(predictBboxs,tbox);%梯度需要计算,然而反向传播非常耗时
boxLoss = boxLoss+mean(1-giouRatio,'all');
else
boxLoss = boxLoss+mse(predictBboxs,tbox,'DataFormat','BC');
end
if (nc>1)
tcls_ = zeros('like',featuresCh(:,6:end));
for idx = 1:N
tcls_(idx,tcls+1) = 1.0;% 确保类别标签是从0开始标注的索引,否则这里会超出维度
end
clsLoss = clsLoss + crossentropy(sigmoid(featuresCh(:,6:end)),tcls_,...
'DataFormat','BC',...
'TargetCategories','independent');
end
end
if N
objLoss = objLoss+crossentropy(sigmoid(currentFeatureMap(:,:,:,:,5)),tobj,...
'DataFormat','BUSS',...
'TargetCategories','independent');
end
end
totalLoss = boxLoss+objLoss+clsLoss;
% Compute gradients of learnables with regard to loss.
gradients = dlgradient(totalLoss, net.Learnables);
boxLoss = gather(extractdata(boxLoss));
objLoss = gather(extractdata(objLoss));
clsLoss = gather(extractdata(clsLoss));
totalLoss = gather(extractdata(totalLoss));
end
function [tcls,tbox,indices,anchor_grids] = buildTargets(currentYOLOV3Layer,YTrain)
% 功能:构建目标值
% 输入:
% currentYOLOV3Layer:网络中yolo输出层之一
% YTrain:网络目标值,bs*1大小的cell类型,每个cell下包含Mi*[x,y,width,height,classID]大小的矩阵,Mi为第i张图片含有目标的检测数量,
% 注意其存储的坐标值是相对网络输入图像上的坐标,并无归一化
% 输出:
% tcls:目标真实类别classID,N*1大小,每一项存储classID,其中N<=na*sum(Mi),只输出有效数量的类别N
% tbox:目标的boundingBox,存储真实目标在特征图上的位置(除去x,y整数部分,保留小数),N*4大小,每项形式为[Xdecimal,Ydecimal,gw,gh]
% indices:目标检测框在高维数组中的位置,N*4大小,每一项存储检测框的位置,其形式为[bs,na,gy,gx],他们都是从1开始的索引,与Python不同
% anchor_grids:所用有效的在特征图上的anchor,N*2大小,每项形式为[anchor_w,anchor_h]
% 注意:
% 此函数是核心,用于产生各个yolov3损失类型的目标,输出每个参数维度都有意义,顺序要保持一致,总的高维顺序为bs*na*h*w*(5+nc),此顺序为darknet
% 官网的顺序一致,非matlab官方一致
%
% author:cuixingxing
% emalil:cuixingxing150@email.com
% 2020.4.25
%
h = currentYOLOV3Layer.numY;
w = currentYOLOV3Layer.numX;
stride = currentYOLOV3Layer.stride;
bs = size(YTrain,1);
% 把targets转换成nt*[imageIDs,classIDs,gx,gy,gw,gh]二维矩阵
scalex = w/currentYOLOV3Layer.imageSize(2);
scaley = h/currentYOLOV3Layer.imageSize(1);
gwh = currentYOLOV3Layer.anchorsUse/stride; % 此处anchor_grids大小为na*2
targets = cat(1,YTrain{:}); % nt*5大小矩阵,nt为该bach下目标检测框数量
output = cellfun(@(x)size(x,1),YTrain);
imageIDs = repelem((1:bs)',output);
classIDs = targets(:,5);
targets = [imageIDs,classIDs,targets(:,1:end-1)];% nt*6大小,[imageIDs,classIDs,x,y,w,h]
% 计算目标检测框在特征图上的大小
targets(:,[3,5]) = targets(:,[3,5]).*scalex;% gx,gw
targets(:,[4,6]) = targets(:,[4,6]).*scaley;% nt*6大小,[imageIDs,classIDs,gx,gy,gw,gh]
% 分别获取每个anchor每个bbox的target
if ~isempty(targets)
iou = getMaxIOUPredictedWithGroundTruth(gwh,targets(:,5:6));
iouThresh = 0.1;
reject = true;
if reject
iou(iou<iouThresh) = 0;
end
use_best_anchor = false;
if use_best_anchor
[iou,anchorsID] = max(iou);
[~,targetsID] = find(iou);
targets = targets(targetsID,:)
else % use all anchors
[anchorsID,targetsID] = find(iou);
targets = targets(targetsID,:); % N*6 ,[imageIDs,classIDs,gx,gy,gw,gh]
end
anchorsID = anchorsID(:); % N*1
end
gxy = targets(:,3:4);
targets(:,3:4) = gxy -floor(gxy);
% 返回targets值
tcls = targets(:,2);
if ~isempty(tcls)
assert(max(tcls)<=currentYOLOV3Layer.classes,'Target classes exceed model classes!');
end
tbox = targets(:,3:6);
xyPos = ceil(gxy);% 取ceil是因为matlab数组是从1开始的索引,gxy已经是缩小范围的坐标
indices = [targets(:,1),anchorsID,xyPos(:,2),xyPos(:,1)];
anchor_grids = gwh(anchorsID,:);
end
function iou = getMaxIOUPredictedWithGroundTruth(gwh,truth)
% getMaxIOUPredictedWithGroundTruth computes the maximum intersection over
% union scores for every pair of predictions and ground-truth boxes.
% 输入:
% gwh: 特征图上的anchor,大小为na*2
% truth:特征图上的目标真实检测框,只含有宽高,大小为nt*2
% 输出:
% iou:大小为na*nt,每一个值代表第i个anchor与第j个真值之间的交并比
%
% author:cuixingxing
% emalil:cuixingxing150@email.com
% 2020.4.25
%
bboxA = [ones(size(gwh)),gwh];
bboxB = [ones(size(truth)),truth];
iou = bboxOverlapRatio(bboxA,bboxB);
end