Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve Caffe Converter (#5247)
Browse files Browse the repository at this point in the history
* update caffe converter

* more

* update

* updatejk

* more

* update

* doc
  • Loading branch information
mli authored Mar 10, 2017
1 parent b0fc714 commit f442b18
Show file tree
Hide file tree
Showing 19 changed files with 699 additions and 351 deletions.
193 changes: 162 additions & 31 deletions docs/how_to/caffe.md
Original file line number Diff line number Diff line change
@@ -1,53 +1,184 @@
# How to Use Caffe Operators in MXNet
# How to | Convert from Caffe to MXNet

[Caffe](http://caffe.berkeleyvision.org/) is a well-known and widely used deep learning framework. MXNet supports calling most Caffe operators (layers) and loss functions directly in its symbolic graph. Using your own customized Caffe layer is also effortless.
Key topics covered include the following:

MXNet also has embedded [Torch modules and its tensor mathematical functions](https://github.com/dmlc/mxnet/blob/master/docs/how_to/torch.md).
- [Converting Caffe trained models to MXNet](#converting-caffe-trained-models-to-mxnet)
- [Calling Caffe operators in MXNet](#calling-caffe-operators-in-mxnet)

This topic explains how to:
## Converting Caffe trained models to MXNet

* Install MXNet with Caffe support
The converting tool is available at
[tools/caffe_converter](https://github.com/dmlc/mxnet/tree/master/tools/caffe_converter). On
the remaining of this section, we assume we are on the `tools/caffe_converter`
directory.

* Embed Caffe operators into MXNet's symbolic graph
### How to build

## Install Caffe With MXNet
If Caffe's python package is installed, namely we can run `import caffe` in
python, then we are ready to go.

For example, we can used
[AWS Deep Learning AMI](https://aws.amazon.com/marketplace/pp/B06VSPXKDX) with
both Caffe and MXNet installed.

1. Download the official Caffe repository, [BVLC/Caffe](https://github.com/BVLC/caffe).
2. Download the [Caffe patch for the MXNet interface](https://github.com/BVLC/caffe/pull/4527.patch). Move the patch file to your Caffe root folder, and apply the patch by using `git apply patch_file_name`.
3. Install Caffe using the [official guide](http://caffe.berkeleyvision.org/installation.html).
Otherwise we can install the
[Google protobuf](https://developers.google.com/protocol-buffers/?hl=en)
compiler and its python binding. It is easier to install, but may be slower
during running.

## Compile with Caffe
1. Install the compiler:
- Linux: install `protobuf-compiler` e.g. `sudo apt-get install
protobuf-compiler` for Ubuntu and `sudo yum install protobuf-compiler` for
Redhat/Fedora.
- Windows: Download the win32 build of
[protobuf](https://github.com/google/protobuf/releases). Make sure to
download the version that corresponds to the version of the python binding
on the next step. Extract to any location then add that location to your
`PATH`
- Mac OS X: `brew install protobuf`

2. Install the python binding by either `conda install -c conda-forge protobuf`
or `pip install protobuf`.

1. If you haven't already, copy `make/config.mk` (for Linux) or `make/osx.mk` (for Mac) into the MXNet root folder as `config.mk`.
2. In the mxnet folder, open `config.mk` and uncomment the lines `CAFFE_PATH = $(HOME)/caffe` and `MXNET_PLUGINS += plugin/caffe/caffe.mk`. Modify `CAFFE_PATH` to your Caffe installation, if necessary.
3. To build with Caffe support, run `make clean && make`.
3. Compile Caffe proto definition. Run `make` in Linux or Mac OS X, or
`make_win32.bat` in Windows

## Using the Caffe Operator (Layer)
Caffe's neural network operator and loss functions are supported by MXNet through `mxnet.symbol.CaffeOp` and `mxnet.symbol.CaffeLoss`, respectively.
For example, the following code shows a [multi-layer perceptron](https://en.wikipedia.org/wiki/Multilayer_perceptron) (MLP) network for classifying MNIST digits: [full code](https://github.com/dmlc/mxnet/blob/master/example/caffe/caffe_net.py):
### How to use

### Python
There are three tools:

- `convert_symbol.py` : convert Caffe model definition in protobuf into MXNet's
Symbol in JSON format.
- `convert_model.py` : convert Caffe model parameters into MXNet's NDArray format
- `convert_mean.py` : convert Caffe input mean file into MXNet's NDArray format

In addition, there are two tools:
- `convert_caffe_modelzoo.py` : download and convert models from Caffe model
zoo.
- `test_converter.py` : test the converted models by checking the prediction
accuracy.

## Calling Caffe operators in MXNet

Besides converting Caffe models, MXNet supports calling most Caffe operators,
including network layer, data layer, and loss function, directly. It is
particularly useful if there are customized operators implemented in Caffe, then
we do not need to re-implement them in MXNet.

### How to install

This feature requires Caffe. In particular, we need to re-compile Caffe before
[PR #4527](https://github.com/BVLC/caffe/pull/4527) is merged into Caffe. There
are the steps of how to rebuild Caffe:

1. Download [Caffe](https://github.com/BVLC/caffe). E.g. `git clone
https://github.com/BVLC/caffe`
2. Download the
[patch for the MXNet interface](https://github.com/BVLC/caffe/pull/4527.patch)
and apply to Caffe. E.g.
```bash
cd caffe && wget https://github.com/BVLC/caffe/pull/4527.patch && git apply 4527.patch
```
3. Build and install Caffe by following the
[official guide](http://caffe.berkeleyvision.org/installation.html).

Next we need to compile MXNet with Caffe supports

1. Copy `make/config.mk` (for Linux) or `make/osx.mk`
(for Mac) into the MXNet root folder as `config.mk` if you have not done it yet
2. Open the copied `config.mk` and uncomment these two lines
```bash
CAFFE_PATH = $(HOME)/caffe
MXNET_PLUGINS += plugin/caffe/caffe.mk
```
Modify `CAFFE_PATH` to your Caffe installation, if necessary.
3. Then build with 8 threads `make clean && make -j8`.

### How to use

This Caffe plugin adds three components into MXNet:

- `sym.CaffeOp` : Caffe neural network layer
- `sym.CaffeLoss` : Caffe loss functions
- `io.CaffeDataIter` : Caffe data layer

#### Use `sym.CaffeOp`
The following example shows the definition of a 10 classes multi-layer perceptron:

```Python
data = mx.symbol.Variable('data')
fc1 = mx.symbol.CaffeOp(data_0=data, num_weight=2, name='fc1', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 128} }")
act1 = mx.symbol.CaffeOp(data_0=fc1, prototxt="layer{type:\"TanH\"}")
fc2 = mx.symbol.CaffeOp(data_0=act1, num_weight=2, name='fc2', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 64} }")
act2 = mx.symbol.CaffeOp(data_0=fc2, prototxt="layer{type:\"TanH\"}")
fc3 = mx.symbol.CaffeOp(data_0=act2, num_weight=2, name='fc3', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 10}}")
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')
data = mx.sym.Variable('data')
fc1 = mx.sym.CaffeOp(data_0=data, num_weight=2, name='fc1', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 128} }")
act1 = mx.sym.CaffeOp(data_0=fc1, prototxt="layer{type:\"TanH\"}")
fc2 = mx.sym.CaffeOp(data_0=act1, num_weight=2, name='fc2', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 64} }")
act2 = mx.sym.CaffeOp(data_0=fc2, prototxt="layer{type:\"TanH\"}")
fc3 = mx.sym.CaffeOp(data_0=act2, num_weight=2, name='fc3', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 10}}")
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
```

Let's break it down. First, `data = mx.symbol.Variable('data')` defines a variable as a placeholder for input.
Then, it's fed through Caffe operators with `fc1 = mx.symbol.CaffeOp(data_0=data, num_weight=2, name='fc1', prototxt="layer{type:\"InnerProduct\" inner_product_param{num_output: 128} }")`.
Let's break it down. First, `data = mx.sym.Variable('data')` defines a variable
as a placeholder for input. Then, it's fed through Caffe operators with `fc1 =
mx.sym.CaffeOp(...)`. `CaffeOp` accepts several arguments:

- The inputs to Caffe operators are named as `data_i` for *i=0, ..., num_data-1*
- `num_data` is the number of inputs. In default it is 1, and therefore
skipped in the above example.
- `num_out` is the number of outputs. In default it is 1 and also skipped.
- `num_weight` is the number of weights (`blobs_`). Its default value is 0. We
need to explicitly specify it for a non-zero value.
- `prototxt` is the protobuf configuration string.

#### Use `sym.CaffeLoss`

The inputs to Caffe operators are named as data_i for i=0. num_data-1 as `num_data` is the number of inputs. You can skip the argument, as the example does, if its value is 1. `num_weight` is the number of `blobs_`(weights). Its default value is 0 because many operators maintain no weight. `prototxt` is the configuration string.
Using Caffe loss is similar.
We can replace the MXNet loss with Caffe loss.
We can replace

To use the loss function in Caffe, replace the last line with:
Replacing the last line of the above example with the following two lines we can
call Caffe loss instead of MXNet loss.

```Python
label = mx.symbol.Variable('softmax_label')
mlp = mx.symbol.CaffeLoss(data=fc3, label=label, grad_scale=1, name='softmax', prototxt="layer{type:\"SoftmaxWithLoss\"}")
label = mx.sym.Variable('softmax_label')
mlp = mx.sym.CaffeLoss(data=fc3, label=label, grad_scale=1, name='softmax', prototxt="layer{type:\"SoftmaxWithLoss\"}")
```

Similar to `CaffeOp`, `CaffeLoss` has arguments `num_data` (2 in default) and
`num_out` (1 in default). But there are two differences

1. Inputs are `data` and `label`. And we need to explicitly create a variable
placeholder for label, which is implicitly done in MXNet loss.
2. `grad_scale` is the weight of this loss.

#### Use `io.CaffeDataIter`

We can also wrap a Caffe data layer into MXNet's data iterator. Below is an
example for creating a data iterator for MNIST

```python
train = mx.io.CaffeDataIter(
prototxt =
'layer { \
name: "mnist" \
type: "Data" \
top: "data" \
top: "label" \
include { \
phase: TEST \
} \
transform_param { \
scale: 0.00390625 \
} \
data_param { \
source: "caffe/examples/mnist/mnist_test_lmdb" \
batch_size: 100 \
backend: LMDB \
} \
}',
flat = flat,
num_examples = 60000,
)
```

### Put it all together

The complete example is available at
[example/caffe](https://github.com/dmlc/mxnet/blob/master/example/caffe/)
2 changes: 0 additions & 2 deletions example/image-classification/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ def download_file(url, local_fname=None, force_write=False):
if exc.errno != errno.EEXIST:
raise



r = requests.get(url, stream=True)
assert r.status_code == 200, "failed to open %s" % url
with open(local_fname, 'wb') as f:
Expand Down
38 changes: 24 additions & 14 deletions example/image-classification/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,47 @@
import os
import logging


def score(model, data_val, metrics, gpus, batch_size, rgb_mean,
image_shape='3,224,224', data_nthreads=4):
def score(model, data_val, metrics, gpus, batch_size, rgb_mean=None, mean_img=None,
image_shape='3,224,224', data_nthreads=4, label_name='softmax_label'):
# create data iterator
rgb_mean = [float(i) for i in rgb_mean.split(',')]
data_shape = tuple([int(i) for i in image_shape.split(',')])
if mean_img is not None:
mean_args = {'mean_img':mean_img}
elif rgb_mean is not None:
rgb_mean = [float(i) for i in rgb_mean.split(',')]
mean_args = {'mean_r':rgb_mean[0], 'mean_g':rgb_mean[1],
'mean_b':rgb_mean[2]}

data = mx.io.ImageRecordIter(
path_imgrec = data_val,
label_width = 1,
mean_r = rgb_mean[0],
mean_g = rgb_mean[1],
mean_b = rgb_mean[2],
preprocess_threads = data_nthreads,
batch_size = batch_size,
data_shape = data_shape,
label_name = label_name,
rand_crop = False,
rand_mirror = False)
rand_mirror = False,
**mean_args)

# download model
dir_path = os.path.dirname(os.path.realpath(__file__))
(prefix, epoch) = modelzoo.download_model(
model, os.path.join(dir_path, 'model'))
if isinstance(model, str):
# download model
dir_path = os.path.dirname(os.path.realpath(__file__))
(prefix, epoch) = modelzoo.download_model(
model, os.path.join(dir_path, 'model'))
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
elif isinstance(model, tuple) or isinstance(model, list):
assert len(model) == 3
(sym, arg_params, aux_params) = model
else:
raise TypeError('model type [%s] is not supported' % str(type(model)))

# create module
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
if gpus == '':
devs = mx.cpu()
else:
devs = [mx.gpu(int(i)) for i in gpus.split(',')]

mod = mx.mod.Module(symbol=sym, context=devs)
mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name,])
mod.bind(for_training=False,
data_shapes=data.provide_data,
label_shapes=data.provide_label)
Expand Down
6 changes: 5 additions & 1 deletion example/image-classification/test_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
test pretrained models
"""
from __future__ import print_function
import os
import mxnet as mx
from common import find_mxnet, modelzoo
from common.util import download_file, get_gpus
from score import score

def download_data():
download_file('http://data.mxnet.io/data/val-5k-256.rec', 'data/val-5k-256.rec')
if not os.path.isdir('data'):
os.mkdir('data')
return download_file('http://data.mxnet.io/data/val-5k-256.rec', 'data/val-5k-256.rec')


def test_imagenet1k_resnet(**kwargs):
models = ['imagenet1k-resnet-34',
Expand Down
67 changes: 66 additions & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
import numpy as np
import numpy.testing as npt
import mxnet as mx

import subprocess
import os
import errno
from .context import cpu, gpu, Context
from .ndarray import array
from .symbol import Symbol
try:
import requests
except ImportError:
# in rare cases requests may be not installed
pass

_rng = np.random.RandomState(1234)

Expand Down Expand Up @@ -803,3 +810,61 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
print(str(e))

return gt

def list_gpus():
"""Return a list of GPUs
Returns
-------
list of int:
If there are n GPUs, then return a list [0,1,...,n-1]. Otherwise returns
[].
"""
try:
re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True)
except OSError:
return []
return range(len([i for i in re.split('\n') if 'GPU' in i]))

def download(url, fname=None, overwrite=False):
"""Download an given URL
Parameters
----------
url : str
URL to download
fname : str, optional
filename of the downloaded file. If None, then will guess a filename
from url.
overwrite : bool, optional
Default is false, which means skipping download if the local file
exists. If true, then download the url to overwrite the local file if
exists.
Returns
-------
str
The filename of the downloaded file
"""
if fname is None:
fname = url.split('/')[-1]
if not overwrite and os.path.exists(fname):
return fname

dir_name = os.path.dirname(fname)
if dir_name != "":
if not os.path.exists(dir_name):
try: # try to create the directory if it doesn't exists
os.makedirs(dir_name)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise OSError('failed to create ' + dir_name)

r = requests.get(url, stream=True)
assert r.status_code == 200, "failed to open %s" % url
with open(fname, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
return fname
1 change: 1 addition & 0 deletions tools/caffe_converter/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
model/
Loading

0 comments on commit f442b18

Please sign in to comment.