diff --git a/.travis.yml b/.travis.yml index db3413a0..c03ee0c2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,7 @@ language: python python: - "2.7" + - "3.7" install: - pip install -r requirements.txt - pip install . diff --git a/CHANGELOG b/CHANGELOG index a03874cd..69c4cc86 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1 +1,2 @@ v0.5, 07/18/2019 -- Initial release. +v0.6, 12/03/2019 -- Support tensorflow 2.0 and tf.keras diff --git a/examples/example_act.py b/examples/example_act.py index eff11407..1aa1c901 100644 --- a/examples/example_act.py +++ b/examples/example_act.py @@ -19,7 +19,9 @@ from __future__ import print_function import warnings import numpy as np -import keras.backend as K + +import tensorflow as tf +import tensorflow.keras.backend as K from qkeras import binary from qkeras import bernoulli diff --git a/examples/example_cifar10_po2.py b/examples/example_cifar10_po2.py index c9b4905b..8730fc2a 100644 --- a/examples/example_cifar10_po2.py +++ b/examples/example_cifar10_po2.py @@ -22,12 +22,12 @@ import os from collections import defaultdict -import keras.backend as K -from keras.datasets import cifar10 -from keras.layers import * -from keras.models import Model -from keras.optimizers import * -from keras.utils.np_utils import to_categorical +import tensorflow.keras.backend as K +from tensorflow.keras.datasets import cifar10 +from tensorflow.keras.layers import * +from tensorflow.keras.models import Model +from tensorflow.keras.optimizers import * +from tensorflow.keras.utils import to_categorical import numpy as np from qkeras import * diff --git a/examples/example_keras_to_qkeras.py b/examples/example_keras_to_qkeras.py index 93378344..b32423ca 100644 --- a/examples/example_keras_to_qkeras.py +++ b/examples/example_keras_to_qkeras.py @@ -21,8 +21,8 @@ from collections import defaultdict -from keras.layers import * -from keras.models import Model +from tensorflow.keras.layers import * +from tensorflow.keras.models import Model from qkeras import * diff --git a/examples/example_mnist.py b/examples/example_mnist.py index 2011c958..eaea5dca 100644 --- a/examples/example_mnist.py +++ b/examples/example_mnist.py @@ -22,17 +22,16 @@ import os from collections import defaultdict -import keras.backend as K -from keras.datasets import mnist -from keras.layers import Activation -from keras.layers import Flatten -from keras.layers import Input -from keras.layers import * -from keras.models import Model -from keras.optimizers import Adam -from keras.optimizers import SGD -from keras.optimizers import TFOptimizer -from keras.utils.np_utils import to_categorical +import tensorflow.keras.backend as K +from tensorflow.keras.datasets import mnist +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import Flatten +from tensorflow.keras.layers import Input +from tensorflow.keras.layers import * +from tensorflow.keras.models import Model +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.optimizers import SGD +from tensorflow.keras.utils import to_categorical from qkeras import * diff --git a/examples/example_mnist_b2t.py b/examples/example_mnist_b2t.py index d2287d05..22286645 100644 --- a/examples/example_mnist_b2t.py +++ b/examples/example_mnist_b2t.py @@ -21,16 +21,16 @@ import os -import keras.backend as K -from keras.datasets import mnist -from keras.layers import Activation -from keras.layers import Flatten -from keras.layers import Input -from keras.layers import * -from keras.models import Model -from keras.optimizers import Adam -from keras.optimizers import SGD -from keras.utils.np_utils import to_categorical +import tensorflow.keras.backend as K +from tensorflow.keras.datasets import mnist +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import Flatten +from tensorflow.keras.layers import Input +from tensorflow.keras.layers import * +from tensorflow.keras.models import Model +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.optimizers import SGD +from tensorflow.keras.utils import to_categorical import numpy as np from qkeras import * diff --git a/examples/example_mnist_bn.py b/examples/example_mnist_bn.py new file mode 100644 index 00000000..d4765b04 --- /dev/null +++ b/examples/example_mnist_bn.py @@ -0,0 +1,205 @@ +# Copyright 2019 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests mnist batchnormalization used as learned scale factor.""" + +# to run, THRESHOLD=0.05 WITH_BN=1 EPOCHS=5 TRAIN=1 python example_mnist_bn.py + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from collections import defaultdict + +import tensorflow.keras.backend as K +from tensorflow.keras import callbacks +from tensorflow.keras.datasets import mnist +from tensorflow.keras.layers import * +from tensorflow.keras.models import Model +from tensorflow.keras.optimizers import * +from tensorflow.keras.utils import to_categorical +import numpy as np + +from qkeras import * + +np.random.seed(42) + +TRAIN=1 # int(os.environ.get("TRAIN", 0)): +NB_EPOCH = 2 # int(os.environ.get("EPOCHS",10)) +BATCH_SIZE = 64 +VERBOSE = 1 +NB_CLASSES = 10 +OPTIMIZER = Adam(lr=0.0001) +VALIDATION_SPLIT = 0.1 +WITH_BN = 1 # int(os.environ.get("WITH_BN",0)) +THRESHOLD = 0.1 # float(os.environ.get("THRESHOLD",0.1)) + +class LearningRateAdjuster(callbacks.Callback): + def __init__(self): + self.learning_rate_factor = 1.0 + pass + + def on_epoch_end(self, epochs, logs): + max_variance = -1 + + for layer in self.model.layers: + if layer.__class__.__name__ in [ + "BatchNormalization", + "QBatchNormalization" + ]: + variance = np.max(layer.get_weights()[-1]) + if variance > max_variance: + max_variance = variance + + if max_variance > 32 and self.learning_rate_factor < 100: + learning_rate = K.get_value(self.model.optimizer.learning_rate) + self.learning_rate_factor /= 2.0 + print("***** max_variance is {} / lr is {} *****".format( + max_variance, learning_rate)) + K.eval(K.update( + self.model.optimizer.learning_rate, learning_rate / 2.0 + )) + +lra = LearningRateAdjuster() + +(x_train, y_train), (x_test, y_test) = mnist.load_data() + +x_train = x_train.reshape(x_train.shape + (1,)).astype("float32") +x_test = x_test.reshape(x_test.shape + (1,)).astype("float32") + +x_train /= 255.0 +x_test /= 255.0 + +print(x_train.shape[0], "train samples") +print(x_test.shape[0], "test samples") + +print(y_train[0:10]) + +y_train = to_categorical(y_train, NB_CLASSES) +y_test = to_categorical(y_test, NB_CLASSES) + +x = x_in = Input(x_train.shape[1:], name="input") +#x = QActivation("quantized_relu_po2(4,1)", name="acti")(x) +x = QConv2D( + 128, (3, 3), + strides=1, + kernel_quantizer=ternary(threshold=THRESHOLD), #quantized_po2(4, 1), + bias_quantizer=quantized_bits(4,2,0) if not WITH_BN else None, + bias_range=4 if not WITH_BN else None, + use_bias=not WITH_BN, + name="conv2d_0_m")(x) +if WITH_BN: + x = QBatchNormalization( + gamma_quantizer=quantized_relu_po2(4,8), + variance_quantizer=quantized_relu_po2(6), + beta_quantizer=quantized_po2(4, 4), + gamma_range=8, + beta_range=4, + name="bn0")(x) +x = QActivation("quantized_relu(3,1)", name="act0_m")(x) +x = MaxPooling2D(2, 2, name="mp_0")(x) +x = QConv2D( + 256, (3, 3), + strides=1, + kernel_quantizer=ternary(threshold=THRESHOLD), #quantized_bits(2,0,1), + bias_quantizer=quantized_bits(4,2,1) if not WITH_BN else None, + bias_range=4 if not WITH_BN else None, + use_bias=not WITH_BN, + name="conv2d_1_m")(x) +if WITH_BN: + x = QBatchNormalization( + gamma_quantizer=quantized_relu_po2(4,8), + variance_quantizer=quantized_relu_po2(6), + beta_quantizer=quantized_po2(4, 4), + gamma_range=8, + beta_range=4, + name="bn1")(x) +x = QActivation("quantized_relu(3,1)", name="act1_m")(x) +x = MaxPooling2D(2, 2, name="mp_1")(x) +x = QConv2D( + 128, (3, 3), + strides=1, + kernel_quantizer=ternary(threshold=THRESHOLD), #quantized_bits(2,0,1), + bias_quantizer=quantized_bits(4,2,1) if not WITH_BN else None, + bias_range=4 if not WITH_BN else None, + use_bias=not WITH_BN, + name="conv2d_2_m")(x) +if WITH_BN: + x = QBatchNormalization( + gamma_quantizer=quantized_relu_po2(4,8), + variance_quantizer=quantized_relu_po2(6), + beta_quantizer=quantized_po2(4, 4), + gamma_range=8, + beta_range=4, + name="bn2")(x) +x = QActivation("quantized_relu(3,1)", name="act2_m")(x) +x = MaxPooling2D(2, 2, name="mp_2")(x) +x = Flatten()(x) +x = QDense( + NB_CLASSES, + kernel_quantizer=quantized_ulaw(4, 0, 1), + bias_quantizer=quantized_bits(4, 0, 1), + name="dense")( + x) +x = Activation("softmax", name="softmax")(x) + +model = Model(inputs=[x_in], outputs=[x]) +model.summary() + +model.compile( + loss="categorical_crossentropy", optimizer=OPTIMIZER, metrics=["accuracy"]) + + +if TRAIN: + history = model.fit( + x_train, y_train, batch_size=BATCH_SIZE, + epochs=NB_EPOCH, initial_epoch=1, verbose=VERBOSE, + validation_split=VALIDATION_SPLIT, + callbacks=[]) #lra]) + + outputs = [] + output_names = [] + + for layer in model.layers: + if layer.__class__.__name__ in [ + "QActivation", "QBatchNormalization", "Activation", "QDense", + "QConv2D", "QDepthwiseConv2D" + ]: + output_names.append(layer.name) + outputs.append(layer.output) + + model_debug = Model(inputs=[x_in], outputs=outputs) + + outputs = model_debug.predict(x_train) + + print("{:30} {: 8.4f} {: 8.4f}".format( + "input", np.min(x_train), np.max(x_train))) + + for n, p in zip(output_names, outputs): + print("{:30} {: 8.4f} {: 8.4f}".format(n, np.min(p), np.max(p)), end="") + layer = model.get_layer(n) + for i, weights in enumerate(layer.get_weights()): + if layer.get_quantizers()[i]: + weights = K.eval(layer.get_quantizers()[i](K.constant(weights))) + print(" ({: 8.4f} {: 8.4f})".format(np.min(weights), np.max(weights)), + end="") + print("") + + score = model.evaluate(x_test, y_test, verbose=False) + print("Test score:", score[0]) + print("Test accuracy:", score[1]) + +print_qstats(model) diff --git a/examples/example_mnist_po2.py b/examples/example_mnist_po2.py index 166a2f07..b8be769d 100644 --- a/examples/example_mnist_po2.py +++ b/examples/example_mnist_po2.py @@ -19,14 +19,14 @@ from __future__ import division from __future__ import print_function -import keras.backend as K -from keras.datasets import mnist -from keras.layers import Activation -from keras.layers import Flatten -from keras.layers import Input -from keras.models import Model -from keras.optimizers import Adam -from keras.utils.np_utils import to_categorical +import tensorflow.keras.backend as K +from tensorflow.keras.datasets import mnist +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import Flatten +from tensorflow.keras.layers import Input +from tensorflow.keras.models import Model +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.utils import to_categorical import numpy as np from qkeras import * # pylint: disable=wildcard-import diff --git a/examples/example_qdense.py b/examples/example_qdense.py index 1ae127ac..e95e22e1 100644 --- a/examples/example_qdense.py +++ b/examples/example_qdense.py @@ -20,19 +20,19 @@ import argparse -from keras.datasets import mnist -from keras.layers import Activation -from keras.layers import Input -from keras.models import Model -from keras.optimizers import Adam -from keras.utils.np_utils import to_categorical +from tensorflow.keras.datasets import mnist +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import Input +from tensorflow.keras.models import Model +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.utils import to_categorical import numpy as np -from qkeras.qkeras import print_qstats -from qkeras.qkeras import QActivation -from qkeras.qkeras import QDense -from qkeras.qkeras import quantized_bits -from qkeras.qkeras import ternary +from qkeras import print_qstats +from qkeras import QActivation +from qkeras import QDense +from qkeras import quantized_bits +from qkeras import ternary np.random.seed(42) diff --git a/examples/example_qoctave.py b/examples/example_qoctave.py index 481a8c1f..7e188835 100644 --- a/examples/example_qoctave.py +++ b/examples/example_qoctave.py @@ -16,13 +16,13 @@ """QOctave example.""" import numpy as np import sys -from keras import activations -from keras import initializers -import keras.backend as K -from keras.layers import Input -from keras.models import Model -from keras.optimizers import Adam -from keras.utils.np_utils import to_categorical +from tensorflow.keras import activations +from tensorflow.keras import initializers +import tensorflow.keras.backend as K +from tensorflow.keras.layers import Input +from tensorflow.keras.models import Model +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.utils import to_categorical from functools import partial from qkeras import * # pylint: disable=wildcard-import diff --git a/qkeras/__init__.py b/qkeras/__init__.py index 52f2cbb0..7218339a 100644 --- a/qkeras/__init__.py +++ b/qkeras/__init__.py @@ -20,7 +20,8 @@ from .b2t import * # pylint: disable=wildcard-import from .estimate import * # pylint: disable=wildcard-import from .qlayers import * # pylint: disable=wildcard-import +from .qnormalization import * # pylint: disable=wildcard-import from .qoctave import * # pylint: disable=wildcard-import from .safe_eval import * # pylint: disable=wildcard-import -__version__ = "0.5.0" +__version__ = "0.6.0" diff --git a/qkeras/b2t.py b/qkeras/b2t.py index 71edc9cf..e5e8964b 100644 --- a/qkeras/b2t.py +++ b/qkeras/b2t.py @@ -15,7 +15,7 @@ # ============================================================================== """Implements total/partial Binary to Thermometer decoder.""" -from keras.utils import to_categorical +from tensorflow.keras.utils import to_categorical import numpy as np diff --git a/qkeras/estimate.py b/qkeras/estimate.py index 5028791d..c5c23bd6 100644 --- a/qkeras/estimate.py +++ b/qkeras/estimate.py @@ -30,12 +30,12 @@ from collections import defaultdict -from keras.layers import Activation -from keras.layers import InputLayer -from keras.models import Model import numpy as np import tensorflow.compat.v1 as tf +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import InputLayer +from tensorflow.keras.models import Model from .qlayers import QActivation from .qlayers import QAveragePooling2D @@ -290,8 +290,9 @@ def get_operation_type(layer, output_cache): # for the input, get tensor input and search the cache that associates # the quantizer with a tensor - if output_cache.get(layer.input, None) is not None: - x_mode, x_bits, x_sign = get_quant_mode(output_cache.get(layer.input)) + if output_cache.get(layer.input.experimental_ref(), None) is not None: + x_mode, x_bits, x_sign = get_quant_mode( + output_cache.get(layer.input.experimental_ref())) else: print("cannot determine presently model for {}".format(layer.name)) return "null", (w_mode, -1), (w_bits, -1), (w_sign, -1) @@ -310,23 +311,28 @@ def create_activation_cache(model): # cache graph tensors' activations for l in model.layers: - output_cache[l.output] = l + output_cache[l.output.experimental_ref()] = l if isinstance(l, QActivation): - output_cache[l.output] = l.quantizer + output_cache[l.output.experimental_ref()] = l.quantizer elif isinstance(l, InputLayer): - output_cache[l.output] = quantized_relu(8, 0) - elif l.__class__.__name__ in ["QDense", "QConv2D", "QConv1D", - "QDepthwiseConv2D"]: - output_cache[l.output] = l.activation + # assume the input is 8-bit positive value + output_cache[l.output.experimental_ref()] = quantized_relu(8, 0) + elif l.__class__.__name__ in [ + "QDense", "QConv2D", "QConv1D", "QDepthwiseConv2D" + ]: + output_cache[l.output.experimental_ref()] = l.activation else: if isinstance(l.input, list): # right now, we just get the first one - we assume this is the leading # one. - all_q = [output_cache.get(l.input[i]) for i in range(len(l.input))] + all_q = [ + output_cache.get(l.input[i].experimental_ref()) + for i in range(len(l.input)) + ] q = all_q[0] else: - q = output_cache.get(l.input, None) - output_cache[l.output] = q + q = output_cache.get(l.input.experimental_ref(), None) + output_cache[l.output.experimental_ref()] = q if q is None: raise ValueError("Unknown operation in {}".format(l.name)) @@ -342,12 +348,19 @@ def extract_model_operations(model): operations = {} for layer in model.layers: + + if layer.__class__.__name__ == "InputLayer": + continue + if isinstance(layer.input, list): input_shape = [ - cache_o.get(layer.input[i], layer.input[i].get_shape()) - for i in range(len(layer.input))] + cache_o.get(layer.input[i].experimental_ref(), + layer.input[i].get_shape()) + for i in range(len(layer.input)) + ] else: - input_shape = cache_o.get(layer.input, layer.input.get_shape()) + input_shape = cache_o.get(layer.input.experimental_ref(), + layer.input.get_shape()) # Check if the inputs are a list of Dimensions if isinstance(input_shape, list): @@ -360,7 +373,8 @@ def extract_model_operations(model): input_shape[i] = tuple(shape) output_shape = layer.compute_output_shape(input_shape) - cache_o[layer.output] = output_shape + + cache_o[layer.output.experimental_ref()] = output_shape if layer.__class__.__name__ not in ["QDense", "QConv2D", "QConv1D", "QDepthwiseConv2D"]: @@ -374,11 +388,22 @@ def extract_model_operations(model): weight = layer.get_weights()[0] + kernel_h, kernel_w, _, _ = weight.shape number_of_operations = ( height_o * width_o * channels_o * kernel_h * kernel_w * channels_i) + number_of_weights = (kernel_h * kernel_w * channels_o * channels_i) + + number_of_bias = 0 + if len(layer.get_weights()) > 1: + number_of_bias = layer.get_weights()[1].shape[0] + + weight_quant, bias_quant = layer.get_quantizers() + weight_type = get_quant_mode(weight_quant) + bias_type = get_quant_mode(bias_quant) + elif layer.__class__.__name__ in ["QConv1D"]: _, _, channels_i = input_shape @@ -392,6 +417,15 @@ def extract_model_operations(model): number_of_operations = ( time_o * channels_o * kernel_h * kernel_w * channels_i) + number_of_weights = (kernel_h * kernel_w * channels_o * channels_i) + number_of_bias = 0 + if len(layer.get_weights()) > 1: + number_of_bias = layer.get_weights()[1].shape[0] + + weight_quant, bias_quant = layer.get_quantizers() + weight_type = get_quant_mode(weight_quant) + bias_type = get_quant_mode(bias_quant) + elif layer.__class__.__name__ in ["QDepthwiseConv2D"]: _, _, _, channels_i = input_shape @@ -405,6 +439,16 @@ def extract_model_operations(model): number_of_operations = ( kernel_h * kernel_w * height_o * width_o * channels_i) + number_of_weights = (kernel_h * kernel_w * channels_o * channels_i) + + number_of_bias = 0 + if len(layer.get_weights()) > 1: + number_of_bias = layer.get_weights()[1].shape[0] + + weight_quant, bias_quant = layer.get_quantizers() + weight_type = get_quant_mode(weight_quant) + bias_type = get_quant_mode(bias_quant) + elif layer.__class__.__name__ in ["QDense"]: _, size_i = input_shape @@ -412,13 +456,32 @@ def extract_model_operations(model): number_of_operations = (size_i * size_o) + number_of_weights = size_i * size_o + number_of_bias = 0 + if len(layer.get_weights()) > 1: + number_of_bias = layer.get_weights()[1].shape[0] + + weight_quant, bias_quant = layer.get_quantizers() + weight_type = get_quant_mode(weight_quant) + bias_type = get_quant_mode(bias_quant) + # "number_of_operations" is tensor_shape.Dimension type operations[layer.name] = { "type": get_operation_type(layer, cache_q), "number_of_operations": number_of_operations if isinstance(number_of_operations, int) else - number_of_operations.value + number_of_operations.value, + "number_of_weights": + number_of_weights, + # if isinstance(number_of_weights, int) else number_of_weights.value, + "number_of_bias": + number_of_bias, + # if isinstance(number_of_bias, int) else number_of_bias.value, + "type_of_weights": + weight_type, + "type_of_bias": + bias_type, } return operations @@ -446,3 +509,15 @@ def print_qstats(model): for key in sorted(ops_table.keys()): if ops_table[key] > 0: print(" {:30}: {}".format(key, ops_table[key])) + + print("") + print("Weight profiling:") + for name in sorted(model_ops): + w_mode, w_sizes, w_signs = model_ops[name]["type_of_weights"] + b_mode, b_sizes, b_signs = model_ops[name]["type_of_bias"] + w_number = model_ops[name]["number_of_weights"] + b_number = model_ops[name]["number_of_bias"] + print(" {:30} : {:5} ({}-bit unit)".format( + str(name) + "_weights", str(w_number), str(w_sizes))) + print(" {:30} : {:5} ({}-bit unit)".format( + str(name) + "_bias", str(b_number), str(b_sizes))) diff --git a/qkeras/qlayers.py b/qkeras/qlayers.py index 81c01331..49f5b43a 100644 --- a/qkeras/qlayers.py +++ b/qkeras/qlayers.py @@ -40,26 +40,27 @@ import json import warnings -from keras import activations -from keras import constraints -from keras import initializers -from keras import regularizers -import keras.backend as K -from keras.constraints import Constraint -from keras.layers import Activation -from keras.layers import AveragePooling2D -from keras.layers import Conv1D -from keras.layers import Conv2D -from keras.layers import Dense -from keras.layers import Dropout -from keras.layers import InputSpec -from keras.layers import Layer -from keras.models import model_from_json -from keras.utils import conv_utils +import tensorflow as tf + +from tensorflow.keras import activations +from tensorflow.keras import constraints +from tensorflow.keras import initializers +from tensorflow.keras import regularizers +from tensorflow.keras.constraints import Constraint +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import AveragePooling2D +from tensorflow.keras.layers import Conv1D +from tensorflow.keras.layers import Conv2D +from tensorflow.keras.layers import Dense +from tensorflow.keras.layers import DepthwiseConv2D +from tensorflow.keras.layers import Dropout +from tensorflow.keras.layers import InputSpec +from tensorflow.keras.layers import Layer +from tensorflow.keras.models import model_from_json + import numpy as np import six -import tensorflow.compat.v1 as tf from .safe_eval import safe_eval @@ -75,13 +76,13 @@ def smooth_sigmoid(x): # smaller than hard_simoid but the arithmetic for it is (x >> 3) + # (x >> 4) + 0.5, which is also not bad. - return K.clip(0.1875 * x + 0.5, 0.0, 1.0) + return tf.keras.backend.clip(0.1875 * x + 0.5, 0.0, 1.0) def hard_sigmoid(x): """Computes hard_sigmoid function that saturates between 0 and 1.""" - return K.clip(0.5 * x + 0.5, 0.0, 1.0) + return tf.keras.backend.clip(0.5 * x + 0.5, 0.0, 1.0) def binary_sigmoid(x): @@ -109,7 +110,7 @@ def set_internal_sigmoid(mode): elif mode == "smooth": _sigmoid = smooth_sigmoid elif mode == "real": - _sigmoid = K.sigmoid + _sigmoid = tf.sigmoid def binary_tanh(x): @@ -129,27 +130,31 @@ def smooth_tanh(x): def stochastic_round(x): """Performs stochastic rounding to the first decimal point.""" - s = K.sign(x) - s += (1.0 - K.abs(s)) * (2.0 * K.round(K.random_uniform(K.shape(x))) - 1.0) + s = tf.sign(x) + s += (1.0 - tf.abs(s)) * (2.0 * tf.round(tf.random.uniform(tf.shape(x))) - + 1.0) t = tf.floor(x) - (s - 1.0) / 2.0 - p = K.abs(x - t) - f = s * (K.sign(p - K.random_uniform(K.shape(p))) + 1.0) / 2.0 + p = tf.abs(x - t) + f = s * (tf.sign(p - tf.random.uniform(tf.shape(p))) + 1.0) / 2.0 return t + f + def stochastic_round_po2(x): """Performs stochastic rounding for the power of two.""" - y = K.abs(x) - eps = K.epsilon() - log2 = K.log(2.0) - x_log2 = K.round(K.log(y + eps) / log2) - sign = K.sign(x) - po2 = K.cast(K.pow(2.0, K.cast(x_log2, dtype="float32")), dtype="float32") + # TODO(hzhuang): test stochastic_round_po2 and constraint. + # because quantizer is applied after constraint. + y = tf.abs(x) + eps = tf.keras.backend.epsilon() + log2 = tf.keras.backend.log(2.0) + x_log2 = tf.round(tf.keras.backend.log(y + eps) / log2) + sign = tf.sign(x) + po2 = tf.cast(pow(2.0, tf.cast(x_log2, dtype="float32")), dtype="float32") left_val = tf.where(po2 > y, x_log2 - 1, x_log2) right_val = tf.where(po2 > y, x_log2, x_log2 + 1) # sampling in [2**left_val, 2**right_val]. minval = 2 ** left_val maxval = 2 ** right_val - val = K.random_uniform(K.shape(y), minval=minval, maxval=maxval) + val = tf.random.uniform(tf.shape(y), minval=minval, maxval=maxval) # use y as a threshold to keep the probabliy [2**left_val, y, 2**right_val] # so that the mean value of the sample should be y x_po2 = tf.where(y < val, left_val, right_val) @@ -180,26 +185,26 @@ def _round_through(x, use_stochastic_rounding=False): Rounded tensor. """ if use_stochastic_rounding: - return x + K.stop_gradient(-x + stochastic_round(x)) + return x + tf.stop_gradient(-x + stochastic_round(x)) else: - return x + K.stop_gradient(-x + K.round(x)) + return x + tf.stop_gradient(-x + tf.round(x)) def _sign_through(x): """Computes the sign operation using the straight through estimator.""" - # K.sign generates -1, 0 or +1, so it should not be used when we attempt + # tf.sign generates -1, 0 or +1, so it should not be used when we attempt # to generate -1 and +1. - k_sign = K.sign(x) + k_sign = tf.sign(x) - return x + K.stop_gradient(-x + k_sign) + return x + tf.stop_gradient(-x + k_sign) def _ceil_through(x): """Computes the ceiling operation using straight through estimator.""" - return x + K.stop_gradient(-x + tf.ceil(x)) + return x + tf.stop_gradient(-x + tf.ceil(x)) # @@ -272,15 +277,15 @@ def __call__(self, x): m = pow(2, unsigned_bits) m_i = pow(2, self.integer) p = x * m / m_i - xq = m_i * K.clip( + xq = m_i * tf.keras.backend.clip( _round_through(p, self.use_stochastic_rounding), self.keep_negative * (-m + self.symmetric), m - 1) / m else: - xq = K.sign(x) - xq += (1.0 - K.abs(xq)) + xq = tf.sign(x) + xq += (1.0 - tf.abs(xq)) if not self.keep_negative: xq = (xq + 1.0) / 2.0 - return x + K.stop_gradient(-x + xq) + return x + tf.stop_gradient(-x + xq) class bernoulli(object): # pylint: disable=invalid-name @@ -314,9 +319,9 @@ def __init__(self, alpha=1.0): def __call__(self, x): p = _sigmoid(x / self.alpha) - k_sign = K.sign(p - K.random_uniform(K.shape(p))) - k_sign += (1.0 - K.abs(k_sign)) - return x + K.stop_gradient(-x + self.alpha * (k_sign + 1.0) / 2.0) + k_sign = tf.sign(p - tf.random.uniform(tf.shape(p))) + k_sign += (1.0 - tf.abs(k_sign)) + return x + tf.stop_gradient(-x + self.alpha * (k_sign + 1.0) / 2.0) class stochastic_ternary(object): # pylint: disable=invalid-name @@ -358,8 +363,8 @@ def __call__(self, x): T = self.threshold # pylint: disable=invalid-name - ones = K.ones_like(p) - zeros = K.zeros_like(p) + ones = tf.ones_like(p) + zeros = tf.zeros_like(p) T0 = np.clip(0.5 + T, 0.5, 1.0) # pylint: disable=invalid-name @@ -372,9 +377,9 @@ def __call__(self, x): c_fm1 = fm1 / f_all c_f0 = (fm1 + f0) / f_all - r = K.random_uniform(K.shape(p)) + r = tf.random.uniform(tf.shape(p)) - return x + K.stop_gradient(-x + self.alpha * tf.where( + return x + tf.stop_gradient(-x + self.alpha * tf.where( r <= c_fm1, -1 * ones, tf.where(r <= c_f0, zeros, ones))) @@ -402,9 +407,9 @@ def __call__(self, x): if self.use_stochastic_rounding: x = _round_through( x, use_stochastic_rounding=self.use_stochastic_rounding) - return x + K.stop_gradient( - -x + self.alpha * tf.where(K.abs(x) < self.threshold, - K.zeros_like(x), K.sign(x))) + return x + tf.stop_gradient( + -x + self.alpha * tf.where(tf.abs(x) < self.threshold, + tf.zeros_like(x), tf.sign(x))) class stochastic_binary(object): # pylint: disable=invalid-name @@ -429,11 +434,11 @@ def __init__(self, alpha=1.0): def __call__(self, x): assert self.alpha != 0 p = _sigmoid(x / self.alpha) - k_sign = K.sign(p - tf.random_uniform(tf.shape(x))) - # we should not need this, but if K.sign is not safe if input is + k_sign = tf.sign(p - tf.random.uniform(tf.shape(x))) + # we should not need this, but if tf.sign is not safe if input is # exactly 0.0 - k_sign += (1.0 - K.abs(k_sign)) - return x + K.stop_gradient(-x + self.alpha * k_sign) + k_sign += (1.0 - tf.abs(k_sign)) + return x + tf.stop_gradient(-x + self.alpha * k_sign) class binary(object): # pylint: disable=invalid-name @@ -467,15 +472,15 @@ def __call__(self, x): x = self.alpha * _round_through( x / self.alpha, use_stochastic_rounding=self.use_stochastic_rounding) - k_sign = K.sign(x) + k_sign = tf.sign(x) if self.use_stochastic_rounding: - k_sign += (1.0 - K.abs(k_sign)) * ( - 2.0 * K.round(K.random_uniform(K.shape(x))) - 1.0) + k_sign += (1.0 - tf.abs(k_sign)) * ( + 2.0 * tf.round(tf.random.uniform(tf.shape(x))) - 1.0) else: - k_sign += (1.0 - K.abs(k_sign)) + k_sign += (1.0 - tf.abs(k_sign)) if self.use_01: k_sign = (k_sign + 1.0) / 2.0 - return x + K.stop_gradient(-x + self.alpha * k_sign) + return x + tf.stop_gradient(-x + self.alpha * k_sign) class quantized_relu(object): # pylint: disable=invalid-name @@ -520,12 +525,12 @@ def __call__(self, x): if self.use_sigmoid: p = _sigmoid(x / m_i) * m - xq = m_i * K.clip( + xq = m_i * tf.keras.backend.clip( 2.0 * (_round_through(p, self.use_stochastic_rounding) / m) - 1.0, 0.0, 1.0 - 1.0 / m) else: p = x * m / m_i - xq = m_i * K.clip( + xq = m_i * tf.keras.backend.clip( _round_through(p, self.use_stochastic_rounding) / m, 0.0, 1.0 - 1.0 / m) return xq @@ -557,8 +562,10 @@ def __call__(self, x): m_i = pow(2, self.integer) p = _sigmoid(x / m_i) * m rp = 2.0 * (_round_through(p) / m) - 1.0 - u_law_p = K.sign(rp) * K.log(1 + self.u * K.abs(rp)) / K.log(1 + self.u) - xq = m_i * K.clip(u_law_p, -1.0 + (1.0 * self.symmetric) / m, 1.0 - 1.0 / m) + u_law_p = tf.sign(rp) * tf.keras.backend.log( + 1 + self.u * tf.abs(rp)) / tf.keras.backend.log(1 + self.u) + xq = m_i * tf.keras.backend.clip(u_law_p, -1.0 + + (1.0 * self.symmetric) / m, 1.0 - 1.0 / m) return xq @@ -592,76 +599,119 @@ def __call__(self, x): m = pow(2, non_sign_bits) m_i = pow(2, self.integer) p = _sigmoid(x / m_i) * m - xq = m_i * K.clip( - 2.0 * (_round_through(p, self.use_stochastic_rounding) / m) - 1.0, - -1.0 + (1.0 * self.symmetric) / m, 1.0 - 1.0 / m) + xq = m_i * tf.keras.backend.clip( + 2.0 * + (_round_through(p, self.use_stochastic_rounding) / m) - 1.0, -1.0 + + (1.0 * self.symmetric) / m, 1.0 - 1.0 / m) return xq class quantized_po2(object): # pylint: disable=invalid-name """Quantizes to the closest power of 2.""" - def __init__(self, bits=8, max_value=-1, use_stochastic_rounding=False): + def __init__(self, + bits=8, + max_value=-1, + use_stochastic_rounding=False, + quadratic_approximation=False): self.bits = bits self.max_value = max_value self.use_stochastic_rounding = use_stochastic_rounding + # if True, round to the exponent for sqrt(x), + # so that the return value can be divided by two without remainder. + self.quadratic_approximation = quadratic_approximation def __call__(self, x): + + need_exponent_sign_bit = (self.max_value > 1) non_sign_bits = self.bits - 1 - min_exp = -2**(non_sign_bits - 1) - max_exp = 2**(non_sign_bits - 1) - 1 - eps = K.epsilon() + min_exp = -2**(non_sign_bits - need_exponent_sign_bit) + max_exp = 2**(non_sign_bits - need_exponent_sign_bit) - 1 + eps = tf.keras.backend.epsilon() if min_exp < np.log2(eps): warnings.warn( - "QKeras: min_exp in po2 quantizer is smaller than K.epsilon()") + "QKeras: min_exp in po2 quantizer is smaller than tf.epsilon()") if self.max_value != -1: max_exp = np.round(np.log2(self.max_value + eps)) - x_sign = K.sign(x) - x_sign += (1.0 - K.abs(x_sign)) + x_sign = tf.sign(x) + x_sign += (1.0 - tf.abs(x_sign)) log2 = np.log(2.0) + # if True, round to the exponent for sqrt(x), + # so that the return value can be divided by two without remainder. + if self.quadratic_approximation: + q_factor = 2.0 + else: + q_factor = 1.0 + if self.use_stochastic_rounding: - x_log2 = stochastic_round_po2(x) + if self.quadratic_approximation: + x_log2 = stochastic_round_po2(tf.sqrt(x)) + else: + x_log2 = stochastic_round_po2(x) else: - x_log2 = _round_through(K.log(K.abs(x) + eps) / log2) - return x + K.stop_gradient( - -x + x_sign * K.pow(2.0, K.clip(x_log2, min_exp, max_exp))) + if self.quadratic_approximation: + x_log2 = _round_through(tf.keras.backend.log(tf.sqrt(x) + eps) / log2) + else: + x_log2 = _round_through(tf.keras.backend.log(tf.abs(x) + eps) / log2) + x_clipped = q_factor * tf.keras.backend.clip(x_log2, min_exp, max_exp) + return x + tf.stop_gradient(-x + x_sign * pow(2.0, x_clipped)) class quantized_relu_po2(object): # pylint: disable=invalid-name """Quantizes to the closest power of 2.""" - def __init__(self, bits=8, max_value=-1, use_stochastic_rounding=False): + def __init__(self, bits=8, max_value=-1, use_stochastic_rounding=False, + quadratic_approximation=False): self.bits = bits self.max_value = max_value self.use_stochastic_rounding = use_stochastic_rounding - def __call__(self, x): + # if True, round to the exponent for sqrt(x), + # so that the return value can be divided by two without remainder. + self.quadratic_approximation = quadratic_approximation - min_exp = -2**(self.bits - 1) - max_exp = 2**(self.bits - 1) - 1 + def __call__(self, x): - eps = K.epsilon() + need_exponent_sign_bit = (self.max_value > 1) + min_exp = -2**(self.bits - need_exponent_sign_bit) + max_exp = 2**(self.bits - need_exponent_sign_bit) - 1 + eps = tf.keras.backend.epsilon() if min_exp < np.log2(eps): warnings.warn( - "QKeras: min_exp in relu_po2 quantizer is smaller than K.epsilon()") + "QKeras: min_exp in quantized_relu_po2 quantizer " + "is smaller than tf.epsilon()") log2 = np.log(2.0) if self.max_value != -1: max_exp = np.round(np.log2(self.max_value + eps)) - x = K.maximum(x, 0) + if self.quadratic_approximation: + q_factor = 2.0 + else: + q_factor = 1.0 + x = tf.maximum(x, 0) + if self.use_stochastic_rounding: - x_log2 = stochastic_round_po2(x) + # if True, approximate the power of two to the sqrt(x) + # use q_factor to recover the value in x_clipped. + if self.quadratic_approximation: + x_log2 = stochastic_round_po2(tf.sqrt(x)) + else: + x_log2 = stochastic_round_po2(x) else: - x_log2 = _round_through(K.log(K.abs(x) + eps) / log2) - x_clipped = K.clip(x_log2, min_exp, max_exp) - return x + K.stop_gradient(-x + K.pow(2.0, x_clipped)) + if self.quadratic_approximation: + x_log2 = _round_through(tf.keras.backend.log(tf.sqrt(x) + eps) / log2) + else: + x_log2 = _round_through(tf.keras.backend.log(tf.abs(x) + eps) / log2) + x_clipped = q_factor * tf.keras.backend.clip(x_log2, min_exp, max_exp) + return x + tf.stop_gradient(-x + pow(2.0, x_clipped)) + # # Because it may be hard to get serialization from activation functions, @@ -722,16 +772,25 @@ class Clip(Constraint): # Constrains the weights to be between min/max values. # min_value: the minimum norm for the incoming weights. # max_value: the maximum norm for the incoming weights. + # constraint: previous constraint to be clipped. + # quantizer: quantizer to be applied to constraint. - def __init__(self, min_value=0.0, max_value=1.0): + def __init__(self, min_value=0.0, max_value=1.0, + constraint=None, quantizer=None): """Initializes Clip constraint class.""" self.min_value = min_value self.max_value = max_value + self.constraint = constraint + self.quantizer = quantizer def __call__(self, w): """Clips values between min and max values.""" - w = K.clip(w, self.min_value, self.max_value) + if self.constraint: + w = self.constraint(w) + if self.quantizer: + w = self.quantizer(w) + w = tf.keras.backend.clip(w, self.min_value, self.max_value) return w def get_config(self): @@ -809,10 +868,16 @@ def __init__(self, kernel_initializer = get_initializer(kernel_initializer, kernel_range) if kernel_quantizer: - kernel_constraint = Clip(-kernel_range, kernel_range) + if kernel_constraint: + kernel_constraint = constraint.get(kernel_constraint) + kernel_constraint = Clip(-kernel_range, kernel_range, kernel_constraint, + kernel_quantizer) if bias_quantizer: - bias_constraint = Clip(-bias_range, bias_range) + if bias_constraint: + bias_constraint = constraint.get(bias_constraint) + bias_constraint = Clip(-bias_range, bias_range, bias_constraint, + bias_quantizer) self.kernel_quantizer = kernel_quantizer self.bias_quantizer = bias_quantizer @@ -854,13 +919,14 @@ def call(self, inputs): quantized_kernel = self.kernel_quantizer_internal(self.kernel) else: quantized_kernel = self.kernel - output = K.dot(inputs, quantized_kernel) + output = tf.keras.backend.dot(inputs, quantized_kernel) if self.use_bias: if self.bias_quantizer: quantized_bias = self.bias_quantizer_internal(self.bias) else: quantized_bias = self.bias - output = K.bias_add(output, quantized_bias, data_format="channels_last") + output = tf.keras.backend.bias_add(output, quantized_bias, + data_format="channels_last") if self.activation is not None: output = self.activation(output) return output @@ -974,10 +1040,16 @@ def __init__(self, ] if kernel_quantizer: - kernel_constraint = Clip(-kernel_range, kernel_range) + if kernel_constraint: + kernel_constraint = constraints.get(kernel_constraint) + kernel_constraint = Clip(-kernel_range, kernel_range, kernel_constraint, + kernel_quantizer) if bias_quantizer: - bias_constraint = Clip(-bias_range, bias_range) + if bias_constraint: + bias_constraint = constraints.get(bias_constraint) + bias_constraint = Clip(-bias_range, bias_range, bias_constraint, + bias_quantizer) super(QConv1D, self).__init__( filters=filters, @@ -1005,7 +1077,7 @@ def call(self, inputs): else: quantized_kernel = self.kernel - outputs = K.conv1d( + outputs = tf.keras.backend.conv1d( inputs, quantized_kernel, strides=self.strides[0], @@ -1022,7 +1094,7 @@ def call(self, inputs): else: quantized_bias = self.bias - outputs = K.bias_add( + outputs = tf.keras.backend.bias_add( outputs, quantized_bias, data_format=self.data_format) if self.activation is not None: @@ -1110,10 +1182,16 @@ def __init__(self, ] if kernel_quantizer: - kernel_constraint = Clip(-kernel_range, kernel_range) + if kernel_constraint: + kernel_constraint = constraints.get(kernel_constraint) + kernel_constraint = Clip(-kernel_range, kernel_range, kernel_constraint, + kernel_quantizer) if bias_quantizer: - bias_constraint = Clip(-bias_range, bias_range) + if bias_constraint: + bias_constraint = constraints.get(bias_constraint) + bias_constraint = Clip(-bias_range, bias_range, bias_constraint, + bias_quantizer) super(QConv2D, self).__init__( filters=filters, @@ -1142,7 +1220,7 @@ def call(self, inputs): else: quantized_kernel = self.kernel - outputs = K.conv2d( + outputs = tf.keras.backend.conv2d( inputs, quantized_kernel, strides=self.strides, @@ -1159,7 +1237,7 @@ def call(self, inputs): else: quantized_bias = self.bias - outputs = K.bias_add( + outputs = tf.keras.backend.bias_add( outputs, quantized_bias, data_format=self.data_format) if self.activation is not None: @@ -1180,7 +1258,7 @@ def get_quantizers(self): return self.quantizers -class QDepthwiseConv2D(Conv2D): +class QDepthwiseConv2D(DepthwiseConv2D): """Creates quantized depthwise conv2d. Copied from mobilenet.""" # most of these parameters follow the implementation of DepthwiseConv2D @@ -1201,7 +1279,7 @@ class QDepthwiseConv2D(Conv2D): def __init__(self, kernel_size, strides=(1, 1), - padding="valid", + padding="VALID", depth_multiplier=1, data_format=None, activation=None, @@ -1221,31 +1299,34 @@ def __init__(self, **kwargs): if depthwise_quantizer: - depthwise_constraint = Clip(-depthwise_range, depthwise_range) + if depthwise_constraint: + depthwise_constraint = constraints.get(depthwise_constraint) + depthwise_constraint = Clip(-depthwise_range, depthwise_range, + depthwise_constraint, depthwise_quantizer) - if use_bias and bias_quantizer: - bias_constraint = Clip(-bias_range, bias_range) + if bias_quantizer: + if bias_constraint: + bias_constraint = constraints.get(bias_constraint) + bias_constraint = Clip(-bias_range, bias_range, bias_constraint, + bias_quantizer) super(QDepthwiseConv2D, self).__init__( - filters=None, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, activation=activation, use_bias=use_bias, + depthwise_regularizer=depthwise_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, + depth_multiplier=depth_multiplier, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_constraint=depthwise_constraint, bias_constraint=bias_constraint, dilation_rate=dilation_rate, **kwargs) - self.depth_multiplier = depth_multiplier - self.depthwise_initializer = initializers.get(depthwise_initializer) - self.depthwise_regularizer = regularizers.get(depthwise_regularizer) - self.depthwise_constraint = constraints.get(depthwise_constraint) - self.bias_initializer = initializers.get(bias_initializer) - - self.depthwise_constraint = depthwise_constraint self.bias_constraint = bias_constraint self.depthwise_quantizer = depthwise_quantizer @@ -1320,7 +1401,7 @@ def call(self, inputs, training=None): self.depthwise_quantizer_internal(self.depthwise_kernel)) else: quantized_depthwise_kernel = self.depthwise_kernel - outputs = K.depthwise_conv2d( + outputs = tf.keras.backend.depthwise_conv2d( inputs, quantized_depthwise_kernel, strides=self.strides, @@ -1336,7 +1417,7 @@ def call(self, inputs, training=None): quantized_bias = self.bias_quantizer_internal(self.bias) else: quantized_bias = self.bias - outputs = K.bias_add( + outputs = tf.keras.backend.bias_add( outputs, quantized_bias, data_format=self.data_format) if self.activation is not None: @@ -1347,27 +1428,6 @@ def call(self, inputs, training=None): return outputs - def compute_output_shape(self, input_shape): - if self.data_format == "channels_first": - rows = input_shape[2] - cols = input_shape[3] - out_filters = input_shape[1] * self.depth_multiplier - elif self.data_format == "channels_last": - rows = input_shape[1] - cols = input_shape[2] - out_filters = input_shape[3] * self.depth_multiplier - - rows = conv_utils.conv_output_length(rows, self.kernel_size[0], - self.padding, self.strides[0]) - - cols = conv_utils.conv_output_length(cols, self.kernel_size[1], - self.padding, self.strides[1]) - - if self.data_format == "channels_first": - return (input_shape[0], out_filters, rows, cols) - elif self.data_format == "channels_last": - return (input_shape[0], rows, cols, out_filters) - def get_config(self): config = super(QDepthwiseConv2D, self).get_config() config.pop("filters") @@ -1395,7 +1455,7 @@ def get_quantizers(self): def QSeparableConv2D(filters, # pylint: disable=invalid-name kernel_size, strides=(1, 1), - padding="valid", + padding="VALID", dilation_rate=(1, 1), depth_multiplier=1, activation=None, @@ -1598,8 +1658,8 @@ def model_save_quantized_weights(model, filename=None): signs = [] for quantizer, weight in zip(layer.get_quantizers(), layer.get_weights()): if quantizer: - weight = K.constant(weight) - weight = K.eval(quantizer(weight)) + weight = tf.constant(weight) + weight = tf.keras.backend.eval(quantizer(weight)) # If quantizer is power-of-2 (quantized_po2 or quantized_relu_po2), # we would like to process it here. diff --git a/qkeras/qnormalization.py b/qkeras/qnormalization.py new file mode 100644 index 00000000..623efb7a --- /dev/null +++ b/qkeras/qnormalization.py @@ -0,0 +1,316 @@ +# Copyright 2019 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== +"""Definition of normalization quantization package.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.keras import constraints +from tensorflow.keras import initializers +from tensorflow.keras import regularizers +from tensorflow.keras.layers import BatchNormalization +from tensorflow.python.framework import ops +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn + +import numpy as np +import six + +from .qlayers import Clip +from .safe_eval import safe_eval + + +class QBatchNormalization(BatchNormalization): + """Quantized Batch Normalization layer. + For training, mean and variance are not quantized. + For inference, the quantized moving mean and moving variance are used. + + output = (x - mean) / sqrt(var + epsilon) * quantized_gamma + quantized_beta + + """ + + def __init__( + self, + axis=-1, + momentum=0.99, + epsilon=1e-3, + center=True, + scale=True, + activation=None, + beta_initializer='zeros', + gamma_initializer='ones', + moving_mean_initializer='zeros', + moving_variance_initializer='ones', + beta_regularizer=None, + gamma_regularizer=None, + beta_quantizer=None, + gamma_quantizer=None, + mean_quantizer=None, + variance_quantizer=None, + # use quantized_po2 and enforce quadratic approximation + # to get an even exponent for sqrt + beta_range=None, + gamma_range=None, + **kwargs): + + self.beta_quantizer = beta_quantizer + self.gamma_quantizer = gamma_quantizer + self.mean_quantizer = mean_quantizer + self.variance_quantizer = variance_quantizer + self.activation = activation + self.beta_range = beta_range + self.gamma_range = gamma_range + + if isinstance(self.beta_quantizer, six.string_types): + self.beta_quantizer_internal = safe_eval( + self.beta_quantizer, globals()) + else: + self.beta_quantizer_internal = self.beta_quantizer + + if isinstance(self.gamma_quantizer, six.string_types): + self.gamma_quantizer_internal = safe_eval( + self.gamma_quantizer, globals()) + else: + self.gamma_quantizer_internal = self.gamma_quantizer + + if isinstance(self.mean_quantizer, six.string_types): + self.mean_quantizer_internal = safe_eval( + self.mean_quantizer, globals()) + else: + self.mean_quantizer_internal = self.mean_quantizer + + if isinstance(self.variance_quantizer, six.string_types): + self.variance_quantizer_internal = safe_eval( + self.variance_quantizer, globals()) + else: + self.variance_quantizer_internal = self.variance_quantizer + + self.quantizers = [ + self.gamma_quantizer_internal, + self.beta_quantizer_internal, + self.mean_quantizer_internal, + self.variance_quantizer_internal + ] + + if center and beta_quantizer and beta_range: + beta_constraint = Clip(-beta_range, beta_range) + else: + beta_constraint = None + + if scale and gamma_quantizer and gamma_range: + gamma_constraint = Clip(-gamma_range, gamma_range) + else: + gamma_constraint = None + + if kwargs.get('fused', None): + warning.warn('batch normalization fused is disabled ' + 'in qkeras qnormalization.py.') + del kwargs['fused'] + + if kwargs.get('renorm', None): + warning.warn('batch normalization renorm is disabled ' + 'in qkeras qnormalization.py.') + del kwargs['renorm'] + + if kwargs.get('virtual_batch_size', None): + warning.warn('batch normalization virtual_batch_size is disabled ' + 'in qkeras qnormalization.py.') + del kwargs['virtual_batch_size'] + + if kwargs.get('adjustment', None): + warning.warn('batch normalization adjustment is disabled ' + 'in qkeras qnormalization.py.') + del kwargs['adjustment'] + + super(QBatchNormalization, self).__init__( + axis=axis, + momentum=momentum, + epsilon=epsilon, + center=center, + scale=scale, + beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + moving_mean_initializer=moving_mean_initializer, + moving_variance_initializer=moving_variance_initializer, + beta_regularizer=beta_regularizer, + gamma_regularizer=gamma_regularizer, + beta_constraint=beta_constraint, + gamma_constraint=gamma_constraint, + fused=False, + renorm=False, + virtual_batch_size=None, + adjustment=None, + **kwargs) + + def call(self, inputs, training=None): + if self.scale and self.gamma_quantizer: + quantized_gamma = self.gamma_quantizer_internal(self.gamma) + else: + quantized_gamma = self.gamma + + if self.center and self.beta_quantizer: + quantized_beta = self.beta_quantizer_internal(self.beta) + else: + quantized_beta = self.beta + + if self.mean_quantizer: + quantized_moving_mean = self.mean_quantizer_internal(self.moving_mean) + else: + quantized_moving_mean = self.moving_mean + + if self.variance_quantizer: + quantized_moving_variance = self.variance_quantizer_internal( + self.moving_variance) + else: + quantized_moving_variance = self.moving_variance + + training = self._get_training_value(training) + + # Compute the axes along which to reduce the mean / variance + input_shape = inputs.shape + ndims = len(input_shape) + reduction_axes = [i for i in range(ndims) if i not in self.axis] + + # Broadcasting only necessary for single-axis batch norm where the axis is + # not the last dimension + broadcast_shape = [1] * ndims + broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value + def _broadcast(v): + if (v is not None and len(v.shape) != ndims and + reduction_axes != list(range(ndims - 1))): + return array_ops.reshape(v, broadcast_shape) + return v + + scale, offset = _broadcast(quantized_gamma), _broadcast(quantized_beta) + + # Determine a boolean value for `training`: could be True, False, or None. + training_value = tf_utils.constant_value(training) + if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison + quantized_mean, quantized_variance = (quantized_moving_mean, + quantized_moving_variance) + else: + # Some of the computations here are not necessary when training==False + # but not a constant. However, this makes the code simpler. + keep_dims = len(self.axis) > 1 + mean, variance = self._moments( + math_ops.cast(inputs, self._param_dtype), + reduction_axes, + keep_dims=keep_dims) + + moving_mean = self.moving_mean + moving_variance = self.moving_variance + + mean = tf_utils.smart_cond(training, + lambda: mean, + lambda: ops.convert_to_tensor(moving_mean)) + variance = tf_utils.smart_cond( + training, + lambda: variance, + lambda: ops.convert_to_tensor(moving_variance)) + + new_mean, new_variance = mean, variance + + if self.mean_quantizer: + quantized_mean = self.mean_quantizer_internal(mean) + else: + quantized_mean = mean + + if self.variance_quantizer: + quantized_variance = self.variance_quantizer_internal(variance) + else: + quantized_variance = variance + + if self._support_zero_size_input(): + inputs_size = array_ops.size(inputs) + else: + inputs_size = None + + def _do_update(var, value): + """Compute the updates for mean and variance.""" + return self._assign_moving_average(var, value, self.momentum, + inputs_size) + + def mean_update(): + true_branch = lambda: _do_update(self.moving_mean, new_mean) + false_branch = lambda: self.moving_mean + return tf_utils.smart_cond(training, true_branch, false_branch) + + def variance_update(): + """Update the moving variance.""" + true_branch = lambda: _do_update(self.moving_variance, new_variance) + false_branch = lambda: self.moving_variance + return tf_utils.smart_cond(training, true_branch, false_branch) + + self.add_update(mean_update) + self.add_update(variance_update) + + quantized_mean = math_ops.cast(quantized_mean, inputs.dtype) + quantized_variance = math_ops.cast(quantized_variance, inputs.dtype) + if offset is not None: + offset = math_ops.cast(offset, inputs.dtype) + if scale is not None: + scale = math_ops.cast(scale, inputs.dtype) + # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing + # math in float16 hurts validation accuracy of popular models like resnet. + outputs = nn.batch_normalization(inputs, + _broadcast(quantized_mean), + _broadcast(quantized_variance), + offset, + scale, + self.epsilon) + # If some components of the shape got lost due to adjustments, fix that. + outputs.set_shape(input_shape) + + return outputs + + def get_config(self): + config = { + 'axis': self.axis, + 'momentum': self.momentum, + 'epsilon': self.epsilon, + 'center': self.center, + 'scale': self.scale, + 'beta_quantizer': self.beta_quantizer, + 'gamma_quantizer': self.gamma_quantizer, + 'mean_quantizer': self.mean_quantizer, + 'variance_quantizer': self.variance_quantizer, + 'beta_initializer': initializers.serialize(self.beta_initializer), + 'gamma_initializer': initializers.serialize(self.gamma_initializer), + 'moving_mean_initializer': + initializers.serialize(self.moving_mean_initializer), + 'moving_variance_initializer': + initializers.serialize(self.moving_variance_initializer), + 'beta_regularizer': regularizers.serialize(self.beta_regularizer), + 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), + 'beta_constraint': constraints.serialize(self.beta_constraint), + 'gamma_constraint': constraints.serialize(self.gamma_constraint), + 'beta_range': self.beta_range, + 'gamma_range': self.gamma_range, + } + base_config = super(BatchNormalization, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_quantizers(self): + return self.quantizers diff --git a/qkeras/qoctave.py b/qkeras/qoctave.py index f72595be..5d1c2d18 100644 --- a/qkeras/qoctave.py +++ b/qkeras/qoctave.py @@ -21,12 +21,12 @@ import re -from keras.layers import Activation -from keras.layers import Add -from keras.layers import AveragePooling2D -from keras.layers import Conv2D -from keras.layers import SeparableConv2D -from keras.layers import UpSampling2D +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import Add +from tensorflow.keras.layers import AveragePooling2D +from tensorflow.keras.layers import Conv2D +from tensorflow.keras.layers import SeparableConv2D +from tensorflow.keras.layers import UpSampling2D from .qlayers import QActivation from .qlayers import QAveragePooling2D from .qlayers import QConv2D diff --git a/requirements.txt b/requirements.txt index 4cb53962..d3c72ec6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,9 @@ -keras>=2.2.4 -numpy>=1.14.5,<1.17 -pyparser>=1.0 -pytest==4.6.5 -scipy>=1.2.2,<1.3 -# tensorflow>=1.14.0,<2.0 +tensorflow>=2.1.0rc2 +numpy>=1.16 +pyparser +pytest +scipy==1.2.2 +setuptools>=41.0.0 +argparse>=1.4.0 +pyasn1<0.5.0,>=0.4.6 +requests<3,>=2.21.0 diff --git a/setup.py b/setup.py index 9354579a..6d7c2a60 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ setuptools.setup( name="QKeras", - version="0.5.0", + version="0.6.0", author="Claudionor N. Coelho", author_email="nunescoelho@google.com", maintainer="Hao Zhuang", @@ -39,17 +39,15 @@ description="Quantization package for Keras", long_description=long_description, install_requires=[ - "keras>=2.2.4", - "numpy>=1.14.5,<1.17", - "pyparser>=1.0", - "scipy>=1.2.2,<1.3", + "numpy>=1.16.0", + "scipy==1.2.2", + "pyparser", "setuptools>=41.0.0", - "tensorflow>=1.14.0,<2.0", ], setup_requires=[ "pytest-runner", ], tests_require=[ - "pytest==4.6.5", + "pytest", ], ) diff --git a/tests/qactivation_test.py b/tests/qactivation_test.py index 7aace158..42f602a0 100644 --- a/tests/qactivation_test.py +++ b/tests/qactivation_test.py @@ -18,7 +18,7 @@ from numpy.testing import assert_allclose import pytest -from keras import backend as K +from tensorflow.keras import backend as K from qkeras import binary from qkeras import hard_sigmoid @@ -29,6 +29,8 @@ from qkeras import quantized_po2 from qkeras import quantized_relu_po2 +# TODO(hzhuang, rxuniverse): quantization_po2/_relu_po2, +# test cases with quadratic_approximation def test_smooth_sigmoid(): """Test smooth_sigmoid function.""" diff --git a/tests/qlayers_test.py b/tests/qlayers_test.py index 2aa91fcd..c4b3b4e2 100644 --- a/tests/qlayers_test.py +++ b/tests/qlayers_test.py @@ -16,6 +16,13 @@ """Test layers from qlayers.py.""" import numpy as np +from numpy.testing import assert_allclose +import pytest +from tensorflow.keras import backend as K +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import Flatten +from tensorflow.keras.layers import Input +from tensorflow.keras.models import Model from qkeras import binary from qkeras import model_save_quantized_weights @@ -27,15 +34,8 @@ from qkeras import quantized_bits from qkeras import ternary -import numpy as np -from numpy.testing import assert_allclose -import pytest -from keras import backend as K -from keras.layers import Activation -from keras.layers import Flatten -from keras.layers import Input -from keras.models import Model +# TODO(hzhuang, rxuniverse): qbatchnormalization tests cases. def qdense_util(layer_cls, kwargs=None, @@ -87,6 +87,7 @@ def test_qdense(layer_kwargs, input_data, weight_data, bias_data, weight_data=[weight_data, bias_data], expected_output=expected_output) + def test_qnetwork(): x = x_in = Input((28, 28, 1), name='input') x = QSeparableConv2D( @@ -155,40 +156,38 @@ def test_qnetwork(): all_weights = np.array(all_weights) - # test_qnetwork_weight_quantization - all_weights_signature = np.array([2.0, -6.75, -0.625, -2.0, -0.25, -56.0, - 1.125, -2.625, -0.75]) + all_weights_signature = np.array( + [2., -6.75, -0.625, -2., -0.25, -56., 1.125, -1.625, -1.125]) + assert all_weights.size == all_weights_signature.size assert np.all(all_weights == all_weights_signature) - # test_qnetwork_forward: - y = np.array([[0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, - 5.341e-02, 9.468e-01, 0.000e+00, 0.000e+00, 0.000e+00], - [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 5.960e-08, - 0.000e+00, 1.919e-01, 0.000e+00, 0.000e+00, 8.081e-01], - [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 2.378e-04, - 0.000e+00, 0.000e+00, 0.000e+00, 2.843e-05, 9.995e-01], - [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, - 0.000e+00, 1.000e+00, 0.000e+00, 0.000e+00, 0.000e+00], - [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, - 0.000e+00, 1.000e+00, 0.000e+00, 2.623e-06, 0.000e+00], - [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, - 7.749e-07, 0.000e+00, 0.000e+00, 1.634e-04, 1.000e+00], - [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, - 0.000e+00, 1.000e+00, 0.000e+00, 0.000e+00, 0.000e+00], - [0.000e+00, 1.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, - 0.000e+00, 6.557e-07, 0.000e+00, 0.000e+00, 0.000e+00], - [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 1.000e+00, - 0.000e+00, 5.960e-08, 0.000e+00, 0.000e+00, 0.000e+00], - [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 9.125e-03, - 9.907e-01, 9.418e-06, 0.000e+00, 5.597e-05, 0.000e+00 - ]]).astype(np.float16) + expected_output = np.array([[0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 0.e+00, 0.e+00, 6.e-08, 1.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00, 0.e+00, 0.e+00, 5.e-07, 1.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00 ,1.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 1.e+00, 0.e+00, 0.e+00, 0.e+00, + 0.e+00 ,0.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 1.e+00, + 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00], + [0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, + 1.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00]]).astype(np.float16) inputs = 2 * np.random.rand(10, 28, 28, 1) - p = model.predict(inputs).astype(np.float16) - assert np.all(p == y) + actual_output = model.predict(inputs).astype(np.float16) + assert_allclose(actual_output, expected_output, rtol=1e-4) def test_qconv1d(): @@ -220,12 +219,16 @@ def test_qconv1d(): inputs = np.random.rand(2, 4, 4) p = model.predict(inputs).astype(np.float16) - + ''' y = np.array([[[0.1309, -1.229], [-0.4165, -2.639], [-0.08105, -2.299], [1.981, -2.195]], [[-0.3174, -3.94], [-0.3352, -2.316], [0.105, -0.833], [0.2115, -2.89]]]).astype(np.float16) - + ''' + y = np.array([[[-2.441, 3.816], [-3.807, -1.426], [-2.684, -1.317], + [-1.659, 0.9834]], + [[-4.99, 1.139], [-2.559, -1.216], [-2.285, 1.905], + [-2.652, -0.467]]]).astype(np.float16) assert np.all(p == y)