Skip to content

Commit

Permalink
Implementation of h-swish activation quantizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 368709712
Change-Id: I9ed7e48f1e9d43f2ece804074eadc2272e367332
  • Loading branch information
lishanok authored and copybara-github committed Apr 15, 2021
1 parent 9ca7ec2 commit 4faff4a
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 1 deletion.
149 changes: 148 additions & 1 deletion qkeras/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ def __init__(self,
self.use_variables = use_variables

def __str__(self):
# Convert Tensors to printable strings by converting to a numpy array and
# Converts Tensors to printable strings by converting to a numpy array and
# then using regex to remove brackets when there is only one integer bit
integer_bits = re.sub(
r"\[(\d)\]", r"\g<1>",
Expand Down Expand Up @@ -2197,6 +2197,153 @@ def get_config(self):
return config


class quantized_hswish(quantized_bits): # pylint: disable=invalid-name
"""Computes a quantized hard swish to a number of bits.
Equation of h-swisth function in mobilenet v3:
hswish(x) = x * ReluY(x + relu_shift) / Y
Y is relu_upper_bound
Attributes:
bits: number of bits to perform quantization, also known as word length.
integer: number of integer bits.
symmetric: if True, the quantization is in symmetric mode, which puts
restricted range for the quantizer. Otherwise, it is in asymmetric mode,
which uses the full range.
alpha: a tensor or None, the scaling factor per channel.
If None, the scaling factor is 1 for all channels.
use_stochastic_rounding: if true, we perform stochastic rounding. This
parameter is passed on to the underlying quantizer quantized_bits which
is used to quantize h_swish.
scale_axis: which axis to calculate scale from
qnoise_factor: float. a scalar from 0 to 1 that represents the level of
quantization noise to add. This controls the amount of the quantization
noise to add to the outputs by changing the weighted sum of
(1 - qnoise_factor)*unquantized_x + qnoise_factor*quantized_x.
var_name: String or None. A variable name shared between the tf.Variables
created in the build function. If None, it is generated automatically.
use_ste: Bool. Whether to use "straight-through estimator" (STE) method or
not.
use_variables: Bool. Whether to make the quantizer variables to be dynamic
tf.Variables or not.
relu_shift: integer type, representing the shift amount
of the unquantized relu.
relu_upper_bound: integer type, representing an upper bound of the
unquantized relu. If None, we apply relu without the upper bound when
"is_quantized_clip" is set to false (true by default).
Note: The quantized relu uses the quantization parameters (bits and
integer) to upper bound. So it is important to set relu_upper_bound
appropriately to the quantization parameters. "is_quantized_clip"
has precedence over "relu_upper_bound" for backward compatibility.
"""

def __init__(self,
bits=8,
integer=0,
symmetric=0,
alpha=None,
use_stochastic_rounding=False,
scale_axis=None,
qnoise_factor=1.0,
var_name=None,
use_ste=True,
use_variables=False,
relu_shift: int = 3,
relu_upper_bound: int = 6):
super(quantized_hswish, self).__init__(
bits=bits,
integer=integer,
symmetric=symmetric,
keep_negative=True,
alpha=alpha,
use_stochastic_rounding=use_stochastic_rounding,
scale_axis=scale_axis,
qnoise_factor=qnoise_factor,
var_name=var_name,
use_ste=use_ste,
use_variables=use_variables)

self.relu_shift = relu_shift
self.relu_upper_bound = relu_upper_bound

def __str__(self):
""" Converts Tensors to printable strings."""

integer_bits = (
re.sub(r"\[(\d)\]", r"\g<1>",
str(self.integer.numpy() if isinstance(self.integer, tf.Variable)
else self.integer)))
assert isinstance(integer_bits, int)

flags = [str(self.bits),
integer_bits,
str(int(self.symmetric)),
"relu_shift=" + str(self.relu_shift),
"relu_upper_bound=" + str(self.relu_upper_bound)
]

if not self.keep_negative:
flags.append("keep_negative=False")
if self.alpha:
alpha = str(self.alpha)
if isinstance(self.alpha, six.string_types):
alpha = "'" + alpha + "'"
flags.append("alpha=" + alpha)
if self.use_stochastic_rounding:
flags.append("use_stochastic_rounding=" +
str(int(self.use_stochastic_rounding)))
return "quantized_hswish(" + ",".join(flags) + ")"

def __call__(self, x):
assert self.relu_upper_bound > 0, (
f"relu_upper_bound must be a positive value, "
f"found {self.relu_upper_bound} instead")
assert self.relu_shift > 0, (
f"relu_shift must be a positive value, "
f"found {self.relu_shift} instead")
x = K.cast_to_floatx(x)
shift_x = x + self.relu_shift
relu_x = tf.where(shift_x <= self.relu_upper_bound,
K.relu(shift_x, alpha=False),
tf.ones_like(shift_x) * self.relu_upper_bound)

hswish_x = tf.math.multiply(x, relu_x) / self.relu_upper_bound
return super(quantized_hswish, self).__call__(hswish_x)

def min(self):
"""Gets the minimum value that quantized_hswish can represent."""

# get the minimum value that the number of bits can represent
min_quant = super(quantized_hswish, self).min()
# In the negative end, the hswish function becomes
# x * (x + relu_shift) / relu_upper_bound
# the min value of this parabolic function is
# - relu_shift^2 / (4 * relu_upper_bound)
denom = 4 * self.relu_upper_bound
min_parabolic = -self.relu_shift * self.relu_shift / denom

if min_quant >= min_parabolic:
return min_quant

# get the quantized value of min_parabolic
return super(quantized_hswish, self).call(min_parabolic)

def get_config(self):
"""Add relu_shift and relu_upper_bound to the config file."""

base_config = super(quantized_hswish, self).get_config()

config = {
"relu_shift": self.relu_shift,
"relu_upper_bound": self.relu_upper_bound
}

out_config = dict(
list(base_config.items()) + list(config.items()))
return out_config


def get_quantizer(identifier):
"""Gets the quantizer.
Expand Down
38 changes: 38 additions & 0 deletions tests/qactivation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from qkeras import stochastic_binary
from qkeras import stochastic_ternary
from qkeras import ternary
from qkeras import quantized_hswish

@pytest.mark.parametrize(
'bits, max_value, use_stochastic_rounding, quadratic_approximation, '
Expand Down Expand Up @@ -573,5 +574,42 @@ def test_stochastic_ternary_inference_mode(alpha, threshold, test_values, expect
assert_allclose(result, expected_values, rtol=1e-05)


@pytest.mark.parametrize(
# y = x * relu6(x+3)/6, the total world length is 6 bits with 2 integer
# bits. The quantization is in asymmetric mode.
('bits, integer, symmetric, relu_shift, relu_upper_bound,'
'test_values, expected_values'), [
(
6, 2, 0, 3, 6,
np.array([[-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1, 4, 10]],
dtype=K.floatx()),
np.array([[0., -0.375, -0.375, -0.25, 0., 0.25, 0.625,
3.875, 3.875]], dtype=K.floatx()),
),
(
6, 4, 1, 3, 6,
np.array([[-10.0, -2.0, -2.3, -0.25, 0.0, 0.5, 1, 4, 10]],
dtype=K.floatx()),
np.array([[0., -0.5, -0.5, 0., 0., 0.5, 0.5, 4., 10.]],
dtype=K.floatx()),
),
(
2, 0, 0, 3, 6,
np.array([[-10.0, -2.0, -2.3, -0.25, 0.0, 0.5, 1, 4, 10]],
dtype=K.floatx()),
np.array([[0., -0.5, -0.5, 0., 0., 0.5, 0.5, 0.5, 0.5]],
dtype=K.floatx()),
),
])
def test_quantized_hswish(bits, integer, symmetric, relu_shift,
relu_upper_bound, test_values, expected_values):
x = K.placeholder(ndim=2)
f = K.function(
[x], [quantized_hswish(bits, integer, symmetric,relu_shift=relu_shift,
relu_upper_bound=relu_upper_bound)(x)])
result = f([test_values])[0]
assert_allclose(result, expected_values, rtol=1e-05)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 4faff4a

Please sign in to comment.