Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 341160130
Change-Id: I2772e76ba3bd23209dedc7bde24bae7308d8371f
  • Loading branch information
qkeras-robot authored and copybara-github committed Nov 7, 2020
1 parent 41aec91 commit 3cc8719
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 18 deletions.
81 changes: 64 additions & 17 deletions qkeras/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,15 @@ def __str__(self):
str(int(self.use_stochastic_rounding)))
return "quantized_bits(" + ",".join(flags) + ")"

def __call__(self, x):
def __call__(self, x, qnoise_factor=1.0):
"""Computes fixedpoint quantization of x."""
# qnoise_factor: a scalar from 0 to 1 that represents the level of
# quantization noise to add. This controls the amount of the quantization
# noise to add to the outputs by changing the weighted sum of
# (1 - qnoise_factor)*unquantized_x + qnoise_factor*quantized_x.

x = K.cast_to_floatx(x)

# quantized_bits with "1" bit becomes a binary implementation.
unsigned_bits = self.bits - self.keep_negative
m = pow(2, unsigned_bits)
Expand Down Expand Up @@ -468,7 +475,7 @@ def __call__(self, x):
x = m_i * x
xq = m_i * z / m
self.scale = scale
return x + tf.stop_gradient(-x + scale * xq)
return x + tf.stop_gradient(qnoise_factor * (-x + scale * xq))
else:
scale = self.alpha

Expand All @@ -485,7 +492,7 @@ def __call__(self, x):
xq = (xq + 1.0) / 2.0

self.scale = scale
return x + tf.stop_gradient(-x + scale * xq)
return x + tf.stop_gradient(qnoise_factor * (-x + scale * xq))

