Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add prelu layer support for caffe convert tool (#4277)
Browse files Browse the repository at this point in the history
* add prelu support

* fix params
  • Loading branch information
fengshikun authored and piiswrong committed Jan 11, 2017
1 parent d6328a5 commit 541d109
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tools/caffe_converter/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ def main():
first_conv = True

for layer_name, layer_type, layer_blobs in iter:
if layer_type == 'Convolution' or layer_type == 'InnerProduct' or layer_type == 4 or layer_type == 14:
if layer_type == 'Convolution' or layer_type == 'InnerProduct' or layer_type == 4 or layer_type == 14 \
or layer_type == 'PReLU':
if layer_type == 'PReLU':
assert(len(layer_blobs) == 1)
wmat = layer_blobs[0].data
weight_name = layer_name + '_gamma'
arg_params[weight_name] = mx.nd.zeros(wmat.shape)
arg_params[weight_name][:] = wmat
continue
assert(len(layer_blobs) == 2)
wmat_dim = []
if getattr(layer_blobs[0].shape, 'dim', None) is not None:
Expand Down
5 changes: 5 additions & 0 deletions tools/caffe_converter/convert_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def proto2script(proto_file):
type_string = 'mx.symbol.BatchNorm'
param = layer[i].batch_norm_param
param_string = 'use_global_stats=%s' % param.use_global_stats
if layer[i].type == 'PReLU':
type_string = 'mx.symbol.LeakyReLU'
param = layer[i].prelu_param
param_string = "act_type='prelu', slope=%f" % param.filler.value
need_flatten[name] = need_flatten[mapping[layer[i].bottom[0]]]
if type_string == '':
raise Exception('Unknown Layer %s!' % layer[i].type)
if type_string != 'split':
Expand Down

0 comments on commit 541d109

Please sign in to comment.