Skip to content

Commit e7684f0

Browse files
walloollawqingqing01
authored andcommitted
caffe2fluid:upgrade argmax implementtion (#866)
1 parent 237fe2f commit e7684f0

File tree

15 files changed

+507
-52
lines changed

15 files changed

+507
-52
lines changed

fluid/image_classification/caffe2fluid/README.md

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,63 @@
11
### Caffe2Fluid
2-
This tool is used to convert a Caffe model to Fluid model
2+
This tool is used to convert a Caffe model to a Fluid model
33

4-
### Howto
4+
### HowTo
55
1. Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here:
6-
- Generate pycaffe from caffe.proto
7-
<pre><code>bash ./proto/compile.sh</code></pre>
6+
- Generate pycaffe from caffe.proto
7+
```
8+
bash ./proto/compile.sh
9+
```
810
9-
- download one from github directly
10-
<pre><code>cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py
11-
</code></pre>
11+
- Download one from github directly
12+
```
13+
cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py
14+
```
1215
1316
2. Convert the Caffe model to Fluid model
14-
- generate fluid code and weight file
15-
<pre><code>python convert.py alexnet.prototxt \
16-
--caffemodel alexnet.caffemodel \
17-
--data-output-path alexnet.npy \
18-
--code-output-path alexnet.py
19-
</code></pre>
17+
- Generate fluid code and weight file
18+
```
19+
python convert.py alexnet.prototxt \
20+
--caffemodel alexnet.caffemodel \
21+
--data-output-path alexnet.npy \
22+
--code-output-path alexnet.py
23+
```
2024
21-
- save weights as fluid model file
22-
<pre><code>python alexnet.py alexnet.npy ./fluid_model
23-
</code></pre>
25+
- Save weights as fluid model file
26+
```
27+
python alexnet.py alexnet.npy ./fluid
28+
```
2429
2530
3. Use the converted model to infer
26-
- see more details in '*examples/imagenet/run.sh*'
31+
- See more details in '*examples/imagenet/run.sh*'
2732
28-
4. compare the inference results with caffe
29-
- see more details in '*examples/imagenet/diff.sh*'
33+
4. Compare the inference results with caffe
34+
- See more details in '*examples/imagenet/diff.sh*'
35+
36+
### How to convert custom layer
37+
1. Implement your custom layer in a file under '*kaffe/custom_layers*', eg: mylayer.py
38+
- Implement ```shape_func(input_shape, [other_caffe_params])``` to calculate the output shape
39+
- Implement ```layer_func(inputs, name, [other_caffe_params])``` to construct a fluid layer
40+
- Register these two functions ```register(kind='MyType', shape=shape_func, layer=layer_func)```
41+
- Notes: more examples can be found in '*kaffe/custom_layers*'
42+
43+
2. Add ```import mylayer``` to '*kaffe/custom_layers/\_\_init__.py*'
44+
45+
3. Prepare your pycaffe as your customized version(same as previous env prepare)
46+
- (option1) replace 'proto/caffe.proto' with your own caffe.proto and compile it
47+
- (option2) change your pycaffe to the customized version
48+
49+
4. Convert the Caffe model to Fluid model
50+
51+
5. Set env $CAFFE2FLUID_CUSTOM_LAYERS to the parent directory of 'custom_layers'
52+
```
53+
export CAFFE2FLUID_CUSTOM_LAYERS=/path/to/caffe2fluid/kaffe
54+
```
55+
56+
6. Use the converted model when loading model in 'xxxnet.py' and 'xxxnet.npy'(no need if model is already in 'fluid/model' and 'fluid/params')
3057
3158
### Tested models
32-
- Lenet
59+
- Lenet:
60+
[model addr](https://github.com/ethereon/caffe-tensorflow/blob/master/examples/mnist)
3361
3462
- ResNets:(ResNet-50, ResNet-101, ResNet-152)
3563
[model addr](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)

fluid/image_classification/caffe2fluid/convert.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,17 @@ def convert(def_path, caffemodel_path, data_output_path, code_output_path,
4343
print_stderr('Saving source...')
4444
with open(code_output_path, 'wb') as src_out:
4545
src_out.write(transformer.transform_source())
46+
print_stderr('set env variable before using converted model '\
47+
'if used custom_layers:')
48+
custom_pk_path = os.path.dirname(os.path.abspath(__file__))
49+
custom_pk_path = os.path.join(custom_pk_path, 'kaffe')
50+
print_stderr('export CAFFE2FLUID_CUSTOM_LAYERS=%s' % (custom_pk_path))
4651
print_stderr('Done.')
52+
return 0
4753
except KaffeError as err:
4854
fatal_error('Error encountered: {}'.format(err))
4955

50-
return 0
56+
return 1
5157

5258

5359
def main():

fluid/image_classification/caffe2fluid/examples/imagenet/infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def infer(model_path, imgfile, net_file=None, net_name=None, debug=True):
164164
debug = False
165165
print('found a inference model for fluid')
166166
except ValueError as e:
167-
pass
168167
print('try to load model using net file and weight file')
169168
net_weight = model_path
170169
ret = load_model(exe, place, net_file, net_name, net_weight, debug)

fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import sys
88
import os
99
import numpy as np
10+
import paddle.fluid as fluid
1011
import paddle.v2 as paddle
11-
import paddle.v2.fluid as fluid
1212

1313

1414
def test_model(exe, test_program, fetch_list, test_reader, feeder):
@@ -34,9 +34,6 @@ def evaluate(net_file, model_file):
3434

3535
from lenet import LeNet as MyNet
3636

37-
with_gpu = False
38-
paddle.init(use_gpu=with_gpu)
39-
4037
#1, define network topology
4138
images = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
4239
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
@@ -45,7 +42,7 @@ def evaluate(net_file, model_file):
4542
prediction = net.layers['prob']
4643
acc = fluid.layers.accuracy(input=prediction, label=label)
4744

48-
place = fluid.CUDAPlace(0) if with_gpu is True else fluid.CPUPlace()
45+
place = fluid.CPUPlace()
4946
exe = fluid.Executor(place)
5047
exe.run(fluid.default_startup_program())
5148

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
"""
3+
4+
from .register import get_registered_layers
5+
#custom layer import begins
6+
7+
import axpy
8+
import flatten
9+
import argmax
10+
11+
#custom layer import ends
12+
13+
custom_layers = get_registered_layers()
14+
15+
16+
def set_args(f, params):
17+
""" set args for function 'f' using the parameters in node.layer.parameters
18+
19+
Args:
20+
f (function): a python function object
21+
params (object): a object contains attributes needed by f's arguments
22+
23+
Returns:
24+
arg_names (list): a list of argument names
25+
kwargs (dict): a dict contains needed arguments
26+
"""
27+
argc = f.__code__.co_argcount
28+
arg_list = f.__code__.co_varnames[0:argc]
29+
30+
kwargs = {}
31+
for arg_name in arg_list:
32+
try:
33+
v = getattr(node.layer.parameters, arg_name, None)
34+
except Exception as e:
35+
v = None
36+
37+
if v is not None:
38+
kwargs[arg_name] = v
39+
40+
return arg_list, kwargs
41+
42+
43+
def has_layer(kind):
44+
""" test whether this layer exists in custom layer
45+
"""
46+
return kind in custom_layers
47+
48+
49+
def compute_output_shape(kind, node):
50+
assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
51+
kind)
52+
shape_func = custom_layers[kind]['shape']
53+
54+
parents = node.parents
55+
inputs = [list(p.output_shape) for p in parents]
56+
arg_names, kwargs = set_args(shape_func, node.layer.parameters)
57+
58+
if len(inputs) == 1:
59+
inputs = inputs[0]
60+
61+
return shape_func(inputs, **kwargs)
62+
63+
64+
def make_node(template, kind, node):
65+
""" make a TensorFlowNode for custom layer which means construct
66+
a piece of code to define a layer implemented in 'custom_layers'
67+
68+
Args:
69+
@template (TensorFlowNode): a factory to new a instance of TensorFLowNode
70+
@kind (str): type of custom layer
71+
@node (graph.Node): a layer in the net
72+
73+
Returns:
74+
instance of TensorFlowNode
75+
"""
76+
assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
77+
kind)
78+
79+
layer_func = custom_layers[kind]['layer']
80+
81+
#construct arguments needed by custom layer function from node's parameters
82+
arg_names, kwargs = set_args(layer_func, node.layer.parameters)
83+
84+
return template('custom_layer', kind, **kwargs)
85+
86+
87+
def make_custom_layer(kind, inputs, name, *args, **kwargs):
88+
""" execute a custom layer which is implemented by users
89+
90+
Args:
91+
@kind (str): type name of this layer
92+
@inputs (vars): variable list created by fluid
93+
@namme (str): name for this layer
94+
@args (tuple): other positional arguments
95+
@kwargs (dict): other kv arguments
96+
97+
Returns:
98+
output (var): output variable for this layer
99+
"""
100+
assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
101+
kind)
102+
103+
layer_func = custom_layers[kind]['layer']
104+
return layer_func(inputs, name, *args, **kwargs)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
""" a custom layer for 'argmax', maybe we should implement this in standard way.
2+
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/argmax.html
3+
"""
4+
from .register import register
5+
6+
7+
def import_fluid():
8+
import paddle.fluid as fluid
9+
return fluid
10+
11+
12+
def argmax_shape(input_shape, out_max_val=False, top_k=1, axis=-1):
13+
""" calculate the output shape of this layer using input shape
14+
15+
Args:
16+
@input_shape (list of num): a list of number which represents the input shape
17+
@out_max_val (bool): parameter from caffe's ArgMax layer
18+
@top_k (int): parameter from caffe's ArgMax layer
19+
@axis (int): parameter from caffe's ArgMax layer
20+
21+
Returns:
22+
@output_shape (list of num): a list of numbers represent the output shape
23+
"""
24+
input_shape = list(input_shape)
25+
26+
if axis < 0:
27+
axis += len(input_shape)
28+
29+
assert (axis + 1 == len(input_shape)
30+
), 'only can be applied on the last dimension now'
31+
32+
output_shape = input_shape
33+
output_shape[-1] = top_k
34+
if out_max_val is True:
35+
output_shape[-1] *= 2
36+
37+
return output_shape
38+
39+
40+
def argmax_layer(input, name, out_max_val=False, top_k=1, axis=-1):
41+
""" build a layer of type 'ArgMax' using fluid
42+
43+
Args:
44+
@input (variable): input fluid variable for this layer
45+
@name (str): name for this layer
46+
@out_max_val (bool): parameter from caffe's ArgMax layer
47+
@top_k (int): parameter from caffe's ArgMax layer
48+
@axis (int): parameter from caffe's ArgMax layer
49+
50+
Returns:
51+
output (variable): output variable for this layer
52+
"""
53+
54+
fluid = import_fluid()
55+
56+
if axis < 0:
57+
axis += len(input.shape)
58+
59+
assert (axis + 1 == len(input_shape)
60+
), 'only can be applied on the last dimension now'
61+
62+
topk_var, index_var = fluid.layers.topk(input=input, k=top_k)
63+
if out_max_val is True:
64+
output = fluid.layers.concate([topk_var, index_var], axis=axis)
65+
else:
66+
output = topk_var
67+
return output
68+
69+
70+
register(kind='ArgMax', shape=argmax_shape, layer=argmax_layer)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
""" A custom layer for 'axpy' which receives 3 tensors and output 1 tensor.
2+
the function performed is:(the mupltiplication and add are elementewise)
3+
output = inputs[0] * inputs[1] + inputs[2]
4+
"""
5+
6+
from .register import register
7+
8+
9+
def axpy_shape(input_shapes):
10+
""" calculate the output shape of this layer using input shapes
11+
12+
Args:
13+
@input_shapes (list of tuples): a list of input shapes
14+
15+
Returns:
16+
@output_shape (list of num): a list of numbers represent the output shape
17+
"""
18+
assert len(input_shapes) == 3, "not valid input shape for axpy layer"
19+
assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
20+
21+
output_shape = input_shapes[1]
22+
assert (input_shapes[2] == output_shape),\
23+
"shape not consistent for axpy[%s <--> %s]" \
24+
% (str(output_shape), str(input_shapes[2]))
25+
26+
return output_shape
27+
28+
29+
def axpy_layer(inputs, name):
30+
""" build a layer of type 'Axpy' using fluid
31+
32+
Args:
33+
@inputs (list of variables): input fluid variables for this layer
34+
@name (str): name for this layer
35+
36+
Returns:
37+
output (variable): output variable for this layer
38+
"""
39+
import paddle.fluid as fluid
40+
41+
assert len(inputs) == 3, "invalid inputs for axpy[%s]" % (name)
42+
alpha = inputs[0]
43+
x = inputs[1]
44+
y = inputs[2]
45+
output = fluid.layers.elementwise_mul(x, alpha, axis=0)
46+
output = fluid.layers.elementwise_add(output, y)
47+
48+
return output
49+
50+
51+
register(kind='Axpy', shape=axpy_shape, layer=axpy_layer)

0 commit comments

Comments
 (0)