Skip to content

Commit 50afdd8

Browse files
authored
Merge pull request #617 from dragonwarrior/caffe2fluid
Caffe2fluid
2 parents 243ee52 + 39daecc commit 50afdd8

File tree

20 files changed

+3529
-0
lines changed

20 files changed

+3529
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
### Caffe2Fluid
2+
This tool is used to convert a Caffe model to Fluid model
3+
4+
### Howto
5+
1, Prepare caffepb.py in ./proto, two options provided
6+
1) generate it from caffe.proto using protoc
7+
bash ./proto/compile.sh
8+
9+
2) download one from github directly
10+
cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py
11+
12+
2, Convert the caffe model using 'convert.py' which will generate a python script and a weight(in .npy) file
13+
14+
3, Use the converted model to predict
15+
see more detail info in 'tests/lenet/README.md'
16+
17+
18+
### Supported models
19+
- Lenet on mnist dataset
20+
21+
- ResNets:(ResNet-50, ResNet-101, ResNet-152)
22+
model addrs:(https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)
23+
24+
### Notes
25+
Some of this code come from here: https://github.com/ethereon/caffe-tensorflow
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python
2+
3+
import os
4+
import sys
5+
import numpy as np
6+
import argparse
7+
from kaffe import KaffeError, print_stderr
8+
9+
from kaffe.paddle import Transformer
10+
11+
12+
def fatal_error(msg):
13+
""" fatal error encounted
14+
"""
15+
print_stderr(msg)
16+
exit(-1)
17+
18+
19+
def validate_arguments(args):
20+
""" validate args
21+
"""
22+
if (args.data_output_path is not None) and (args.caffemodel is None):
23+
fatal_error('No input data path provided.')
24+
if (args.caffemodel is not None) and (args.data_output_path is None):
25+
fatal_error('No output data path provided.')
26+
if (args.code_output_path is None) and (args.data_output_path is None):
27+
fatal_error('No output path specified.')
28+
29+
30+
def convert(def_path, caffemodel_path, data_output_path, code_output_path,
31+
phase):
32+
""" convert caffe model to tf/paddle models
33+
"""
34+
try:
35+
transformer = Transformer(def_path, caffemodel_path, phase=phase)
36+
print_stderr('Converting data...')
37+
if caffemodel_path is not None:
38+
data = transformer.transform_data()
39+
print_stderr('Saving data...')
40+
with open(data_output_path, 'wb') as data_out:
41+
np.save(data_out, data)
42+
if code_output_path:
43+
print_stderr('Saving source...')
44+
with open(code_output_path, 'wb') as src_out:
45+
src_out.write(transformer.transform_source())
46+
print_stderr('Done.')
47+
except KaffeError as err:
48+
fatal_error('Error encountered: {}'.format(err))
49+
50+
51+
def main():
52+
""" main
53+
"""
54+
parser = argparse.ArgumentParser()
55+
parser.add_argument('def_path', help='Model definition (.prototxt) path')
56+
parser.add_argument('--caffemodel', help='Model data (.caffemodel) path')
57+
parser.add_argument('--data-output-path', help='Converted data output path')
58+
parser.add_argument(
59+
'--code-output-path', help='Save generated source to this path')
60+
parser.add_argument(
61+
'-p',
62+
'--phase',
63+
default='test',
64+
help='The phase to convert: test (default) or train')
65+
args = parser.parse_args()
66+
validate_arguments(args)
67+
convert(args.def_path, args.caffemodel, args.data_output_path,
68+
args.code_output_path, args.phase)
69+
70+
71+
if __name__ == '__main__':
72+
main()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .graph import GraphBuilder, NodeMapper
2+
from .errors import KaffeError, print_stderr
3+
4+
import os
5+
from . import paddle
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .resolver import get_caffe_resolver, has_pycaffe
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import os
2+
import sys
3+
4+
SHARED_CAFFE_RESOLVER = None
5+
6+
7+
def import_caffepb():
8+
p = os.path.realpath(__file__)
9+
p = os.path.dirname(p)
10+
p = os.path.join(p, '../../proto')
11+
sys.path.insert(0, p)
12+
import caffepb
13+
return caffepb
14+
15+
16+
class CaffeResolver(object):
17+
def __init__(self):
18+
self.import_caffe()
19+
20+
def import_caffe(self):
21+
self.caffe = None
22+
try:
23+
# Try to import PyCaffe first
24+
import caffe
25+
self.caffe = caffe
26+
except ImportError:
27+
# Fall back to the protobuf implementation
28+
self.caffepb = import_caffepb()
29+
show_fallback_warning()
30+
if self.caffe:
31+
# Use the protobuf code from the imported distribution.
32+
# This way, Caffe variants with custom layers will work.
33+
self.caffepb = self.caffe.proto.caffe_pb2
34+
self.NetParameter = self.caffepb.NetParameter
35+
36+
def has_pycaffe(self):
37+
return self.caffe is not None
38+
39+
40+
def get_caffe_resolver():
41+
global SHARED_CAFFE_RESOLVER
42+
if SHARED_CAFFE_RESOLVER is None:
43+
SHARED_CAFFE_RESOLVER = CaffeResolver()
44+
return SHARED_CAFFE_RESOLVER
45+
46+
47+
def has_pycaffe():
48+
return get_caffe_resolver().has_pycaffe()
49+
50+
51+
def show_fallback_warning():
52+
msg = '''
53+
------------------------------------------------------------
54+
WARNING: PyCaffe not found!
55+
Falling back to a pure protocol buffer implementation.
56+
* Conversions will be drastically slower.
57+
* This backend is UNTESTED!
58+
------------------------------------------------------------
59+
60+
'''
61+
sys.stderr.write(msg)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import sys
2+
3+
#debug level, can be 'warn', 'verbose'
4+
log_level = 'warn'
5+
6+
7+
class KaffeError(Exception):
8+
pass
9+
10+
11+
def print_stderr(msg):
12+
sys.stderr.write('%s\n' % msg)
13+
14+
15+
def debug(msg):
16+
if log_level == 'verbose':
17+
print_stderr('[DEBUG]' + msg)
18+
19+
20+
def notice(msg):
21+
print_stderr('[NOTICE]' + msg)
22+
23+
24+
def warn(msg):
25+
print_stderr('[WARNING]' + msg)
26+
27+
28+
def set_loglevel(level):
29+
global log_level
30+
31+
if 'warn' != level and 'verbose' != level:
32+
raise Exception('not supported log level[%s]' % (level))
33+
34+
log_level = level

0 commit comments

Comments
 (0)