Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
albanie committed Sep 7, 2017
1 parent d55be22 commit 46c61ec
Showing 4 changed files with 34 additions and 30 deletions.
9 changes: 5 additions & 4 deletions benchmarks/run_se_benchmarks.m
Original file line number Diff line number Diff line change
@@ -9,12 +9,13 @@
useCached = 1 ; % load results from cache if available

importedModels = {
'SE-ResNet-152-mcn', ...
'SENet-mcn', ...
'SE-ResNeXt-101-32x4d-mcn', ...
'SE-ResNeXt-50-32x4d-mcn', ...
'SE-BN-Inception-mcn', ...
'SE-ResNet-50-mcn', ...
'SE-ResNet-101-mcn', ...
'SE-ResNet-152-mcn', ...
'SE-ResNeXt-50-32x4d-mcn', ...
'SE-ResNeXt-101-32x4d-mcn', ...
'SENet-mcn', ...
} ;

for ii = 1:numel(importedModels)
17 changes: 12 additions & 5 deletions check.m
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
%net = load('data/models-import/SE-ResNeXt-50-mcn.mat') ;
%net = load('data/models-import/SE-ResNet-50-mcn.mat') ;
net = load('data/models-import/SE-ResNet-101-mcn.mat') ;
%net = load('data/models-import/SE-ResNet-101-mcn.mat') ;
%net = load('data/models-import/imagenet-resnet-50-dag.mat') ;
%net = load('data/models-import/SE-BN-Inception-mcn.mat') ;
net = load('data/models-import/SE-BN-Inception-mcn.mat') ;
dag = dagnn.DagNN.loadobj(net) ;

imPath = '000017.jpg' ;
labelMapPath = fullfile(vl_rootnn, 'contrib/mcnSENets/misc/label_map.txt') ;
labelMap = importdata(labelMapPath) ;

%imPath = '000017.jpg' ;
imPath = 'peppers.png' ;
im = single(imresize(imread(imPath), [224 224])) ;
RGB = [123, 117, 104] ;
rgb = permute(RGB, [1 3 2]) ;
@@ -14,7 +18,10 @@
dag.eval({'data', im}) ;
preds = dag.vars(dag.getVarIndex('prob')).value ;

[bestScore, best] = max(preds) ;
%figure(1) ; clf ; imagesc(im) ; zs_dispFig ;
[bestScore, best_] = max(preds) ;
%best = best_ ;
best = find(labelMap == best_) ; % remap label
%best = labelMap(best_) ;
figure(1) ; clf ; imagesc(im) ; zs_dispFig ;
fprintf('%s (%d), score %.3f\n',...
dag.meta.classes.description{best}, best, bestScore) ;
31 changes: 16 additions & 15 deletions matlab/vl_nnreshape.m
Original file line number Diff line number Diff line change
@@ -24,23 +24,24 @@
% block projected onto DZDY. DZDX and DZDY have the same dimensions
% as X and Y respectively.
%
% Copyright (C) 2017 Samuel Albanie and Andrea Vedaldi.
% Copyright (C) 2017 Samuel Albanie and Andrea Vedaldi
% Licensed under The MIT License [see LICENSE.md for details]

[~, dzdy] = vl_argparsepos(struct(), varargin) ;
[~, dzdy] = vl_argparsepos(struct(), varargin) ;

if isnumeric(shape) % apply caffe style conventions if needed
shape_ = num2cell(shape) ;
k = find(shape == -1) ; if k, shape_{k} = [] ; end
k = find(shape == 0) ;
if k, rep = arrayfun(@(i) {size(x,i)}, k) ; shape_(k) = rep ; end
shape = shape_ ;
end
if isnumeric(shape) % apply caffe style conventions if needed
shape_ = num2cell(shape) ;
if numel(shape_) == 2, shape_{3} = [] ; end
k = find(shape == -1) ; if k, shape_{k} = [] ; end
k = find(shape == 0) ;
if k, rep = arrayfun(@(i) {size(x,i)}, k) ; shape_(k) = rep ; end
shape = shape_ ;
end

batchSize = size(x, 4);
batchSize = size(x, 4);

if isempty(dzdy)
y = reshape(x, shape{1}, shape{2}, shape{3}, batchSize) ;
else
y = reshape(dzdy{1}, size(x)) ;
end
if isempty(dzdy)
y = reshape(x, shape{1}, shape{2}, shape{3}, batchSize) ;
else
y = reshape(dzdy{1}, size(x)) ;
end
7 changes: 1 addition & 6 deletions misc/fix_se_imports.m
Original file line number Diff line number Diff line change
@@ -42,12 +42,7 @@ function fix_se_imports(varargin)
% fix meta
fprintf('adding info to %s (%d/%d)\n', modelPath, mm, numel(modelNames)) ;
net.meta.classes = imdb.classes ;
if contains(modelPath, 'Inception')
imSize = [299 299 3] ;
else
imSize = [224 224 3] ;
end
net.meta.normalization.imageSize = imSize ;
net.meta.normalization.imageSize = [224 224 3] ;
net = dagnn.DagNN.loadobj(net) ;
net = net.saveobj() ; save(modelPath, '-struct', 'net') ; %#ok
end

0 comments on commit 46c61ec

Please sign in to comment.