This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/caffe importer improvements (#3397)
* add mean image to ndarray converter * special handling for input layer type * added note about inconsistency of pooling with Caffe * add support for TanH and Sigmoid activations
- Loading branch information
1 parent
b0f0b81
commit 06583ee
Showing
3 changed files
with
65 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import mxnet as mx | ||
import numpy as np | ||
import argparse | ||
|
||
caffe_flag = True | ||
try: | ||
import caffe | ||
from caffe.proto import caffe_pb2 | ||
except ImportError: | ||
caffe_flag = False | ||
import caffe_parse.caffe_pb2 | ||
|
||
def protoBlobFileToND(protofile): | ||
data = '' | ||
file = open(protofile, "r") | ||
if not file: | ||
raise self.ProcessException("ERROR (" + protofile + ")!") | ||
data = file.read() | ||
file.close() | ||
|
||
if caffe_flag: | ||
mean_blob = caffe.proto.caffe_pb2.BlobProto() | ||
else: | ||
mean_blob = caffe_parse.caffe_pb2.BlobProto() | ||
|
||
mean_blob.ParseFromString(data) | ||
img_mean_np = np.array(mean_blob.data) | ||
img_mean_np = img_mean_np.reshape( | ||
mean_blob.channels, mean_blob.height, mean_blob.width | ||
) | ||
# swap channels from Caffe BGR to RGB | ||
img_mean_np2 = img_mean_np | ||
img_mean_np[0] = img_mean_np2[2] | ||
img_mean_np[2] = img_mean_np2[0] | ||
return mx.nd.array(img_mean_np) | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Caffe prototxt to mxnet model parameter converter.\ | ||
Note that only basic functions are implemented. You are welcomed to contribute to this file.') | ||
parser.add_argument('mean_image_proto', help='The protobuf file in Caffe format') | ||
parser.add_argument('save_name', help='The name of the output file prefix') | ||
args = parser.parse_args() | ||
nd = protoBlobFileToND(args.mean_image_proto) | ||
mx.nd.save(args.save_name + ".nd", {"mean_image": nd}) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |