Skip to content

Commit e343416

Browse files
authored
Caffe Plugin (apache#2746)
* [Caffe Plugin] Add caffe plugin. * [caffe-plugin] Rename caffe_init to caffe_mode. * [caffe-plugin] Re-organize caffe plugin * [caffe-plugin] Remove caffe namespace. * [caffe-plugin] Add comments. * [caffe-plugin] Rewrite CaffeOperatorProp.ListFunction(). * [caffe-plugin] Finish comments. * Add caffe md. * [Caffe-Plugin] Done install section. * [Caffe-Plugin] Complement plugin document with adding new layers steps. * [Caffe-Plugin] Fix typo. * [Caffe-Plugin] Rename variable. * [Caffe-Plugin] Clean code. * [Caffe-Plugin] Add todos. * [caffe-plugin] Support multiple inputs & outputs. * [caffe-plugin] Add Todo list. * [caffe-plugin] Write caffe initial setting registry. * [caffe-plugin] Remove caffeNum::types. * [caffe-plugin] Rewrite register macro. * [caffe-plugin] Rewrite comment. * [caffe-plugin] Add caffe-op layers (no cudnn or data layer). * [caffe-plugin] Add caffe folder. * [caffe-plugin] Change caffe symbol param name. * [caffe-plugin] Clean improper code & Use caffe pooling op. * [caffe-plugin] Caffe Operator In&Out num & Default prototxt done. * [caffe-plugin] Clean caffe net. * [caffe-plugin] Rewrite get weight number in ListArg(). * [caffe-plugin] Refine arglist(). * [caffe-plugin] Add in_num case to caffe example. * [caffe-plugin] Refine caffe doc. * [caffe-plugin] Fix typo. * [caffe-plugin] Remove extra space. * [caffe-plugin] Remove op_type_string. * [master] Add set weight func. * [caffe-plugin] Complete weight. * [caffe-plugin] rewrite makefile * [caffe-plugin] Support cudnn layers. * [caffe-plugin] Document & Clean code. * [caffe-plugin] Remove tensor. * [caffe-plugin] Rewrite args. * [caffe-plugin] Remove check codes. * [caffe-plugin] Rewrite Tblob convertion to Caffe Blob. * [caffe-plugin] Clean code. * [caffe-plugin] Add Caffe Loss. * [caffe-plugin] Pass caffe loss examples! * [caffe-plugin] Python interface refine. * [caffe-plugin] Deal with 0-dim Blob. * [caffe-plugin] Fix bug. * [caffe-plugin] Add specific setup for caffe layers. * [caffe-plugin] Clean code. * [caffe-plugin] Support Dtype. * [caffe-plugin] Clean code. * [caffe-plugin] Add mxnet/op/caffe and remove using in global namespace. * [caffe-plugin] Replace by caff registry. * [caffe-plugin] Use FlatTo2D. * [caffe-plugin] Caffe metric inherits torch's. * [caffe-plugin] Support data=[symbol1, symbol2] * [caffe-plugin] Rename to caffeop. * [caffe-plugin] clean code. * [plugin-master] Rename data_num to num_data. * [caffe-plugin] Add comment & Fix lint. * [caffe-plugin] Refine caffe doc.
1 parent 95ef531 commit e343416

20 files changed

+1382
-2
lines changed

docs/how_to/caffe.md

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# How to use Caffe Op(Layer) in MXNet
2+
3+
This tutorial demonstrates how to call Caffe operator in MXNet:
4+
5+
* 1) Compile MXNet with Caffe support.
6+
7+
* 2) Embed Caffe's neural network layers into MXNet's symbolic graph.
8+
9+
## Install Caffe With MXNet interface
10+
* Download offical Caffe repository [BVLC/Caffe](https://github.com/BVLC/caffe).
11+
* Download mxnet-interface [patch] (https://github.com/BVLC/caffe/pull/4527.patch). Move patch file under your caffe folder and apply the patch by `git apply 4527.patch`.
12+
* Install caffe following [official guide](http://caffe.berkeleyvision.org/installation.html).
13+
14+
## Compile with Caffe
15+
* In mxnet folder, open `config.mk` (if you haven't already, copy `make/config.mk` (Linux) or `make/osx.mk` (Mac) into MXNet root folder as `config.mk`) and uncomment the lines `CAFFE_PATH = $(HOME)/caffe` and `MXNET_PLUGINS += plugin/caffe/caffe.mk`. Modify `CAFFE_PATH` to your caffe installation if necessary.
16+
* Run `make clean && make` to build with caffe support.
17+
18+
## Caffe Operators(Layers)
19+
Caffe's neural network layers are supported by MXNet through `mxnet.symbol.CaffeOperator` symbol.
20+
For example, the following code shows multi-layer perception network and lenet for classifying MNIST digits ([full code](https://github.com/HrWangChengdu/mxnet/blob/master/example/caffe/caffe_net.py)):
21+
```Python
22+
data = mx.symbol.Variable('data')
23+
fc1 = mx.symbol.CaffeOp(data_0=data, num_weight=2, name='fc1', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 128} }")
24+
act1 = mx.symbol.CaffeOp(data_0=fc1, prototxt="layer{type:\"TanH\"}")
25+
fc2 = mx.symbol.CaffeOp(data_0=act1, num_weight=2, name='fc2', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 64} }")
26+
act2 = mx.symbol.CaffeOp(data_0=fc2, prototxt="layer{type:\"TanH\"}")
27+
fc3 = mx.symbol.CaffeOp(data_0=act2, num_weight=2, name='fc3', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 10}}")
28+
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')
29+
```
30+
Let's break it down. First `data = mx.symbol.Variable('data')` defines a Variable as placeholder for input.
31+
Then it's fed through Caffe's operators with `fc1 = mx.symbol.CaffeOperator(data_0=data, num_weight=2, name='fc1', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 128} }")`.
32+
33+
The inputs to caffe layer are named as data_i for i=0 ... num_data-1 as `num_data` is the number of inputs. You may skip the argument, as the example does, if its value is 1. `num_weight` is number of `blobs_`(weights) in caffe layer. The default value is 0, as most layers, e.g. tanh, owns no weight. `prototxt` is the caffe's layer configuration string.
34+
35+
We could also replace the last line by:
36+
```Python
37+
label = mx.symbol.Variable('softmax_label')
38+
mlp = mx.symbol.CaffeLoss(data=fc3, label=label, grad_scale=1, name='softmax', prototxt="layer{type:\"SoftmaxWithLoss\"}")
39+
```
40+
to use loss funciton in caffe.
41+
42+
## Use your own customized layers
43+
Running new caffe layer from mxnet is no difference than using regular caffe layers, through rules above. There's no need to add any code in mxnet.

example/caffe/caffe_net.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os, sys
2+
import mxnet as mx
3+
from data import get_iterator
4+
import argparse
5+
import train_model
6+
7+
def get_mlp():
8+
"""
9+
multi-layer perceptron
10+
"""
11+
data = mx.symbol.Variable('data')
12+
fc1 = mx.symbol.CaffeOp(data_0=data, num_weight=2, name='fc1', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 128} }")
13+
act1 = mx.symbol.CaffeOp(data_0=fc1, prototxt="layer{type:\"TanH\"}")
14+
fc2 = mx.symbol.CaffeOp(data_0=act1, num_weight=2, name='fc2', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 64} }")
15+
act2 = mx.symbol.CaffeOp(data_0=fc2, prototxt="layer{type:\"TanH\"}")
16+
fc3 = mx.symbol.CaffeOp(data_0=act2, num_weight=2, name='fc3', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 10}}")
17+
if use_caffe_loss:
18+
label = mx.symbol.Variable('softmax_label')
19+
mlp = mx.symbol.CaffeLoss(data=fc3, label=label, grad_scale=1, name='softmax', prototxt="layer{type:\"SoftmaxWithLoss\"}")
20+
else:
21+
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')
22+
return mlp
23+
24+
def get_lenet():
25+
"""
26+
LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
27+
Haffner. "Gradient-based learning applied to document recognition."
28+
Proceedings of the IEEE (1998)
29+
"""
30+
data = mx.symbol.Variable('data')
31+
32+
# first conv
33+
conv1 = mx.symbol.CaffeOp(data_0=data, num_weight=2, prototxt="layer{type:\"Convolution\" convolution_param { num_output: 20 kernel_size: 5 stride: 1} }")
34+
act1 = mx.symbol.CaffeOp(data_0=conv1, prototxt="layer{type:\"TanH\"}")
35+
pool1 = mx.symbol.CaffeOp(data_0=act1, prototxt="layer{type:\"Pooling\" pooling_param { pool: MAX kernel_size: 2 stride: 2}}")
36+
37+
# second conv
38+
conv2 = mx.symbol.CaffeOp(data_0=pool1, num_weight=2, prototxt="layer{type:\"Convolution\" convolution_param { num_output: 50 kernel_size: 5 stride: 1} }")
39+
act2 = mx.symbol.CaffeOp(data_0=conv2, prototxt="layer{type:\"TanH\"}")
40+
pool2 = mx.symbol.CaffeOp(data_0=act2, prototxt="layer{type:\"Pooling\" pooling_param { pool: MAX kernel_size: 2 stride: 2}}")
41+
42+
fc1 = mx.symbol.CaffeOp(data_0=pool2, num_weight=2, prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 500} }")
43+
act3 = mx.symbol.CaffeOp(data_0=fc1, prototxt="layer{type:\"TanH\"}")
44+
45+
# second fullc
46+
fc2 = mx.symbol.CaffeOp(data_0=act3, num_weight=2, prototxt="layer{type:\"InnerProduct\"inner_product_param{num_output: 10} }")
47+
if use_caffe_loss:
48+
label = mx.symbol.Variable('softmax_label')
49+
lenet = mx.symbol.CaffeLoss(data=fc2, label=label, grad_scale=1, name='softmax', prototxt="layer{type:\"SoftmaxWithLoss\"}")
50+
else:
51+
lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
52+
return lenet
53+
54+
def parse_args():
55+
parser = argparse.ArgumentParser(description='train an image classifer on mnist')
56+
parser.add_argument('--network', type=str, default='lenet',
57+
choices = ['mlp', 'lenet'],
58+
help='the cnn to use')
59+
parser.add_argument('--caffe-loss', type=int, default=0,
60+
help='Use CaffeLoss symbol')
61+
parser.add_argument('--data-dir', type=str, default='mnist/',
62+
help='the input data directory')
63+
parser.add_argument('--gpus', type=str,
64+
help='the gpus will be used, e.g "0,1,2,3"')
65+
parser.add_argument('--num-examples', type=int, default=60000,
66+
help='the number of training examples')
67+
parser.add_argument('--batch-size', type=int, default=128,
68+
help='the batch size')
69+
parser.add_argument('--lr', type=float, default=.1,
70+
help='the initial learning rate')
71+
parser.add_argument('--model-prefix', type=str,
72+
help='the prefix of the model to load/save')
73+
parser.add_argument('--save-model-prefix', type=str,
74+
help='the prefix of the model to save')
75+
parser.add_argument('--num-epochs', type=int, default=10,
76+
help='the number of training epochs')
77+
parser.add_argument('--load-epoch', type=int,
78+
help="load the model on an epoch using the model-prefix")
79+
parser.add_argument('--kv-store', type=str, default='local',
80+
help='the kvstore type')
81+
parser.add_argument('--lr-factor', type=float, default=1,
82+
help='times the lr with a factor for every lr-factor-epoch epoch')
83+
parser.add_argument('--lr-factor-epoch', type=float, default=1,
84+
help='the number of epoch to factor the lr, could be .5')
85+
return parser.parse_args()
86+
87+
88+
if __name__ == '__main__':
89+
args = parse_args()
90+
use_caffe_loss = args.caffe_loss
91+
92+
if args.network == 'mlp':
93+
data_shape = (784, )
94+
net = get_mlp()
95+
else:
96+
data_shape = (1, 28, 28)
97+
net = get_lenet()
98+
99+
# train
100+
if use_caffe_loss:
101+
train_model.fit(args, net, get_iterator(data_shape), mx.metric.Caffe())
102+
else:
103+
train_model.fit(args, net, get_iterator(data_shape))

example/caffe/data.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import sys
2+
import os
3+
# code to automatically download dataset
4+
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
5+
sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
6+
import get_data
7+
import mxnet as mx
8+
9+
def get_iterator(data_shape):
10+
def get_iterator_impl(args, kv):
11+
"""return train and val iterators for mnist"""
12+
# download data
13+
get_data.GetMNIST_ubyte()
14+
flat = False if len(data_shape) == 3 else True
15+
16+
train = mx.io.MNISTIter(
17+
image = "data/train-images-idx3-ubyte",
18+
label = "data/train-labels-idx1-ubyte",
19+
input_shape = data_shape,
20+
batch_size = args.batch_size,
21+
shuffle = True,
22+
flat = flat,
23+
num_parts = kv.num_workers,
24+
part_index = kv.rank)
25+
26+
val = mx.io.MNISTIter(
27+
image = "data/t10k-images-idx3-ubyte",
28+
label = "data/t10k-labels-idx1-ubyte",
29+
input_shape = data_shape,
30+
batch_size = args.batch_size,
31+
flat = flat,
32+
num_parts = kv.num_workers,
33+
part_index = kv.rank)
34+
35+
return (train, val)
36+
return get_iterator_impl
37+

example/caffe/train_model.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import mxnet as mx
2+
import logging
3+
import os
4+
5+
def fit(args, network, data_loader, eval_metrics=None, batch_end_callback=None):
6+
# kvstore
7+
kv = mx.kvstore.create(args.kv_store)
8+
9+
# logging
10+
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
11+
if 'log_file' in args and args.log_file is not None:
12+
log_file = args.log_file
13+
log_dir = args.log_dir
14+
log_file_full_name = os.path.join(log_dir, log_file)
15+
if not os.path.exists(log_dir):
16+
os.mkdir(log_dir)
17+
logger = logging.getLogger()
18+
handler = logging.FileHandler(log_file_full_name)
19+
formatter = logging.Formatter(head)
20+
handler.setFormatter(formatter)
21+
logger.addHandler(handler)
22+
logger.setLevel(logging.DEBUG)
23+
logger.info('start with arguments %s', args)
24+
else:
25+
logging.basicConfig(level=logging.DEBUG, format=head)
26+
logging.info('start with arguments %s', args)
27+
28+
# load model
29+
model_prefix = args.model_prefix
30+
if model_prefix is not None:
31+
model_prefix += "-%d" % (kv.rank)
32+
model_args = {}
33+
if args.load_epoch is not None:
34+
assert model_prefix is not None
35+
tmp = mx.model.FeedForward.load(model_prefix, args.load_epoch)
36+
model_args = {'arg_params' : tmp.arg_params,
37+
'aux_params' : tmp.aux_params,
38+
'begin_epoch' : args.load_epoch}
39+
# save model
40+
save_model_prefix = args.save_model_prefix
41+
if save_model_prefix is None:
42+
save_model_prefix = model_prefix
43+
checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)
44+
45+
# data
46+
(train, val) = data_loader(args, kv)
47+
48+
# train
49+
devs = mx.cpu() if args.gpus is None else [
50+
mx.gpu(int(i)) for i in args.gpus.split(',')]
51+
52+
epoch_size = args.num_examples / args.batch_size
53+
54+
if args.kv_store == 'dist_sync':
55+
epoch_size /= kv.num_workers
56+
model_args['epoch_size'] = epoch_size
57+
58+
if 'lr_factor' in args and args.lr_factor < 1:
59+
model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
60+
step = max(int(epoch_size * args.lr_factor_epoch), 1),
61+
factor = args.lr_factor)
62+
63+
if 'clip_gradient' in args and args.clip_gradient is not None:
64+
model_args['clip_gradient'] = args.clip_gradient
65+
66+
# disable kvstore for single device
67+
if 'local' in kv.type and (
68+
args.gpus is None or len(args.gpus.split(',')) is 1):
69+
kv = None
70+
71+
model = mx.model.FeedForward(
72+
ctx = devs,
73+
symbol = network,
74+
num_epoch = args.num_epochs,
75+
learning_rate = args.lr,
76+
momentum = 0.9,
77+
wd = 0.00001,
78+
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
79+
**model_args)
80+
81+
if eval_metrics == None:
82+
eval_metrics = ['accuracy']
83+
## TopKAccuracy only allows top_k > 1
84+
for top_k in [5, 10, 20]:
85+
eval_metrics.append(mx.metric.create('top_k_accuracy', top_k = top_k))
86+
87+
if batch_end_callback is not None:
88+
if not isinstance(batch_end_callback, list):
89+
batch_end_callback = [batch_end_callback]
90+
else:
91+
batch_end_callback = []
92+
batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
93+
94+
model.fit(
95+
X = train,
96+
eval_data = val,
97+
eval_metric = eval_metrics,
98+
kvstore = kv,
99+
batch_end_callback = batch_end_callback,
100+
epoch_end_callback = checkpoint)

