Skip to content

Commit

Permalink
allow labelmap switching
Browse files Browse the repository at this point in the history
  • Loading branch information
albanie committed Sep 4, 2018
1 parent a1887cc commit e1ec0d7
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions benchmarks/cnn_imagenet_se_mcn.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
opts.gpus = 3 ;
opts.continue = 1 ;
opts.batchSize = 256 ;
opts.labelFormat = 'v1' ;
opts.model = 'SE-ResNet-50-mcn' ;
opts.modelDir = fullfile(vl_rootnn, 'data/models-import') ;
opts.dataDir = fullfile(vl_rootnn, 'data/datasets/ILSVRC2012') ;
Expand Down Expand Up @@ -38,7 +39,11 @@
end

% remap labels to match the order used in training
imdb = updateLabelMap(imdb, opts) ;
switch opts.labelFormat
case 'v1', imdb = updateLabelMap(imdb, opts) ;
case 'v2' % do nothing
otherwise, error('unrecognised label format %s', opts.labelFormat) ;
end

% -------------------------------------------------------------------------
% Prepare model
Expand Down Expand Up @@ -71,7 +76,12 @@
end

dag = dagnn.DagNN.loadobj(load(modelPath)) ;
dag.addLayer('softmax', dagnn.SoftMax(), dag.layers(end).outputs, 'prediction', {}) ;
if isa(dag.layers(end).block, 'dagnn.SoftMax')
dag.renameVar(dag.layers(end).outputs, 'prediction') ;
else % add softmax to logits if softmax isn't yet present
dag.addLayer('softmax', dagnn.SoftMax(), dag.layers(end).outputs, ...
'prediction', {}) ;
end
dag.addLayer('top1err', dagnn.Loss('loss', 'classerror'), ...
{'prediction','label'}, 'top1err') ;
dag.addLayer('top5err', dagnn.Loss('loss', 'topkerror', 'opts', {'topK',5}), ...
Expand All @@ -95,9 +105,9 @@
end

% ----------------------------------------
function imdb = updateLabelMap(imdb, opts)
function imdb = updateLabelMap(imdb, opts)
% ----------------------------------------
labelMap = importdata(opts.labelMap) ;
labelMap = importdata(opts.labelMap) ;
keep = imdb.images.label ~= 0 ;
newLabels = labelMap(imdb.images.label(keep)) ;
imdb.images.label(keep) = newLabels ;
Expand Down

0 comments on commit e1ec0d7

Please sign in to comment.