Skip to content

Caffe2fluid #617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions fluid/image_classification/caffe2fluid/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
### Caffe2Fluid
This tool is used to convert a Caffe model to Fluid model

### Howto
1, Prepare caffepb.py in ./proto, two options provided
1) generate it from caffe.proto using protoc
bash ./proto/compile.sh

2) download one from github directly
cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py

2, Convert the caffe model using 'convert.py' which will generate a python script and a weight(in .npy) file

3, Use the converted model to predict
see more detail info in 'tests/lenet/README.md'


### Supported models
- Lenet on mnist dataset

- ResNets:(ResNet-50, ResNet-101, ResNet-152)
model addrs:(https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)

### Notes
Some of this code come from here: https://github.com/ethereon/caffe-tensorflow
72 changes: 72 additions & 0 deletions fluid/image_classification/caffe2fluid/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python

import os
import sys
import numpy as np
import argparse
from kaffe import KaffeError, print_stderr

from kaffe.paddle import Transformer


def fatal_error(msg):
""" fatal error encounted
"""
print_stderr(msg)
exit(-1)


def validate_arguments(args):
""" validate args
"""
if (args.data_output_path is not None) and (args.caffemodel is None):
fatal_error('No input data path provided.')
if (args.caffemodel is not None) and (args.data_output_path is None):
fatal_error('No output data path provided.')
if (args.code_output_path is None) and (args.data_output_path is None):
fatal_error('No output path specified.')


def convert(def_path, caffemodel_path, data_output_path, code_output_path,
phase):
""" convert caffe model to tf/paddle models
"""
try:
transformer = Transformer(def_path, caffemodel_path, phase=phase)
print_stderr('Converting data...')
if caffemodel_path is not None:
data = transformer.transform_data()
print_stderr('Saving data...')
with open(data_output_path, 'wb') as data_out:
np.save(data_out, data)
if code_output_path:
print_stderr('Saving source...')
with open(code_output_path, 'wb') as src_out:
src_out.write(transformer.transform_source())
print_stderr('Done.')
except KaffeError as err:
fatal_error('Error encountered: {}'.format(err))


def main():
""" main
"""
parser = argparse.ArgumentParser()
parser.add_argument('def_path', help='Model definition (.prototxt) path')
parser.add_argument('--caffemodel', help='Model data (.caffemodel) path')
parser.add_argument('--data-output-path', help='Converted data output path')
parser.add_argument(
'--code-output-path', help='Save generated source to this path')
parser.add_argument(
'-p',
'--phase',
default='test',
help='The phase to convert: test (default) or train')
args = parser.parse_args()
validate_arguments(args)
convert(args.def_path, args.caffemodel, args.data_output_path,
args.code_output_path, args.phase)


if __name__ == '__main__':
main()
5 changes: 5 additions & 0 deletions fluid/image_classification/caffe2fluid/kaffe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .graph import GraphBuilder, NodeMapper
from .errors import KaffeError, print_stderr

import os
from . import paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .resolver import get_caffe_resolver, has_pycaffe
61 changes: 61 additions & 0 deletions fluid/image_classification/caffe2fluid/kaffe/caffe/resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import sys

SHARED_CAFFE_RESOLVER = None


def import_caffepb():
p = os.path.realpath(__file__)
p = os.path.dirname(p)
p = os.path.join(p, '../../proto')
sys.path.insert(0, p)
import caffepb
return caffepb


class CaffeResolver(object):
def __init__(self):
self.import_caffe()

def import_caffe(self):
self.caffe = None
try:
# Try to import PyCaffe first
import caffe
self.caffe = caffe
except ImportError:
# Fall back to the protobuf implementation
self.caffepb = import_caffepb()
show_fallback_warning()
if self.caffe:
# Use the protobuf code from the imported distribution.
# This way, Caffe variants with custom layers will work.
self.caffepb = self.caffe.proto.caffe_pb2
self.NetParameter = self.caffepb.NetParameter

def has_pycaffe(self):
return self.caffe is not None


def get_caffe_resolver():
global SHARED_CAFFE_RESOLVER
if SHARED_CAFFE_RESOLVER is None:
SHARED_CAFFE_RESOLVER = CaffeResolver()
return SHARED_CAFFE_RESOLVER


def has_pycaffe():
return get_caffe_resolver().has_pycaffe()


def show_fallback_warning():
msg = '''
------------------------------------------------------------
WARNING: PyCaffe not found!
Falling back to a pure protocol buffer implementation.
* Conversions will be drastically slower.
* This backend is UNTESTED!
------------------------------------------------------------

'''
sys.stderr.write(msg)
34 changes: 34 additions & 0 deletions fluid/image_classification/caffe2fluid/kaffe/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys

#debug level, can be 'warn', 'verbose'
log_level = 'warn'


class KaffeError(Exception):
pass


def print_stderr(msg):
sys.stderr.write('%s\n' % msg)


def debug(msg):
if log_level == 'verbose':
print_stderr('[DEBUG]' + msg)


def notice(msg):
print_stderr('[NOTICE]' + msg)


def warn(msg):
print_stderr('[WARNING]' + msg)


def set_loglevel(level):
global log_level

if 'warn' != level and 'verbose' != level:
raise Exception('not supported log level[%s]' % (level))

log_level = level
Loading