Skip to content

added example and tests for exporting keras models #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ batch = np.random.rand(10000, 784)
result = y.eval({x: batch})
```

##### Convert a Keras Model

You can also convert Keras models (although many layers are still not supported, see test output for the keras.py test suite)

```python
>>> import tfdeploy as td
>>> from keras.models import Sequential, Model
>>> from keras.layers import Convolution2D
>>> k_model = Sequential()
>>> k_model.add(Convolution2D(5, (3,3), input_shape = (9,9,1)))
>>> k_model.compile('sgd', 'mse')
>>> t_model, i_names, o_names = td.deploy_keras(k_model)
>>> type(t_model)
<class 'tfdeploy.Model'>
>>> i_names
OrderedDict([('conv2d_1_input', 'conv2d_1_input:0')])
>>> o_names
OrderedDict([('conv2d_1', 'conv2d_1/BiasAdd:0')])
```

##### Write your own `Operation`

tfdeploy supports most of the `Operation`'s [implemented in tensorflow](https://www.tensorflow.org/versions/master/api_docs/python/math_ops.html). However, if you miss one (in that case, submit a PR or an issue ;) ) or if you're using custom ops, you might want to extend tfdeploy by defining a new class op that inherits from `tfdeploy.Operation`:
Expand Down
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
numpy
tensorflow>=1.0
matplotlib
keras>=2.0
11 changes: 8 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
readme = os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md")
if os.path.isfile(readme):
cmd = "pandoc --from=markdown --to=rst " + readme
p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True, executable="/bin/bash")
out, err = p.communicate()
if p.returncode != 0:
try:
p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True, executable="/bin/bash")
out, err = p.communicate()
returncode = p.returncode
except FileNotFoundError as file_exp:
print('pandoc and/or bash not found')
out, err, returncode = -1,str(file_exp), -1
if returncode != 0:
print("pandoc conversion failed: " + err)
long_description = out
else:
Expand Down
201 changes: 201 additions & 0 deletions tests/keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-


import os
import unittest
from itertools import product

import numpy as np
import tensorflow as tf

from .base import TestCase, td

# noinspection PyUnresolvedReferences
try:

from keras.models import Sequential, Model, Input
from keras.layers import Dense, Convolution2D, MaxPooling2D, UpSampling2D, BatchNormalization, Dropout, Reshape, \
Conv2DTranspose, LSTM, LeakyReLU, Activation, RepeatVector, Lambda, LocallyConnected2D
from keras.optimizers import Adam
from keras.backend import tensorflow_backend as tfb
import keras.backend as K
from keras import applications as kapps # for bigger prebuilt models

KERAS_MISSING = False
except ImportError:
KERAS_MISSING = True

UNSUPPORTED_LAYERS = ['Dropout', 'BatchNormalization', 'UpSampling2D', 'Convolution2DTranspose',
'LSTM', 'RepeatVector', 'LocallyConnected2D']

__all__ = ["KerasTestCase"]


@unittest.skipIf(KERAS_MISSING, "requires Keras to be installed")
class KerasTestCase(TestCase):
def __init__(self, *args, **kwargs):
super(KerasTestCase, self).__init__(*args, **kwargs)
td.setup(tf)
K.set_image_dim_ordering('tf')

def test_deploy_tool(self):
c_model = KerasTestCase._build_simple_2d(use_leakyrelu=True, use_pooling=True)
t_model, in_mapping, out_mapping = td.deploy_keras(c_model)
print(in_mapping)
print(out_mapping)

self.assertIsInstance(t_model, td.Model, "Output should be tfdeploy model")
self.assertEqual(len(in_mapping), 1, "only one input")
self.assertIn('Reshape_input', in_mapping, "Reshape not found in input")
self.assertIn('MaxPooling2D', out_mapping, "MaxPooling not found in output")
self.assertEqual(len(out_mapping), 1, "only one ouput")
for c_mapping in [in_mapping, out_mapping]:
for keras_name, tf_name in c_mapping.items():
cur_tensor = t_model.get(tf_name)
self.assertIsNotNone(cur_tensor, "Layer: {} -> TF:{}, not found in model".format(
keras_name, tf_name))
self.assertIsInstance(cur_tensor, td.Tensor, "Layer should be tensor: {}".format(cur_tensor))

def test_cnn_models(self):
model_kwargs = dict(use_dense=False, use_dropout=False, use_pooling=False, use_bn=False, use_upsample=False,
use_conv2dtrans=False, use_lstm=False, use_leakyrelu=False, use_repeatvec=False,
use_lambda=False, use_locallyconnected=False)

def _try_args(**kw_args):
new_args = model_kwargs.copy()
new_args.update(kw_args)
return new_args

test_models = [('base_cnn', KerasTestCase._build_simple_2d())]
test_models += [(c_arg,
KerasTestCase._build_simple_2d(**_try_args(**{c_arg: True})))
for c_arg in model_kwargs.keys()]

deployed_models = []
for i, (model_name, cur_keras_model) in enumerate(test_models):
model_layers = ','.join(map(lambda x: x.name, cur_keras_model.layers))
out_path = "%04d.pkl" % i
try:
deployed_models += \
KerasTestCase.export_keras_model(cur_keras_model, out_path, model_name=model_layers)
except td.UnknownOperationException as uoe:
print('Model {}: {}'.format(i, model_name), 'could not be serialized', uoe)
bad_layer_count = sum([us_layer in model_layers for us_layer in UNSUPPORTED_LAYERS])
self.assertGreater(bad_layer_count, 0,
"Model contains no unsupported layers {}, "
"Unsupported Layers:{}".format(model_layers, UNSUPPORTED_LAYERS))

self.assertGreater(len(deployed_models), 0, "No models could be tested")
print("Testing #{} models".format(len(deployed_models)))
for c_model_pkl in deployed_models:
result = KerasTestCase.deploy_model(c_model_pkl)
self.assertIsNotNone(result, "Result should not be empty")
self.assertEqual(len(result.shape), 4, "Output should be 4D Tensor: {}".format(result.shape))
os.remove(c_model_pkl['path'])

@unittest.skip("Takes quite awhile to run (and fails for all models)")
def test_big_models(self):
"""
A test for bigger commonly used pretrained models (for this we skip the weights)
:return:
"""
kapp_kwargs = dict(
input_shape=(99, 99, 3),
weights=None,
include_top=False # so we can use different sizes
)
test_models = []

test_models += [('Resnet50', kapps.ResNet50(**kapp_kwargs))]
test_models += [('InceptionV3', kapps.InceptionV3(**kapp_kwargs))]
test_models += [('VGG19', kapps.VGG19(**kapp_kwargs))]
test_models += [('Xception', kapps.Xception(**kapp_kwargs))]

for i, (model_name, cur_keras_model) in enumerate(test_models):

model_layers = ','.join(map(lambda x: x.name, cur_keras_model.layers))
out_path = "%04d.pkl" % i
try:
c_model_pkl = KerasTestCase.export_keras_model(cur_keras_model, out_path, model_name=model_layers)
except td.UnknownOperationException as uoe:
print('Model {}: {}'.format(i, model_layers), 'could not be serialized', uoe)
bad_layer_count = sum([us_layer in model_layers for us_layer in UNSUPPORTED_LAYERS])
self.assertGreater(bad_layer_count, 0,
"Model contains no unsupported layers {}, "
"Unsupported Layers:{}".format(model_layers, UNSUPPORTED_LAYERS))
continue
except tf.errors.RESOURCE_EXHAUSTED:
# many of the bigger models take up quite a bit of GPU memory
print('Model {} with #{} layers is too big for memory'.format(model_name, len(cur_keras_model.layers)))

result = KerasTestCase.deploy_model(c_model_pkl, np.random.uniform(0, 1, size=(299, 299, 3)))
self.assertIsNotNone(result, "Result should not be empty")
self.assertEqual(len(result.shape), 4, "Output should be 4D Tensor: {}".format(result.shape))
os.remove(c_model_pkl['path'])

@staticmethod
def deploy_model(c_model_pkl, input=None):
model = td.Model(c_model_pkl['path'])
inp, outp = model.get(c_model_pkl['input'], c_model_pkl['output'])
if input is None:
input = np.random.rand(50, 81)
return outp.eval({inp: input})

@staticmethod
def export_keras_model(in_ks_model, out_path, model_name):
td_model = td.Model()
td_model.add(in_ks_model.get_output_at(0),
tfb.get_session()) # y and all its ops and related tensors are added recursively

td_model.save(out_path)
return [dict(path=out_path,
output=in_ks_model.get_output_at(0).name,
input=in_ks_model.get_input_at(0).name,
name=model_name)]

@staticmethod
def compile_model(i_model):
i_model.compile(optimizer=Adam(lr=2e-3), loss='mse')

@staticmethod
def _build_simple_2d(use_dense=False, use_dropout=False, use_pooling=False, use_bn=False, use_upsample=False,
use_conv2dtrans=False, use_lstm=False, use_leakyrelu=False, use_repeatvec=False,
use_lambda=False, use_locallyconnected=False):
"""
Simple function for building CNN models with various layers turned on and off
:param use_dropout:
:param use_pooling: maxpooling2d
:param use_bn: batchnormalization
:param use_upsample:
:return:
"""
out_model = Sequential()
if use_lstm:
out_model.add(Reshape(target_shape=(1, 81), input_shape=(81,), name='Reshape_LSTM'))
out_model.add(LSTM(81, name='LSTM'))
if use_dense:
out_model.add(Dense(81, input_shape=(81,), name='Dense'))
if use_repeatvec:
out_model.add(RepeatVector(3, input_shape=(81,), name='RepeatVector'))
out_model.add(Lambda(lambda x: x[0, :], name='Lambda'))
out_model.add(Reshape(target_shape=(9, 9, 1), input_shape=(81,), name='Reshape'))
out_model.add(Convolution2D(2, (3, 3), input_shape=(9, 9, 1), name='Convolution2D'))
if use_lambda:
out_model.add(Lambda(lambda x: x + 1, name='Lambda_add'))
if use_leakyrelu:
out_model.add(LeakyReLU(0.1, name='LeakyRelu'))
if use_dropout:
out_model.add(Dropout(0.5, name='Dropout'))
if use_pooling:
out_model.add(MaxPooling2D((2, 2), name='MaxPooling2D'))
if use_upsample:
out_model.add(UpSampling2D((2, 2), name='UpSampling2D'))
if use_bn:
out_model.add(BatchNormalization(name='BatchNormalization'))
if use_conv2dtrans:
out_model.add(Conv2DTranspose(2, kernel_size=(3, 3), strides=(2, 2), name='Convolution2DTranspose'))
if use_locallyconnected:
out_model.add(LocallyConnected2D(3, (3, 3), name='LocallyConnected2D'))

KerasTestCase.compile_model(out_model)
return out_model
46 changes: 44 additions & 2 deletions tfdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"InvalidImplementationException", "UnknownImplementationException",
"EnsembleMismatchException", "ScipyOperationException",
"reset", "optimize", "print_tensor", "print_op", "print_tf_tensor", "print_tf_op",
"deploy_keras",
"IMPL_NUMPY", "IMPL_SCIPY", "IMPLS",
"METHOD_MEAN", "METHOD_MAX", "METHOD_MIN", "METHOD_CUSTOM", "METHODS",
"HAS_SCIPY"]
Expand All @@ -29,6 +30,7 @@
import re
from uuid import uuid4
from functools import reduce
from collections import OrderedDict

try:
# python 2
Expand Down Expand Up @@ -56,7 +58,6 @@ def wrapper(cls):
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper


class Model(object):
"""
A trained model that contains one or more converted tensorflow graphs. When *path* is set, a
Expand Down Expand Up @@ -177,7 +178,7 @@ def save(self, path):
"""
path = os.path.expandvars(os.path.expanduser(path))
with open(path, "wb") as f:
pickle.dump(self.roots, f)
pickle.dump(self.roots, f, protocol = 2) # make it python2 compatible always


class TensorRegister(type):
Expand All @@ -196,6 +197,47 @@ def __call__(cls, tf_tensor, *args, **kwargs):
return cls.instances[tf_tensor]


def deploy_keras(in_keras_model):
# type: (keras.models.Model) -> Tuple[Model, Dict[str, str], Dict[str, str]]
"""
Converts a keras model (>2.0) to a tfdeploy model and provides a list of input and output
mappings from keras layer names to tensorflow tensor names
:param in_keras_model: the keras model to convert
:return: the tfdeploy model, a dictionary mapping inputs and a dictionary mapping outputs
The dictionaries map keras layer names to tensorflow names
Usage
====
>>> from keras.models import Sequential, Model
>>> from keras.layers import Convolution2D
>>> k_model = Sequential()
>>> k_model.add(Convolution2D(5, (3,3), input_shape = (9,9,1)))
>>> k_model.compile('sgd', 'mse')
>>> t_model, i_names, o_names = deploy_keras(k_model)
>>> type(t_model)
<class 'tfdeploy.Model'>
>>> i_names
OrderedDict([('conv2d_1_input', 'conv2d_1_input:0')])
>>> o_names
OrderedDict([('conv2d_1', 'conv2d_1/BiasAdd:0')])
"""
try:
from keras.backend import tensorflow_backend as tfb
except ImportError:
raise NotImplementedError("Keras is not installed or not setup with the tensorflow backend!")

td_model = Model()
keras_in_mapping = OrderedDict()
for i, in_name in enumerate(in_keras_model.input_names):
keras_in_mapping[in_name] = in_keras_model.get_input_at(i).name

keras_out_mapping = OrderedDict()
for i, out_name in enumerate(in_keras_model.output_names):
keras_out_mapping[out_name] = in_keras_model.get_output_at(i).name
td_model.add(in_keras_model.get_output_at(i),
tfb.get_session()) # y and all its ops and related tensors are added recursively

return td_model, keras_in_mapping, keras_out_mapping

@add_metaclass(TensorRegister)
class Tensor(object):
"""
Expand Down