def _set_trainable_parameter(self):
if self.alpha is None:
Expand Down Expand Up @@ -1216,22 +1223,41 @@ class quantized_relu(BaseQuantizer): # pylint: disable=invalid-name
use_sigmoid: if true, we apply sigmoid to input to normalize it.
negative_slope: slope when activation < 0, needs to be power of 2.
use_stochastic_rounding: if true, we perform stochastic rounding.
relu_upper_bound: A float representing an upper bound of the unquantized
relu. If None, we apply relu without the upper bound when
"is_quantized_clip" is set to false (true by default).
Note: The quantized relu uses the quantization parameters (bits and
integer) to upper bound. So it is important to set relu_upper_bound
appropriately to the quantization parameters. "is_quantized_clip"
has precedence over "relu_upper_bound" for backward compatibility.
is_quantized_clip: A boolean representing whether the inputs are clipped to
the maximum value represented by the quantization parameters. This
parameter is deprecated, and the default is set to True for backwards
compatibility. Users are encouraged to use "relu_upper_bound" instead.
Returns:
Function that performs relu + quantization to bits >= 0.
"""

def __init__(self, bits=8, integer=0, use_sigmoid=0,
negative_slope=0, use_stochastic_rounding=False):
def __init__(self,
bits=8,
integer=0,
use_sigmoid=0,
negative_slope=0.0,
use_stochastic_rounding=False,
relu_upper_bound=None,
is_quantized_clip=True):
super(quantized_relu, self).__init__()
self.bits = bits
self.integer = integer
self.use_sigmoid = use_sigmoid
self.negative_slope = negative_slope
self.use_stochastic_rounding = use_stochastic_rounding
self.relu_upper_bound = relu_upper_bound
self.is_quantized_clip = is_quantized_clip

assert negative_slope >= 0.0
if negative_slope != 0:
if negative_slope != 0.0:
assert np.mod(np.log2(negative_slope), 1) == 0

def __str__(self):
Expand All @@ -1244,12 +1270,30 @@ def __str__(self):
flags.append(str(int(self.use_stochastic_rounding)))
return "quantized_relu(" + ",".join(flags) + ")"

def __call__(self, x):
non_sign_bits = self.bits - (self.negative_slope != 0)
def __call__(self, x, qnoise_factor=1.0):
# qnoise_factor: a scalar from 0 to 1 that represents the level of
# quantization noise to add. This controls the amount of the quantization
# noise to add to the outputs by changing the weighted sum of
# (1 - qnoise_factor)*unquantized_relu + qnoise_factor*quantized_relu.

non_sign_bits = self.bits - (self.negative_slope != 0.0)
x = K.cast_to_floatx(x)
m = K.cast_to_floatx(pow(2, non_sign_bits))
m_i = K.cast_to_floatx(pow(2, self.integer))
x_uq = tf.where(
x <= m_i, K.relu(x, alpha=self.negative_slope), tf.ones_like(x) * m_i)

# is_quantized_clip has precedence over relu_upper_bound for backward
# compatibility.
if self.is_quantized_clip:
m_f = K.cast_to_floatx(
pow(tf.constant(2., tf.float32), self.integer - non_sign_bits))
x_u = tf.where(x <= m_i - m_f, K.relu(x, alpha=self.negative_slope),
tf.ones_like(x) * (m_i - m_f))
elif self.relu_upper_bound is not None:
x_u = tf.where(x <= self.relu_upper_bound,
K.relu(x, alpha=self.negative_slope),
tf.ones_like(x) * self.relu_upper_bound)
else:
x_u = K.relu(x, alpha=self.negative_slope)

if self.use_sigmoid:
p = _sigmoid(x / m_i) * m
Expand All @@ -1260,19 +1304,21 @@ def __call__(self, x):
neg_factor = 1 / (self.negative_slope * m)
xq = xq + m_i * self.negative_slope * tf.keras.backend.clip(
2.0 * (_round_through(p * self.negative_slope,
self.use_stochastic_rounding) * neg_factor) - 1.0,
-1.0, 0.0)
self.use_stochastic_rounding) * neg_factor) -
1.0, -1.0, 0.0)
else:
p = x * m / m_i
xq = m_i * tf.keras.backend.clip(
_round_through(p, self.use_stochastic_rounding) / m, 0.0,
1.0 - 1.0 / m)
if self.negative_slope > 0:
neg_factor = 1 / (self.negative_slope * m)
xq = xq + m_i * self.negative_slope * (tf.keras.backend.clip(
_round_through(p * self.negative_slope,
self.use_stochastic_rounding) * neg_factor, -1.0, 0.0))
return x_uq + tf.stop_gradient(-x_uq + xq)
xq = xq + m_i * self.negative_slope * (
tf.keras.backend.clip(
_round_through(p * self.negative_slope,
self.use_stochastic_rounding) * neg_factor, -1.0,
0.0))
return x_u + tf.stop_gradient(qnoise_factor * (-x_u + xq))

def max(self):
"""Get the maximum value that quantized_relu can represent."""
Expand Down Expand Up @@ -1320,7 +1366,8 @@ def get_config(self):
"integer": self.integer,
"use_sigmoid": self.use_sigmoid,
"negative_slope": self.negative_slope,
"use_stochastic_rounding": self.use_stochastic_rounding
"use_stochastic_rounding": self.use_stochastic_rounding,
"relu_upper_bound": self.relu_upper_bound
}
return config

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tensorflow>=2.1.0rc0
tensorflow>=2.1.0rc0, <2.4.0rc0
numpy>=1.16
pyparser
scipy>=1.4.1
Expand Down
188 changes: 188 additions & 0 deletions tests/qnoise_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright 2020 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test gradual quantization noise injection with quantizers of quantizers.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import logging
from numpy.testing import assert_allclose
from numpy.testing import assert_equal
import pytest
from tensorflow.keras import backend as K
from qkeras.quantizers import quantized_bits
from qkeras.quantizers import quantized_relu


def test_qnoise_quantized_bits():
# 1 sign bit, 1 integer bit, and 2 fractional bits.
bits = 4
integer = 1
symmetric = True
keep_negative = True
alpha = 1
use_stochastic_rounding = False

qb = quantized_bits(
bits=bits,
integer=integer,
symmetric=symmetric,
keep_negative=keep_negative,
alpha=alpha,
use_stochastic_rounding=use_stochastic_rounding)

inputs = np.array([0.0, 0.5, -0.5, 0.6, -0.6, 2.0, -2.0], dtype=np.float32)
x = np.array([0.0, 0.5, -0.5, 0.6, -0.6, 2.0, -2.0], dtype=np.float32)
xq = np.array([0.0, 0.5, -0.5, 0.5, -0.5, 1.75, -1.75], dtype=np.float32)
x_xq = 0.5 * (x + xq)

# no quantization
x_q_0 = qb(inputs, qnoise_factor=0.0)
assert_equal(x_q_0, x)

# full quantization
x_q_1 = qb(inputs, qnoise_factor=1.0)
assert_equal(x_q_1, xq)

# mixing half and half of x and xq
x_q_05 = qb(inputs, qnoise_factor=0.5)
assert_equal(x_q_05, x_xq)


def test_qnoise_quantized_relu():
# 0 sign bit, 1 integer bit, and 3 fractional bits.
bits = 4
integer = 1
use_sigmoid = False
negative_slope = 0
use_stochastic_rounding = False

# input to quantized relu
inputs = np.array([0.0, 0.5, -0.5, 0.6, 2.0, 3.0], dtype=np.float32)
# float relu
x = np.array([0.0, 0.5, 0.0, 0.6, 2.0, 3.0], dtype=np.float32)
# float relu with upper bound 1.5
x_ub = np.array([0.0, 0.5, 0.0, 0.6, 1.5, 1.5], dtype=np.float32)
# float relu with quantized clipping
x_clipped = np.array([0.0, 0.5, 0.0, 0.6, 1.875, 1.875], dtype=np.float32)
# quantized relu
xq = np.array([0.0, 0.5, 0.0, 0.625, 1.875, 1.875], dtype=np.float32)

# mixing half and half
x_xq = 0.5 * (x + xq)
x_clipped_xq = 0.5 * (x_clipped + xq)
x_ub_xq = 0.5 * (x_ub + xq)

######################
# No relu upper bound
######################
relu_upper_bound = None
qr = quantized_relu(
bits=bits,
integer=integer,
use_sigmoid=use_sigmoid,
negative_slope=negative_slope,
use_stochastic_rounding=use_stochastic_rounding,
relu_upper_bound=relu_upper_bound)

######################
# Relu upper bound
######################
relu_upper_bound = 1.5
qr_ub = quantized_relu(
bits=bits,
integer=integer,
use_sigmoid=use_sigmoid,
negative_slope=negative_slope,
use_stochastic_rounding=use_stochastic_rounding,
relu_upper_bound=relu_upper_bound)

#########################################
# No relu upper bound
# No quantized clip for float relu
#########################################
qr.is_quantized_clip = False

# no quantization
x_q_0 = qr(inputs, qnoise_factor=0.0)
assert_equal(x_q_0, x)

# full quantization
x_q_1 = qr(inputs, qnoise_factor=1.0)
assert_equal(x_q_1, xq)

# mixing half and half
x_q_05 = qr(inputs, qnoise_factor=0.5)
assert_equal(x_q_05, x_xq)

#########################################
# No relu upper bound
# Quantized clip for float relu
#########################################
qr.is_quantized_clip = True

# no quantization
x_q_0 = qr(inputs, qnoise_factor=0.0)
assert_equal(x_q_0, x_clipped)

# full quantization
x_q_1 = qr(inputs, qnoise_factor=1.0)
assert_equal(x_q_1, xq)

# mixing half and half
x_q_05 = qr(inputs, qnoise_factor=0.5)
assert_equal(x_q_05, x_clipped_xq)

#########################################
# Relu upper bound
# No quantized clip for float relu
#########################################
qr_ub.is_quantized_clip = False

# no quantization
x_q_0 = qr_ub(inputs, qnoise_factor=0.0)
assert_equal(x_q_0, x_ub)

# full quantization
x_q_1 = qr_ub(inputs, qnoise_factor=1.0)
assert_equal(x_q_1, xq)

# mixing half and half
x_q_05 = qr_ub(inputs, qnoise_factor=0.5)
assert_equal(x_q_05, x_ub_xq)

#########################################
# Relu upper bound
# Quantized clip for float relu
# (The quantized clip has precedence over the relu upper bound.)
#########################################
qr_ub.is_quantized_clip = True

# no quantization
x_q_0 = qr_ub(inputs, qnoise_factor=0.0)
assert_equal(x_q_0, x_clipped)

# full quantization
x_q_1 = qr_ub(inputs, qnoise_factor=1.0)
assert_equal(x_q_1, xq)

# mixing half and half
x_q_05 = qr_ub(inputs, qnoise_factor=0.5)
assert_equal(x_q_05, x_clipped_xq)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 3cc8719

Please sign in to comment.