Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added ResNet support to Caffe converter
Browse files Browse the repository at this point in the history
  • Loading branch information
nirbenz authored and mli committed Mar 2, 2017
1 parent 7cd0610 commit b4e8743
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 27 deletions.
98 changes: 77 additions & 21 deletions tools/caffe_converter/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np
import argparse
import re

import sys

from convert_symbol import proto2symbol

caffe_flag = True
Expand Down Expand Up @@ -38,24 +41,35 @@ def main():
parser.add_argument('save_model_name', help='The name of the output model prefix')
args = parser.parse_args()

prob, input_dim = proto2symbol(args.caffe_prototxt)
sym, arg_params, aux_params, input_dim = process_caffe_model(args.caffe_prototxt, args.caffe_model)
model = mx.mod.Module(symbol=sym, label_names=['prob_label', ])
model.bind(data_shapes=[('data', tuple(input_dim))])
model.init_params(arg_params=arg_params, aux_params=aux_params)
model.save_checkpoint(args.save_model_name, 1)

print ('Saved model successfully to {}'.format(args.save_model_name))

def process_caffe_model(caffe_prototxt, caffe_model, output_file=None, data=None, data_shapes=None):
prob, input_dim = proto2symbol(caffe_prototxt)

layers = ''
layer_names = ''

if caffe_flag:
caffe.set_mode_cpu()
net_caffe = caffe.Net(args.caffe_prototxt, args.caffe_model, caffe.TEST)
net_caffe = caffe.Net(caffe_prototxt, caffe_model, caffe.TEST)
layer_names = net_caffe._layer_names
layers = net_caffe.layers
else:
layers = parse.parse_caffemodel(args.caffe_model)
layers = parse.parse_caffemodel(caffe_model)

arg_shapes, output_shapes, aux_shapes = prob.infer_shape(data=tuple(input_dim))
arg_names = prob.list_arguments()
aux_names = prob.list_auxiliary_states()
arg_shape_dic = dict(zip(arg_names, arg_shapes))
aux_shape_dic = dict(zip(aux_names, aux_shapes))
arg_params = {}

aux_params = {}
iter = ''
if caffe_flag:
iter = get_caffe_iter(layer_names, layers)
Expand All @@ -73,31 +87,38 @@ def main():
arg_params[weight_name] = mx.nd.zeros(wmat.shape)
arg_params[weight_name][:] = wmat
continue
assert (len(layer_blobs) == 2)
wmat_dim = []
if getattr(layer_blobs[0].shape, 'dim', None) is not None:
if len(layer_blobs[0].shape.dim) > 0:
wmat_dim = layer_blobs[0].shape.dim
else:
wmat_dim = [layer_blobs[0].num, layer_blobs[0].channels, layer_blobs[0].height,
layer_blobs[0].width]
wmat_dim = [layer_blobs[0].num, layer_blobs[0].channels, layer_blobs[0].height, layer_blobs[0].width]
else:
wmat_dim = list(layer_blobs[0].shape)
wmat = np.array(layer_blobs[0].data).reshape(wmat_dim)
bias = np.array(layer_blobs[1].data)

channels = wmat_dim[1]
if channels == 3 or channels == 4: # RGB or RGBA
if first_conv:
print('Swapping BGR of caffe into RGB in mxnet')
print ('Swapping BGR of caffe into RGB in mxnet')
wmat[:, [0, 2], :, :] = wmat[:, [2, 0], :, :]

assert (wmat.flags['C_CONTIGUOUS'] is True)
assert (bias.flags['C_CONTIGUOUS'] is True)
print('converting layer {0}, wmat shape = {1}, bias shape = {2}'.format(layer_name, wmat.shape, bias.shape))
assert(wmat.flags['C_CONTIGUOUS'] is True)
sys.stdout.write('converting layer {0}, wmat shape = {1}'.format(layer_name, wmat.shape))
if len(layer_blobs) == 2:
bias = np.array(layer_blobs[1].data)
bias = bias.reshape((bias.shape[0], 1))
assert(bias.flags['C_CONTIGUOUS'] is True)
bias_name = layer_name + "_bias"
bias = bias.reshape(arg_shape_dic[bias_name])
arg_params[bias_name] = mx.nd.zeros(bias.shape)
arg_params[bias_name][:] = bias
sys.stdout.write(', bias shape = {}'.format(bias.shape))

sys.stdout.write('\n')
sys.stdout.flush()
wmat = wmat.reshape((wmat.shape[0], -1))
bias = bias.reshape((bias.shape[0], 1))
weight_name = layer_name + "_weight"
bias_name = layer_name + "_bias"

if weight_name not in arg_shape_dic:
print(weight_name + ' not found in arg_shape_dic.')
Expand All @@ -106,18 +127,53 @@ def main():
arg_params[weight_name] = mx.nd.zeros(wmat.shape)
arg_params[weight_name][:] = wmat

bias = bias.reshape(arg_shape_dic[bias_name])
arg_params[bias_name] = mx.nd.zeros(bias.shape)
arg_params[bias_name][:] = bias

if first_conv and (layer_type == 'Convolution' or layer_type == 4):
first_conv = False

model = mx.mod.Module(symbol=prob, label_names=['prob_label', ])
model.bind(data_shapes=[('data', tuple(input_dim))])
model.init_params(arg_params=arg_params, aux_params={})
elif layer_type == 'Scale':
bn_name = layer_name.replace('scale', 'bn')
gamma = layer_blobs[0].data
beta = layer_blobs[1].data
# beta = np.expand_dims(beta, 1)
beta_name = '{}_beta'.format(bn_name)
gamma_name = '{}_gamma'.format(bn_name)