make/config.mk

+4
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ EXTRA_OPERATORS =
111111
# plugins
112112
#----------------------------
113113

114+
# whether to use caffe integration. This requires including caffe submodule.
115+
# CAFFE_PATH = caffe-lite
116+
# MXNET_PLUGINS += plugin/caffe/caffe.mk
117+
114118
# whether to use torch integration. This requires installing torch.
115119
# You also need to add TORCH_PATH/install/lib to your LD_LIBRARY_PATH
116120
# TORCH_PATH = $(HOME)/torch

plugin/caffe/caffe.mk

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
CFLAGS += -I$(CAFFE_PATH)/include -I$(CAFFE_PATH)/build/src
2+
LDFLAGS += -lprotobuf -lboost_system -lboost_thread -lboost_filesystem -lgflags -lglog -L$(CAFFE_PATH)/build/lib -lcaffe
3+
4+
ifeq ($(USE_CUDNN), 1)
5+
CFLAGS += -DUSE_CUDNN=1
6+
endif
7+
8+
ifeq ($(USE_CUDA), 0)
9+
CFLAGS += -DCPU_ONLY=1
10+
endif
11+
12+
CAFFE_SRC = $(wildcard plugin/caffe/*.cc)
13+
PLUGIN_OBJ += $(patsubst %.cc, build/%.o, $(CAFFE_SRC))
14+
CAFFE_CUSRC = $(wildcard plugin/caffe/*.cu)
15+
PLUGIN_CUOBJ += $(patsubst %.cu, build/%_gpu.o, $(CAFFE_CUSRC))

plugin/caffe/caffe_blob.cc

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file caffe_blob.cc
4+
* \brief Implementations of SetDataGradToBlob given various device/dimension
5+
* \author Haoran Wang
6+
*/
7+
#include "caffe_blob.h"
8+
namespace mxnet {
9+
namespace op {
10+
namespace caffe {
11+
12+
template<>
13+
void SetDataGradToBlob<mshadow::cpu, float>(caffeMemoryTypes memType,
14+
std::vector<::caffe::Blob<float>*>::iterator blob,
15+
std::vector<mshadow::TBlob>::const_iterator itr) {
16+
float *data_ptr = reinterpret_cast<float*>((*itr).dptr_);
17+
if (memType == Data)
18+
(*blob)->set_cpu_data(data_ptr);
19+
else
20+
(*blob)->set_cpu_diff(data_ptr);
21+
}
22+
23+
template<>
24+
void SetDataGradToBlob<mshadow::cpu, double>(caffeMemoryTypes memType,
25+
std::vector<::caffe::Blob<double>*>::iterator blob,
26+
std::vector<mshadow::TBlob>::const_iterator itr) {
27+
double *data_ptr = reinterpret_cast<double*>((*itr).dptr_);
28+
if (memType == Data)
29+
(*blob)->set_cpu_data(data_ptr);
30+
else
31+
(*blob)->set_cpu_diff(data_ptr);
32+
}
33+
34+
template<>
35+
void SetDataGradToBlob<mshadow::gpu, float>(caffeMemoryTypes memType,
36+
std::vector<::caffe::Blob<float>*>::iterator blob,
37+
std::vector<mshadow::TBlob>::const_iterator itr) {
38+
float *data_ptr = reinterpret_cast<float*>((*itr).dptr_);
39+
if (memType == Data)
40+
(*blob)->set_gpu_data(data_ptr);
41+
else
42+
(*blob)->set_gpu_diff(data_ptr);
43+
}
44+
45+
template<>
46+
void SetDataGradToBlob<mshadow::gpu, double>(caffeMemoryTypes memType,
47+
std::vector<::caffe::Blob<double>*>::iterator blob,
48+
std::vector<mshadow::TBlob>::const_iterator itr) {
49+
double *data_ptr = reinterpret_cast<double*>((*itr).dptr_);
50+
if (memType == Data)
51+
(*blob)->set_gpu_data(data_ptr);
52+
else
53+
(*blob)->set_gpu_diff(data_ptr);
54+
}
55+
56+
mshadow::TShape Vector2TShape(const std::vector<int> &vec_int) {
57+
mshadow::TShape res;
58+
std::vector<mshadow::index_t> vec_indx;
59+
for (int i = 0; i < vec_int.size(); ++i)
60+
vec_indx.push_back(vec_int[i]);
61+
// 0-dim represents scalar in caffe
62+
if (vec_int.size() == 0)
63+
vec_indx.push_back(1);
64+
res = vec_indx;
65+
return res;
66+
}
67+
68+
std::vector<int> TShape2Vector(const mshadow::TShape &tshape) {
69+
std::vector<int> s;
70+
for (int i =0 ; i < tshape.ndim(); ++i)
71+
s.push_back(tshape[i]);
72+
return s;
73+
}
74+
75+
} // namespace caffe
76+
} // namespace op
77+
} // namespace mxnet

0 commit comments

Comments
 (0)