From 671ff87210388f0d9b5f3e4dbaf49afc163b4a49 Mon Sep 17 00:00:00 2001 From: Shan Li Date: Wed, 25 Nov 2020 12:23:31 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 344297537 Change-Id: Iafc94fd95cf1f96d9261599146bd66effa2ccb75 --- qkeras/__init__.py | 1 + qkeras/autoqkeras/utils.py | 2 +- qkeras/estimate.py | 31 ++- qkeras/qconv2d_batchnorm.py | 352 ++++++++++++++++++++++++++ qkeras/utils.py | 194 +++++++++++++-- tests/automatic_conversion_test.py | 75 ++++++ tests/bn_folding_test.py | 380 +++++++++++++++++++++++++++++ 7 files changed, 1003 insertions(+), 32 deletions(-) create mode 100644 qkeras/qconv2d_batchnorm.py create mode 100644 tests/bn_folding_test.py diff --git a/qkeras/__init__.py b/qkeras/__init__.py index e25a21da..83969341 100644 --- a/qkeras/__init__.py +++ b/qkeras/__init__.py @@ -31,6 +31,7 @@ from .safe_eval import * # pylint: disable=wildcard-import #from .qtools.run_qtools import QTools #from .qtools.settings import cfg +from .qconv2d_batchnorm import QConv2DBatchnorm assert tf.executing_eagerly(), "QKeras requires TF with eager execution mode on" diff --git a/qkeras/autoqkeras/utils.py b/qkeras/autoqkeras/utils.py index 33fe7ec4..dd4f2c70 100644 --- a/qkeras/autoqkeras/utils.py +++ b/qkeras/autoqkeras/utils.py @@ -38,7 +38,7 @@ def print_qmodel_summary(q_model): if "Dense" in layer.__class__.__name__: print("u={} ".format(layer.units), end="") elif layer.__class__.__name__ in [ - "Conv2D", "QConv2D", "Conv1D", "QConv1D"]: + "Conv2D", "QConv2D", "Conv1D", "QConv1D", "QConv2DBatchnorm"]: print("f={} ".format(layer.filters), end="") quantizers = layer.get_quantizers() for q in range(len(quantizers)): diff --git a/qkeras/estimate.py b/qkeras/estimate.py index cc9f1e5b..e3c9e990 100644 --- a/qkeras/estimate.py +++ b/qkeras/estimate.py @@ -37,6 +37,7 @@ from tensorflow.keras.layers import InputLayer from tensorflow.keras.models import Model +from .qconv2d_batchnorm import QConv2DBatchnorm from .qlayers import QActivation from .qlayers import QAdaptiveActivation from .qlayers import QDense @@ -50,17 +51,18 @@ from .quantizers import quantized_tanh from .quantizers import quantized_ulaw from .utils import get_model_sparsity +from .utils import convert_folded_model_to_normal -def analyze_accumulator(model, x, verbose=False): +def analyze_accumulator(in_model, x, verbose=False): """Analyzes the distribution of weights to specify size of accumulators. Computes the maximum number of bits for the accumulator assuming the inputs have a distribution given by the dictionary x. for each output channel i: - max_positive_value[i] = sum(positive) w[i] + positive(bias[i]) - max_negative_value[i] = sum(negative) w[i] + negative(bias[i]) + max_positive_value[i] = sum(w[i]) + bias[i] for the positive weights + max_negative_value[i] = sum(w[i]) + bias[i] for the negative weights max_value = max( max_positive_value[i] * positive(x) + @@ -79,14 +81,19 @@ def analyze_accumulator(model, x, verbose=False): in the future, we want to provide a sample and compute this automatically Arguments: - model: model to be evaluated - x: input distribution - verbose: if true, print statistics messages + in_model: keras model object, model to be evaluated + x: dictionary of the form: { layer_name: (min_value, max_value) } + input distribution + verbose: boolean, if true, print statistics messages Returns: dictionary containing { layer_name: accumulator_size } """ + # this function converts a folded model to a "normal" model. It replace folded + # layers (e.g., QConv2dBatchnorm) layer with qconv2d layer whenever possible. + model = convert_folded_model_to_normal(in_model) + acc_sizes = {} for layer in model.layers: @@ -146,17 +153,18 @@ def analyze_accumulator(model, x, verbose=False): def analyze_accumulator_from_sample( - model, x_sample, mode="conservative", verbose=False): - + in_model, x_sample, mode="conservative", verbose=False): """Extracts range of inputs of quantized layers from samples.""" # mode is one of "conservative", "sampled" - if mode not in ["conservative", "sampled"]: ValueError("'mode' has to be 'conservative' or 'sampled'") - # get layer names of quantized layers (QDense and QConv2D) + # this function converts a folded model to a "normal" model. It replace folded + # layers (e.g., QConv2DBatchnorm) layer with qconv2d layer whenever possible. + model = convert_folded_model_to_normal(in_model) + # get layer names of quantized layers (QDense and QConv2D) layer_names = [ layer.name for layer in model.layers if (isinstance(layer, QDepthwiseConv2D) or isinstance(layer, QConv2D) or @@ -348,9 +356,10 @@ def create_activation_cache(model): return output_cache -def extract_model_operations(model): +def extract_model_operations(in_model): """Determines types of operations for convolutions.""" + model = convert_folded_model_to_normal(in_model) cache_q = create_activation_cache(model) cache_o = {} diff --git a/qkeras/qconv2d_batchnorm.py b/qkeras/qconv2d_batchnorm.py new file mode 100644 index 00000000..ae8f499f --- /dev/null +++ b/qkeras/qconv2d_batchnorm.py @@ -0,0 +1,352 @@ +# Copyright 2019 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fold batchnormalization with previous qconv or qdense layers.""" + +import numpy as np +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.models import Model + +from .qconvolutional import QConv2D +from .quantizers import * +from tensorflow.python.framework import smart_cond as tf_utils +from tensorflow.python.ops import math_ops + +tf.compat.v2.enable_v2_behavior() + + +class QConv2DBatchnorm(QConv2D): + """Fold batchnormalization with a previous qconv2d layer.""" + + def __init__( + self, + # qconv2d params + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + use_bias=True, + kernel_initializer="he_normal", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + kernel_quantizer=None, + bias_quantizer=None, + + # batchnorm params + axis=-1, + momentum=0.99, + epsilon=0.001, + center=True, + scale=True, + beta_initializer="zeros", + gamma_initializer="ones", + moving_mean_initializer="zeros", + moving_variance_initializer="ones", + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99, + fused=None, + trainable=True, + virtual_batch_size=None, + adjustment=None, + name=None, + + # other params + ema_freeze_delay=300000, + folding_mode="ema_stats_folding", + **kwargs): + + """Initialize a composite layer that folds conv2d and batch normalization. + + The first group of parameters correponds to the initialization parameters + of a qconv2d layer. check qkeras.qconvolutional.qconv2d for details. + + The 2nd group of parameters corresponds to the initialization parameters + of a BatchNormalization layer. Check keras.layers.normalization.BatchNorma + lizationBase for details. + + The 3rd group of parameters corresponds to the initialization parameters + specific to this class. + + ema_freeze_delay: int. number of steps before batch normalization mv_mean + and mv_variance will be frozen and used in the folded layer. + folding_mode: string + "ema_stats_folding": mimic tflite which uses the ema statistics to + fold the kernel to suppress quantization induced jitter then performs + the correction to have a similar effect of using the current batch + statistics. + "batch_stats_folding": use batch mean and variance to fold kernel first; + after enough training steps switch to moving_mean and moving_variance + for kernel folding. + """ + + # intialization the qconv2d part of the composite layer + super(QConv2DBatchnorm, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + kernel_quantizer=kernel_quantizer, + bias_quantizer=bias_quantizer) + + # initialization of batchnorm part of the composite layer + self.batchnorm = layers.BatchNormalization( + axis=axis, momentum=momentum, epsilon=epsilon, center=center, + scale=scale, beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + moving_mean_initializer=moving_mean_initializer, + moving_variance_initializer=moving_variance_initializer, + beta_regularizer=beta_regularizer, + gamma_regularizer=gamma_regularizer, + beta_constraint=beta_constraint, gamma_constraint=gamma_constraint) + + self.ema_freeze_delay = ema_freeze_delay + assert folding_mode in ["ema_stats_folding", "batch_stats_folding"] + self.folding_mode = folding_mode + self._name = name + + def build(self, input_shape): + super(QConv2DBatchnorm, self).build(input_shape) + + # create untrainable folded weights that can export later for zpm + input_channel = self._get_input_channel(input_shape) + kernel_shape = self.kernel_size + (input_channel // self.groups, + self.filters) + # folded quantized kernel and bias + self.folded_kernel_quantized = self.add_weight( + name="folded_kernel_quantized", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + trainable=False, + dtype=self.dtype) + + self.folded_bias_quantized = self.add_weight( + name="folded_bias_quantized", + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=False, + dtype=self.dtype) + + # self._iteration (i.e., training_steps) is initialized with -1. When + # loading ckpt, it can load the number of training steps that have been + # previously trainied. If start training from scratch. + # TODO(lishanok): develop a way to count iterations outside layer + self._iteration = tf.Variable(-1, trainable=False, name="iteration", + dtype=tf.int64) + + def call(self, inputs, training=None): + + # numpy value, mark the layer is in training + training = self.batchnorm._get_training_value(training) # pylint: disable=protected-access + + # checking if to update batchnorm params + bn_training = tf.math.logical_and(training, tf.math.less_equal( + self._iteration, self.ema_freeze_delay)) + + kernel = self.kernel + + # run conv to produce output for the following batchnorm + conv_outputs = tf.keras.backend.conv2d( + inputs, + kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate) + + if self.use_bias: + bias = self.bias + conv_outputs = tf.keras.backend.bias_add( + conv_outputs, bias, data_format=self.data_format) + else: + bias = 0 + + _ = self.batchnorm(conv_outputs, training=bn_training) + if training is True: + # The following operation is only performed during training + + self._iteration.assign_add(tf_utils.smart_cond( + training, lambda: tf.constant(1, tf.int64), + lambda: tf.constant(0, tf.int64))) + + # calcuate mean and variance from current batch + bn_shape = conv_outputs.shape + ndims = len(bn_shape) + reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis] + keep_dims = len(self.batchnorm.axis) > 1 + mean, variance = self.batchnorm._moments( # pylint: disable=protected-access + math_ops.cast(conv_outputs, self.batchnorm._param_dtype), # pylint: disable=protected-access + reduction_axes, + keep_dims=keep_dims) + # get batchnorm weights + gamma = self.batchnorm.gamma + beta = self.batchnorm.beta + moving_mean = self.batchnorm.moving_mean + moving_variance = self.batchnorm.moving_variance + + if self.folding_mode == "batch_stats_folding": + # using batch mean and variance in the initial training stage + # after sufficient training, switch to moving mean and variance + new_mean = tf_utils.smart_cond( + bn_training, lambda: mean, lambda: moving_mean) + new_variance = tf_utils.smart_cond( + bn_training, lambda: variance, lambda: moving_variance) + + # get the inversion factor so that we replace division by multiplication + inv = math_ops.rsqrt(new_variance + self.batchnorm.epsilon) + if gamma is not None: + inv *= gamma + # fold bias with bn stats + folded_bias = inv * (bias - new_mean) + beta + + elif self.folding_mode == "ema_stats_folding": + # We always scale the weights with a correction factor to the long term + # statistics prior to quantization. This ensures that there is no jitter + # in the quantized weights due to batch to batch variation. During the + # initial phase of training, we undo the scaling of the weights so that + # outputs are identical to regular batch normalization. We also modify + # the bias terms correspondingly. After sufficient training, switch from + # using batch statistics to long term moving averages for batch + # normalization. + + # use batch stats for calcuating bias before bn freeze, and use moving + # stats after bn freeze + mv_inv = math_ops.rsqrt(moving_variance + self.batchnorm.epsilon) + batch_inv = math_ops.rsqrt(variance + self.batchnorm.epsilon) + + if gamma is not None: + mv_inv *= gamma + batch_inv *= gamma + folded_bias = tf_utils.smart_cond( + bn_training, + lambda: batch_inv * (bias - mean) + beta, + lambda: mv_inv * (bias - moving_mean) + beta) + # moving stats is always used to fold kernel in tflite; before bn freeze + # an additional correction factor will be applied to the conv2d output + inv = mv_inv + else: + assert ValueError + + # wrap conv kernel with bn parameters + folded_kernel = inv * kernel + # quantize the folded kernel + if self.kernel_quantizer is not None: + q_folded_kernel = self.kernel_quantizer_internal(folded_kernel) + else: + q_folded_kernel = folded_kernel + + # If loaded from a ckpt, bias_quantizer is the ckpt value + # Else if the layer is called for the first time, in this case bias + # quantizer is None and we need to calculate bias quantizer + # type according to accumulator type + if not self.bias_quantizer_internal: + # TODO(lishanok): implement an "eager-mode" quantizer map + pass + + if self.bias_quantizer_internal is not None: + q_folded_bias = self.bias_quantizer_internal(folded_bias) + else: + q_folded_bias = folded_bias + + # set value for the folded weights + self.folded_kernel_quantized.assign(q_folded_kernel, read_value=False) + self.folded_bias_quantized.assign(q_folded_bias, read_value=False) + + applied_kernel = q_folded_kernel + applied_bias = q_folded_bias + else: + applied_kernel = self.folded_kernel_quantized + applied_bias = self.folded_bias_quantized + # calculate conv2d output using the quantized folded kernel + folded_outputs = tf.keras.backend.conv2d( + inputs, + applied_kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate) + if training is True and self.folding_mode == "ema_stats_folding": + batch_inv = math_ops.rsqrt(variance + self.batchnorm.epsilon) + y_corr = tf_utils.smart_cond( + bn_training, + lambda: (math_ops.sqrt(moving_variance + self.batchnorm.epsilon) * + math_ops.rsqrt(variance + self.batchnorm.epsilon)), + lambda: tf.constant(1.0, shape=moving_variance.shape)) + folded_outputs = math_ops.mul(folded_outputs, y_corr) + + folded_outputs = tf.keras.backend.bias_add( + folded_outputs, + applied_bias, + data_format=self.data_format) + if self.activation is not None: + return self.activation(folded_outputs) + + return folded_outputs + + def get_config(self): + base_config = super().get_config() + bn_config = self.batchnorm.get_config() + config = {"ema_freeze_delay": self.ema_freeze_delay, + "folding_mode": self.folding_mode} + name = base_config["name"] + out_config = dict( + list(base_config.items()) + + list(bn_config.items()) + list(config.items())) + + # names from different config override each other; use the base layer name + # as the this layer's config name + out_config["name"] = name + return out_config + + def get_quantization_config(self): + return { + "kernel_quantizer": str(self.kernel_quantizer_internal), + "bias_quantizer": str(self.bias_quantizer_internal), + "activation": str(self.activation), + "filters": str(self.filters) + } + + def get_quantizers(self): + return self.quantizers + + def get_folded_quantized_weight(self): + return [self.folded_kernel_quantized.numpy(), + self.folded_bias_quantized.numpy()] diff --git a/qkeras/utils.py b/qkeras/utils.py index a5b0cbf5..974c8f3f 100644 --- a/qkeras/utils.py +++ b/qkeras/utils.py @@ -42,6 +42,7 @@ from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer from .qlayers import Clip +from .qconv2d_batchnorm import QConv2DBatchnorm from .qlayers import QActivation from .qlayers import QAdaptiveActivation from .qpooling import QAveragePooling2D @@ -86,6 +87,7 @@ "QGRU", "QBidirectional", "QBatchNormalization", + "QConv2DBatchnorm", ] # This is a list of the state variable names of the QKeras layers and quantizers @@ -177,6 +179,10 @@ def model_save_quantized_weights(model, filename=None): """ + # this function converts a folded model to a "normal" model. It replace folded + # layers (e.g., QConv2DBatchnorm) layer with qconv2d layer whenever possible. + model = convert_folded_model_to_normal(model) + saved_weights = {} print("... quantizing model") @@ -280,12 +286,53 @@ def get_config(quantizer_config, layer, layer_class, parameter=None): return quantizer +def find_layers_to_fold(model): + """Find conv/dense layers that need to be folded with following bn layers. + + Args: + model: input model + + Returns: + new model without bn layers + list of layers that need to be folded + + Note: currently only supports sequential model + """ + + # TODO(lishanok): extends this function to non-sequential model + layers = list(model.layers) + layers_to_fold = [] + + prev_layer = None + prev_x = layers[0].output + + for i in range(1, len(layers)): + layer = layers[i] + + if layer.__class__.__name__ not in ["BatchNormalization", + "QBatchNormalization"]: + x = layer(prev_x) + prev_x = x + prev_layer = layer + else: + # current layer is a bn layer; mark prev layer in the to_fold list + if prev_layer.__class__.__name__ in [ + "Conv2D", "Dense", "QConv2D", "QDense", "DepthwiseConv2D", + "QDepthwiseConv2D"]: + layers_to_fold.append(prev_layer.name) + + new_model = Model(inputs=model.inputs, outputs=x) + + return new_model, layers_to_fold + + def model_quantize(model, quantizer_config, activation_bits, custom_objects=None, transfer_weights=False, - prefer_qadaptiveactivation=False): + prefer_qadaptiveactivation=False, + enable_bn_folding=False): """Creates a quantized model from non-quantized model. The quantized model translation is based on json interface of Keras, @@ -359,11 +406,21 @@ def model_quantize(model, qmodel. prefer_qadaptiveactivation: Bool. If true, try to use QAdaptiveActivation over QActivation whenever possible + enable_bn_folding: Bool. If true, fold conv/dense layers with + following batch normalization layers whenever possible. use + QConv2DBatchnorm for example, to replace conv2d layers Returns: qmodel with quantized operations and custom_objects. """ + if enable_bn_folding: + # remove bn layers from the model and find a list of layers to fold + model, layers_to_fold = find_layers_to_fold(model) + if len(layers_to_fold) == 0: + # no layers to fold, no need to perform folding + enable_bn_folding = False + if not custom_objects: custom_objects = {} @@ -414,7 +471,12 @@ def quantize_rnn(layer, quantizer_config): # Activation converts activation functions if layer["class_name"] in ["Dense", "Conv1D", "Conv2D", "Conv2DTranspose"]: - q_name = "Q" + layer["class_name"] + if (layer["class_name"] in ["Dense", "Conv2D"] and enable_bn_folding and + layer["name"] in layers_to_fold): + # only fold if current layer is followed by BN layer + q_name = "Q" + layer["class_name"] + "Batchnorm" + else: + q_name = "Q" + layer["class_name"] # needs to add kernel/bias quantizers kernel_quantizer = get_config( quantizer_config, layer, q_name, "kernel_quantizer") @@ -422,6 +484,16 @@ def quantize_rnn(layer, quantizer_config): bias_quantizer = get_config( quantizer_config, layer, q_name, "bias_quantizer") + if (kernel_quantizer is None and + q_name == "Q" + layer["class_name"] + "Batchnorm"): + # try none-folded layer quantizer as a back up + kernel_quantizer = get_config( + quantizer_config, layer, "Q" + layer["class_name"], + "kernel_quantizer") + bias_quantizer = get_config( + quantizer_config, layer, "Q" + layer["class_name"], + "bias_quantizer") + # this is to avoid unwanted transformations if kernel_quantizer is None: continue @@ -441,23 +513,35 @@ def quantize_rnn(layer, quantizer_config): quantize_activation(layer_config, activation_bits) elif layer["class_name"] == "DepthwiseConv2D": + if enable_bn_folding and layer.name in layers_to_fold: + q_name = "QDepthwiseConv2DBatchnorm" + else: + q_name = "QDepthwiseConv2D" + # needs to add kernel/bias quantizers - depthwise_quantizer = get_config(quantizer_config, layer, - "QDepthwiseConv2D", "depthwise_quantizer") - bias_quantizer = get_config(quantizer_config, layer, - "QDepthwiseConv2D", "bias_quantizer") + depthwise_quantizer = get_config(quantizer_config, layer, q_name, + "depthwise_quantizer") + bias_quantizer = get_config(quantizer_config, layer, q_name, + "bias_quantizer") + + if depthwise_quantizer is None and q_name == "QDepthwiseConv2DBatchnorm": + # try none-folded layer quantizer as a back up + depthwise_quantizer = get_config( + quantizer_config, layer, "QDepthwiseConv2D", "depthwise_quantizer") + bias_quantizer = get_config( + quantizer_config, layer, "QDepthwiseConv2D", "bias_quantizer") # this is to avoid unwanted transformations if depthwise_quantizer is None: continue - layer["class_name"] = "QDepthwiseConv2D" + layer["class_name"] = q_name layer_config["depthwise_quantizer"] = depthwise_quantizer layer_config["bias_quantizer"] = bias_quantizer # if activation is present, add activation here - quantizer = get_config(quantizer_config, layer, - "QDepthwiseConv2D", "activation_quantizer",) + quantizer = get_config(quantizer_config, layer, q_name, + "activation_quantizer",) if quantizer: layer_config["activation"] = quantizer @@ -467,16 +551,19 @@ def quantize_rnn(layer, quantizer_config): elif layer["class_name"] in ["SimpleRNN", "LSTM", "GRU"]: quantize_rnn(layer, quantizer_config) - elif layer['class_name'] == 'Bidirectional': + elif layer["class_name"] == "Bidirectional": forward_layer_quantizer_config = { - layer_config['layer']['config']['name'] : get_config(quantizer_config, - layer, "QBidirectional") } - quantize_rnn(layer['config']['layer'], forward_layer_quantizer_config) + layer_config["layer"]["config"]["name"]: + get_config(quantizer_config, layer, "QBidirectional") + } + quantize_rnn(layer["config"]["layer"], forward_layer_quantizer_config) if "backward_layer" in layer_config: backward_layer_quantizer_config = { - layer_config['backward_layer']['config']['name'] : get_config(quantizer_config, - layer, "QBidirectional") } - quantize_rnn(layer['config']['backward_layer'], backward_layer_quantizer_config) + layer_config["backward_layer"]["config"]["name"]: + get_config(quantizer_config, layer, "QBidirectional") + } + quantize_rnn(layer["config"]["backward_layer"], + backward_layer_quantizer_config) layer["class_name"] = "QBidirectional" elif layer["class_name"] == "Activation": @@ -608,7 +695,7 @@ def quantize_rnn(layer, quantizer_config): # if transfer_weights is true, we load the weights from model to qmodel - if transfer_weights: + if transfer_weights and not enable_bn_folding: for layer, qlayer in zip(model.layers, qmodel.layers): if layer.get_weights(): qlayer.set_weights(copy.deepcopy(layer.get_weights())) @@ -647,6 +734,8 @@ def _add_supported_quantized_objects(custom_objects): custom_objects["quantized_po2"] = quantized_po2 custom_objects["quantized_relu_po2"] = quantized_relu_po2 + custom_objects["QConv2DBatchnorm"] = QConv2DBatchnorm + def clone_model(model, custom_objects=None): """Clones model with custom_objects.""" @@ -777,7 +866,8 @@ def get_model_sparsity(model, per_layer=False, allow_list=None): "QDepthwiseConv2D", "DepthwiseConv2D", "QSeparableConv2D", "SeparableConv2D", "QOctaveConv2D", "QSimpleRNN", "RNN", "QLSTM", "QGRU", - "QConv2DTranspose", "Conv2DTranspose" + "QConv2DTranspose", "Conv2DTranspose", + "QConv2DBatchnorm" ] # Quantize the model weights for a more accurate sparsity calculation @@ -806,6 +896,8 @@ def get_model_sparsity(model, per_layer=False, allow_list=None): def quantized_model_debug(model, X_test, plot=False): """Debugs and plots model weights and activations.""" + + model = convert_folded_model_to_normal(model) outputs = [] output_names = [] @@ -835,8 +927,8 @@ def quantized_model_debug(model, X_test, plot=False): if alpha != 1.0: print(" a[{: 8.4f} {:8.4f}]".format(np.min(alpha), np.max(alpha))) if plot and layer.__class__.__name__ in [ - "QConv1D", "QConv2D", "QConv2DTranspose", "QDense", "QActivation", - "QAdaptiveActivation", "QSimpleRNN", "QLSTM", "QGRU", "QBidirectional" + "QConv1D", "QConv2D", "QConv2DTranspose", "QDense", "QActivation", + "QAdaptiveActivation", "QSimpleRNN", "QLSTM", "QGRU", "QBidirectional" ]: plt.hist(p.flatten(), bins=25) plt.title(layer.name + "(output)") @@ -904,3 +996,65 @@ def quantized_model_dump(model, print("writing the layer output tensor to ", filename) with open(filename, "w") as fid: tensor_data.astype(np.float32).tofile(fid) + + +def convert_folded_model_to_normal(model): + """Convert a sequential model with batchnorm folded layer to a normal model. + + Replace the folded layers with a normal qconv/qdense layer. + Set the weights in the normal layer with the folded weights + in the folded layer. + + we need to convert a folded model to a normal model before we pass it to zpm. + + Arguments: + model: model with folded layers. + + Returns: + A model that replaces folded layers (e.g., QConv2DBatchnorm) with normal + qkeras layers (e.g., QConv2D). This model can be passed on to hardware + generator (zpm) so that hardware doesn't see batch normalization + parameters. + """ + + layer_list = list(model.layers) + x = layer_list[0].output + + for i in range(1, len(layer_list)): + layer = layer_list[i] + + if layer.__class__.__name__ not in ["QConv2DBatchnorm"]: + x = layer_list[i](x) + + else: + # get layer config from the composite layer + config = layer.get_config() + + # set layer config for QConv2D layer by first creating a tmp + # QConv2D object and generate template for its config + qconv2d = QConv2D(filters=1, kernel_size=(2, 2)) + qconv2d_cfg = qconv2d.get_config() + + # set qconv2d config according to the values in the composite layer + for key in qconv2d_cfg.keys(): + qconv2d_cfg[key] = config[key] + + # in case use_bias is False in the composite layer, + # we need to set it True because we have folded bias + qconv2d_cfg["use_bias"] = True + + # create a qconv2d layer from config and replace old layer with it + qconv2d = QConv2D.from_config(qconv2d_cfg) + x = qconv2d(x) + + # transfer weights from composite layer to the qconv2d layer + for weight in layer.weights: + if "folded_kernel_quantized" in weight.name: + folded_kernel_quantized = weight.numpy() + elif "folded_bias_quantized" in weight.name: + folded_bias_quantized = weight.numpy() + qconv2d.set_weights([folded_kernel_quantized, folded_bias_quantized]) + + new_model = Model(inputs=model.inputs, outputs=x) + + return new_model diff --git a/tests/automatic_conversion_test.py b/tests/automatic_conversion_test.py index a1a925ec..5a079599 100644 --- a/tests/automatic_conversion_test.py +++ b/tests/automatic_conversion_test.py @@ -33,6 +33,16 @@ def create_network(): x = QConv2D(32, (3, 3), activation="quantized_relu(4)")(x) return Model(inputs=xi, outputs=x) +def create_network_with_bn(): + xi = Input((28,28,1)) + x = Conv2D(32, (3, 3))(xi) + x = BatchNormalization(axis=-1)(x) + x = Activation("relu", name='relu_act')(x) + x = Conv2D(32, (3, 3), activation="relu")(x) + x = Activation("softmax")(x) + x = QConv2D(32, (3, 3), activation="quantized_relu(4)")(x) + return Model(inputs=xi, outputs=x) + def create_network_sequential(): model = Sequential([ Conv2D(32, (3, 3), input_shape=(28,28,1)), @@ -152,5 +162,70 @@ def test_sequential_model_conversion(): qq = model_quantize(m, d, 4) assert str(qq.layers[2].activation) == "quantized_relu(4,0)" + +def test_folded_layer_conversion(): + # create a sequential model with conv2d layer and activation layers + m1 = create_network() + + # create a sequantial model with conv2d layer followed by bn layer + m2 = create_network_with_bn() + + # quantization config + d = { + "QConv2D": { + "kernel_quantizer": "binary", + "bias_quantizer": "binary" + }, + "QConv2DBatchnorm": { + "kernel_quantizer": "ternary", + "bias_quantizer": "ternary", + }, + "relu_act": { + "relu": "quantized_relu(8)" + } + } + + # test when model has no layer to fold + # desired behavior: un-folded layers + qq1 = model_quantize(m1, d, 4, enable_bn_folding=True) + assert qq1.layers[1].__class__.__name__ == "QConv2D" + assert str(qq1.layers[1].quantizers[0]).startswith("binary") + + # test when the 1st conv2d layers needs to fold but the 2nd conv2d layer + # does not (not followed by bn layer) + # desired behavior: 1st conv2d is folded, 2nd conv2d unfolded + qq2 = model_quantize(m2, d, 4, enable_bn_folding=True) + assert qq2.layers[1].__class__.__name__ == "QConv2DBatchnorm" + assert str(qq2.layers[1].quantizers[0]).startswith("ternary") + assert qq2.layers[3].__class__.__name__ == "QConv2D" + assert str(qq2.layers[3].quantizers[0]).startswith("binary") + + # test when there are layers to fold but folding is disabled + # desired behavior: all conv2d layers unfolded + qq3 = model_quantize(m2, d, 4, enable_bn_folding=False) + assert qq3.layers[1].__class__.__name__ == "QConv2D" + assert str(qq3.layers[1].quantizers[0]).startswith("binary") + assert qq3.layers[2].__class__.__name__ == "BatchNormalization" + assert str(qq3.layers[3].quantizer).startswith("quantized_relu") + + # test when QConv2DBatchnorm quantizer is not given in config + # desired behavior: quantizers for QConv2DBatchnorm layer fall back to QConv2D + # quantizers + d = { + "QConv2D": { + "kernel_quantizer": "binary", + "bias_quantizer": "binary" + }, + "relu_act": { + "relu": "quantized_relu(8)" + } + } + qq4 = model_quantize(m2, d, 4, enable_bn_folding=True) + assert qq4.layers[1].__class__.__name__ == "QConv2DBatchnorm" + assert str(qq4.layers[1].quantizers[0]).startswith("binary") + assert qq4.layers[3].__class__.__name__ == "QConv2D" + assert str(qq4.layers[3].quantizers[0]).startswith("binary") + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/bn_folding_test.py b/tests/bn_folding_test.py new file mode 100644 index 00000000..06439c75 --- /dev/null +++ b/tests/bn_folding_test.py @@ -0,0 +1,380 @@ +# Copyright 2019 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests layers from folded_layers.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from numpy.testing import assert_allclose +from numpy.testing import assert_raises +import tempfile + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.models import Model +from tensorflow.keras.backend import clear_session +from tensorflow.keras.utils import to_categorical +from tensorflow.keras import metrics + +from qkeras import QConv2DBatchnorm +from qkeras import QConv2D +from qkeras import utils as qkeras_utils + + +def get_qconv2d_model(input_shape, kernel_size, kernel_quantizer=None): + num_class = 2 + + x = x_in = layers.Input(input_shape, name="input") + + x = QConv2D( + filters=2, kernel_size=kernel_size, strides=(4, 4), + kernel_initializer="ones", + bias_initializer="zeros", use_bias=False, + kernel_quantizer=kernel_quantizer, bias_quantizer=None, + name="conv2d")(x) + + x = layers.BatchNormalization( + axis=-1, + momentum=0.99, + epsilon=0.001, + center=True, + scale=True, + beta_initializer="zeros", + gamma_initializer="ones", + moving_mean_initializer="zeros", + moving_variance_initializer="ones", + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99, + fused=None, + trainable=True, + virtual_batch_size=None, + adjustment=None, + name="bn")( + x) + x = layers.Flatten(name="flatten")(x) + x = layers.Dense(num_class, use_bias=False, kernel_initializer="ones", + name="dense")(x) + x = layers.Activation("softmax", name="softmax")(x) + model = Model(inputs=[x_in], outputs=[x]) + return model + + +def get_qconv2d_batchnorm_model(input_shape, kernel_size, folding_mode, + kernel_quantizer=None): + num_class = 2 + + x = x_in = layers.Input(input_shape, name="input") + x = QConv2DBatchnorm( + filters=2, kernel_size=kernel_size, strides=(4, 4), + kernel_initializer="ones", bias_initializer="zeros", use_bias=False, + kernel_quantizer=kernel_quantizer, beta_initializer="zeros", + gamma_initializer="ones", moving_mean_initializer="zeros", + moving_variance_initializer="ones", folding_mode=folding_mode, + name="foldconv2d")(x) + + x = layers.Flatten(name="flatten")(x) + x = layers.Dense(num_class, use_bias=False, kernel_initializer="ones", + name="dense")(x) + x = layers.Activation("softmax", name="softmax")(x) + model = Model(inputs=[x_in], outputs=[x]) + return model + + +def get_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay): + + x_shape = (2, 2, 1) + loss_fn = tf.keras.losses.MeanSquaredError() + optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3) + + # define a model with seperate conv2d and bn layers + x = x_in = layers.Input(x_shape, name="input") + x = QConv2D( + filters=2, kernel_size=(2, 2), strides=(4, 4), + kernel_initializer="ones", + bias_initializer="zeros", use_bias=False, + kernel_quantizer=kernel_quantizer, bias_quantizer=None, + name="conv2d")(x) + x = layers.BatchNormalization( + axis=-1, + momentum=0.99, + epsilon=0.001, + center=True, + scale=True, + beta_initializer="zeros", + gamma_initializer="ones", + moving_mean_initializer="zeros", + moving_variance_initializer="ones", + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99, + fused=None, + trainable=True, + virtual_batch_size=None, + adjustment=None, + name="bn")(x) + unfold_model = Model(inputs=[x_in], outputs=[x]) + unfold_model.compile(loss=loss_fn, optimizer=optimizer, metrics="acc") + + x = x_in = layers.Input(x_shape, name="input") + x = QConv2DBatchnorm( + filters=2, kernel_size=(2, 2), strides=(4, 4), + kernel_initializer="ones", bias_initializer="zeros", use_bias=False, + kernel_quantizer=kernel_quantizer, beta_initializer="zeros", + gamma_initializer="ones", moving_mean_initializer="zeros", + moving_variance_initializer="ones", folding_mode=folding_mode, + ema_freeze_delay=ema_freeze_delay, + name="foldconv2d")(x) + fold_model = Model(inputs=[x_in], outputs=[x]) + fold_model.compile(loss=loss_fn, optimizer=optimizer, metrics="acc") + + return (unfold_model, fold_model) + + +def get_debug_model(model): + layer_output_list = [] + for layer in model.layers: + if layer.__class__.__name__ not in ["Flatten", "InputLayer"]: + layer_output_list.append(layer.output) + + debug_model = Model(inputs=model.inputs, outputs=layer_output_list) + return debug_model + + +def generate_dataset(train_size=10, + batch_size=5, + input_shape=(3, 3, 1), + num_class=2): + """create tf.data.Dataset with shape: (N,) + input_shape.""" + + x_train = np.random.randint( + 4, size=(train_size, input_shape[0], input_shape[1], input_shape[2])) + x_train = np.random.rand( + train_size, input_shape[0], input_shape[1], input_shape[2]) + + y_train = np.random.randint(num_class, size=train_size) + y_train = to_categorical(y_train, num_class) + + train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + train_ds = train_ds.batch(batch_size) + return train_ds + + +def run_training(model, epochs, loss_fn, loss_metric, optimizer, + train_ds, do_print=False): + + # Iterate over epochs. + for epoch in range(epochs): + if do_print: + print("- epoch {} -".format(epoch)) + + # Iterate over the batches of the dataset. + for step, (x_batch_train, y_batch_train) in enumerate(train_ds): + if do_print: + print("\n - step {} -".format(step)) + with tf.GradientTape() as tape: + predictions = model(x_batch_train, training=True) + + if epoch == epochs - 1: + if do_print: + print("y_pred:", predictions) + print("y:", y_batch_train) + output_predictions = predictions + + # Compute loss + loss = loss_fn(y_batch_train, predictions) + + grads = tape.gradient(loss, model.trainable_weights) + if do_print: + if epoch == epochs - 1: + # print("old trainable:", model.trainable_weights) + print("grads:", grads) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + if do_print: + if epoch == epochs - 1: + # print("new trainable:", model.trainable_weights) + print("loss:", loss) + loss_metric(loss) + if do_print: + if epoch == epochs - 1: + print("mean loss = %.4f" % (loss_metric.result())) + + return output_predictions + + +def test_loading(): + """Test to load model using different approahches.""" + + loss_fn = tf.keras.losses.MeanSquaredError() + loss_metric = metrics.Mean() + optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3) + x_shape = (2, 2, 1) + + custom_objects = {} + qkeras_utils._add_supported_quantized_objects(custom_objects) + + train_ds = generate_dataset(train_size=1, batch_size=1, + input_shape=x_shape, num_class=2) + + model_fold = get_qconv2d_batchnorm_model( + input_shape=x_shape, kernel_size=(2, 2), + folding_mode="ema_stats_folding") + model_fold.compile(loss=loss_fn, optimizer=optimizer, metrics="acc") + + run_training(model_fold, 10, loss_fn, loss_metric, optimizer, train_ds, + do_print=False) + + # test load model from json to ensure saving/loading model architecture works + json_string = model_fold.to_json() + clear_session() + model_from_json = qkeras_utils.quantized_model_from_json(json_string) + assert json_string == model_from_json.to_json() + + # test reload model from hdf5 files to ensure saving/loading works + _, fname = tempfile.mkstemp(".h5") + model_fold.save(fname) + model_loaded = qkeras_utils.load_qmodel(fname) + weight1 = model_fold.layers[1].get_folded_quantized_weight() + weight2 = model_loaded.layers[1].get_folded_quantized_weight() + assert_allclose(weight1[0], weight2[0], rtol=1e-4) + assert_allclose(weight1[1], weight2[1], rtol=1e-4) + + # test convert a folded model to a normal model for zpm + # the kernel/bias weight in the normal model should be the same as the folded + # kernel/bias in the folded model + normal_model = qkeras_utils.convert_folded_model_to_normal(model_fold) + weight2 = normal_model.layers[1].get_weights() + assert_allclose(weight1[0], weight2[0], rtol=1e-4) + assert_allclose(weight1[1], weight2[1], rtol=1e-4) + + +def test_same_training_and_prediction(): + """test if fold/unfold layer has the same training and prediction output.""" + + epochs = 5 + loss_fn = tf.keras.losses.MeanSquaredError() + loss_metric = metrics.Mean() + optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3) + + x_shape = (2, 2, 1) + kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]]) + gamma = np.array([2., 1.]) + beta = np.array([0., 1.]) + moving_mean = np.array([1., 1.]) + moving_variance = np.array([1., 2.]) + iteration = np.array(-1) + folded_kernel_quantized = np.array([[[[1.99900079, 0.706930101]], + [[1.99900079, 0]]], + [[[1.99900079, 0.706930101]], + [[0, 0.706930101]]]]) + folded_bias_quantized = np.array([-1.99900079, 0.293069899]) + train_ds = generate_dataset(train_size=10, batch_size=10, input_shape=x_shape, + num_class=2) + + (unfold_model, fold_model_batch) = get_models_with_one_layer( + kernel_quantizer=None, folding_mode="batch_stats_folding", + ema_freeze_delay=10) + (_, fold_model_ema) = get_models_with_one_layer( + kernel_quantizer=None, folding_mode="ema_stats_folding", + ema_freeze_delay=10) + + unfold_model.layers[1].set_weights([kernel]) + unfold_model.layers[2].set_weights( + [gamma, beta, moving_mean, moving_variance]) + fold_model_batch.layers[1].set_weights([ + kernel, gamma, beta, folded_kernel_quantized, folded_bias_quantized, + iteration, moving_mean, moving_variance + ]) + fold_model_ema.layers[1].set_weights([ + kernel, gamma, beta, folded_kernel_quantized, folded_bias_quantized, + iteration, moving_mean, moving_variance + ]) + + # check if prediction is the same + y1 = unfold_model.predict(train_ds) + y2_batch = fold_model_batch.predict(train_ds) + y2_ema = fold_model_ema.predict(train_ds) + assert_allclose(y1, y2_batch, rtol=1e-4) + assert_allclose(y1, y2_ema, rtol=1e-4) + + # check if training for a number of epochs, and before bn freeeze, models + # reached the same point + y1 = run_training(unfold_model, epochs, loss_fn, loss_metric, optimizer, + train_ds, do_print=False) + y2_batch = run_training(fold_model_batch, epochs, loss_fn, loss_metric, + optimizer, train_ds, do_print=False) + y2_ema = run_training(fold_model_ema, epochs, loss_fn, loss_metric, optimizer, + train_ds, do_print=False) + assert_allclose(y1, y2_batch, rtol=1e-4) + assert_allclose(y1, y2_ema, rtol=1e-4) + + # check if training for long enough (after bn freezes), unfold model and fold + # models should be different, but the two folding modes should be the same + epochs = 5 + iteration = np.array(8) + (unfold_model, fold_model_batch) = get_models_with_one_layer( + kernel_quantizer=None, folding_mode="batch_stats_folding", + ema_freeze_delay=10) + (_, fold_model_ema) = get_models_with_one_layer( + kernel_quantizer=None, folding_mode="ema_stats_folding", + ema_freeze_delay=10) + unfold_model.layers[1].set_weights([kernel]) + unfold_model.layers[2].set_weights( + [gamma, beta, moving_mean, moving_variance]) + fold_model_batch.layers[1].set_weights([ + kernel, gamma, beta, folded_kernel_quantized, folded_bias_quantized, + iteration, moving_mean, moving_variance + ]) + fold_model_ema.layers[1].set_weights([ + kernel, gamma, beta, folded_kernel_quantized, folded_bias_quantized, + iteration, moving_mean, moving_variance + ]) + y1 = run_training( + unfold_model, + epochs, + loss_fn, + loss_metric, + optimizer, + train_ds, + do_print=False) + y2_batch = run_training( + fold_model_batch, + epochs, + loss_fn, + loss_metric, + optimizer, + train_ds, + do_print=False) + y2_ema = run_training( + fold_model_ema, + epochs, + loss_fn, + loss_metric, + optimizer, + train_ds, + do_print=False) + assert_raises(AssertionError, assert_allclose, y1, y2_batch, rtol=1e-4) + assert_allclose(y2_batch, y2_ema, rtol=1e-4)