Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Feature/caffe importer improvements (#3397)
Browse files Browse the repository at this point in the history
* add mean image to ndarray converter

* special handling for input layer type

* added note about inconsistency of pooling with Caffe

* add support for TanH and Sigmoid activations
  • Loading branch information
sbodenstein authored and mli committed Sep 29, 2016
1 parent b0f0b81 commit 06583ee
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tools/caffe_converter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ For example: `python convert_model.py VGG_ILSVRC_16_layers_deploy.prototxt VGG_I
* We have verified the results of VGG_16/VGG_19 model and BVLC_googlenet results from Caffe model zoo.
* The tool only supports single input and single output network.
* The tool can only work with the L2LayerParameter in Caffe.
* Caffe uses a convention for multi-strided pooling output shape inconsistent with MXNet
* This importer doesn't handle this problem properly yet
* And example of this failure is importing bvlc_Googlenet. The user needs to add padding to stride-2 pooling to make this work right now.
14 changes: 14 additions & 0 deletions tools/caffe_converter/convert_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def proto2script(proto_file):
input_dim = proto.input_dim
elif len(proto.input_shape) > 0:
input_dim = proto.input_shape[0].dim
elif (layer[0].type == "Input"):
input_dim = layer[0].input_param.shape._values[0].dim
layer.pop(0)
else:
raise Exception('Invalid proto file.')

# We assume the first bottom blob of first layer is the output from data layer
input_name = layer[0].bottom[0]
output_name = ""
Expand Down Expand Up @@ -116,6 +122,14 @@ def proto2script(proto_file):
type_string = 'mx.symbol.Activation'
param_string = "act_type='relu'"
need_flatten[name] = need_flatten[mapping[layer[i].bottom[0]]]
if layer[i].type == 'TanH' or layer[i].type == 23:
type_string = 'mx.symbol.Activation'
param_string = "act_type='tanh'"
need_flatten[name] = need_flatten[mapping[layer[i].bottom[0]]]
if layer[i].type == 'Sigmoid' or layer[i].type == 19:
type_string = 'mx.symbol.Activation'
param_string = "act_type='sigmoid'"
need_flatten[name] = need_flatten[mapping[layer[i].bottom[0]]]
if layer[i].type == 'LRN' or layer[i].type == 15:
type_string = 'mx.symbol.LRN'
param = layer[i].lrn_param
Expand Down
48 changes: 48 additions & 0 deletions tools/caffe_converter/mean_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import mxnet as mx
import numpy as np
import argparse

caffe_flag = True
try:
import caffe
from caffe.proto import caffe_pb2
except ImportError:
caffe_flag = False
import caffe_parse.caffe_pb2

def protoBlobFileToND(protofile):
data = ''
file = open(protofile, "r")
if not file:
raise self.ProcessException("ERROR (" + protofile + ")!")
data = file.read()
file.close()

if caffe_flag:
mean_blob = caffe.proto.caffe_pb2.BlobProto()
else:
mean_blob = caffe_parse.caffe_pb2.BlobProto()

mean_blob.ParseFromString(data)
img_mean_np = np.array(mean_blob.data)
img_mean_np = img_mean_np.reshape(
mean_blob.channels, mean_blob.height, mean_blob.width
)
# swap channels from Caffe BGR to RGB
img_mean_np2 = img_mean_np
img_mean_np[0] = img_mean_np2[2]
img_mean_np[2] = img_mean_np2[0]
return mx.nd.array(img_mean_np)

def main():
parser = argparse.ArgumentParser(description='Caffe prototxt to mxnet model parameter converter.\
Note that only basic functions are implemented. You are welcomed to contribute to this file.')
parser.add_argument('mean_image_proto', help='The protobuf file in Caffe format')
parser.add_argument('save_name', help='The name of the output file prefix')
args = parser.parse_args()
nd = protoBlobFileToND(args.mean_image_proto)
mx.nd.save(args.save_name + ".nd", {"mean_image": nd})


if __name__ == '__main__':
main()

0 comments on commit 06583ee

Please sign in to comment.