beta = beta.reshape(arg_shape_dic[beta_name])
gamma = gamma.reshape(arg_shape_dic[gamma_name])
arg_params[beta_name] = mx.nd.zeros(beta.shape)
arg_params[gamma_name] = mx.nd.zeros(gamma.shape)
arg_params[beta_name][:] = beta
arg_params[gamma_name][:] = gamma

assert gamma.flags['C_CONTIGUOUS'] is True
assert beta.flags['C_CONTIGUOUS'] is True
print ('converting scale layer, beta shape = {}, gamma shape = {}'.format(beta.shape, gamma.shape))
elif layer_type == 'BatchNorm':
bn_name = layer_name
mean = layer_blobs[0].data
var = layer_blobs[1].data
moving_average_factor = layer_blobs[2].data
mean_name = '{}_moving_mean'.format(bn_name)
var_name = '{}_moving_var'.format(bn_name)
maf_name = '{}_momentum'.format(bn_name)
mean = mean.reshape(aux_shape_dic[mean_name])
var = var.reshape(aux_shape_dic[var_name])
aux_params[mean_name] = mx.nd.zeros(mean.shape)
aux_params[var_name] = mx.nd.zeros(var.shape)
arg_params[maf_name] = mx.nd.zeros(moving_average_factor.shape)
aux_params[mean_name][:] = mean
aux_params[var_name][:] = var
arg_params[maf_name][:] = moving_average_factor
assert var.flags['C_CONTIGUOUS'] is True
assert mean.flags['C_CONTIGUOUS'] is True
print ('converting batchnorm layer, mean shape = {}, var shape = {}'.format(mean.shape, var.shape))
else:
assert len(layer_blobs) == 0
print ('\tskipping layer {} of type {}'.format(layer_name, layer_type))
return prob, arg_params, aux_params, input_dim


model.save_checkpoint(args.save_model_name, 1)


if __name__ == '__main__':
Expand Down
36 changes: 30 additions & 6 deletions tools/caffe_converter/convert_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ def proto2script(proto_file):
output_name = ""
mapping = {input_name: 'data'}
need_flatten = {input_name: False}
prev_bn = None
for i in range(len(layer)):
type_string = ''
param_string = ''
skip_layer = False
name = re.sub('[-/]', '_', layer[i].name)
if layer[i].type == 'Convolution' or layer[i].type == 4:
type_string = 'mx.symbol.Convolution'
Expand Down Expand Up @@ -167,32 +169,54 @@ def proto2script(proto_file):
if layer[i].type == 'BatchNorm':
type_string = 'mx.symbol.BatchNorm'
param = layer[i].batch_norm_param
param_string = 'use_global_stats=%s' % param.use_global_stats
param_string = 'use_global_stats=%s, fix_gamma=False' % param.use_global_stats
need_flatten[name] = need_flatten[mapping[layer[i].bottom[0]]]
if layer[i].type == 'Scale':
assert layer[i-1].type == 'BatchNorm'
need_flatten[name] = need_flatten[mapping[layer[i].bottom[0]]]
skip_layer = True
prev_bn = re.sub('[-/]', '_', layer[i-1].name)
if layer[i].type == 'PReLU':
type_string = 'mx.symbol.LeakyReLU'
param = layer[i].prelu_param
param_string = "act_type='prelu', slope=%f" % param.filler.value
need_flatten[name] = need_flatten[mapping[layer[i].bottom[0]]]
if type_string == '':
if layer[i].type == 'Eltwise':
type_string = 'mx.symbol.broadcast_add'
param_string = ""
need_flatten[name] = False

if layer[i].type == 'Reshape':
type_string = 'mx.symbol.Reshape'
need_flatten[name] = False
param = layer[i].reshape_param
param_string = "shape=(%s, %s, %s, %s)" % \
(param.shape.dim[0], param.shape.dim[1], param.shape.dim[2], param.shape.dim[3])

if skip_layer:
assert len(layer[i].bottom) == 1
symbol_string += "%s = %s\n" % (name, prev_bn)

elif type_string == '':
raise Exception('Unknown Layer %s!' % layer[i].type)
if type_string != 'split':
elif type_string != 'split':
bottom = layer[i].bottom
if param_string != "":
param_string = ", " + param_string
if len(bottom) == 1:
if need_flatten[mapping[bottom[0]]] and type_string == 'mx.symbol.FullyConnected':
flatten_name = "flatten_%d" % flatten_count
symbol_string += "%s=mx.symbol.Flatten(name='%s', data=%s)\n" % \
(flatten_name, flatten_name, mapping[bottom[0]])
(flatten_name, flatten_name, mapping[bottom[0]])
flatten_count += 1
need_flatten[flatten_name] = False
bottom[0] = flatten_name
mapping[bottom[0]] = bottom[0]
symbol_string += "%s = %s(name='%s', data=%s %s)\n" % \
(name, type_string, name, mapping[bottom[0]], param_string)
(name, type_string, name, mapping[bottom[0]], param_string)
else:
symbol_string += "%s = %s(name='%s', *[%s] %s)\n" % \
(name, type_string, name, ','.join([mapping[x] for x in bottom]), param_string)
(name, type_string, name, ','.join([mapping[x] for x in bottom]), param_string)
for j in range(len(layer[i].top)):
mapping[layer[i].top[j]] = name
output_name = name
Expand Down

0 comments on commit b4e8743

Please sign in to comment.