Skip to content

Commit

Permalink
Add support for float type in estimate.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 368566596
Change-Id: Ib3cc15e8c86b8e1b80a01663b13aa6bd7964d4d3
  • Loading branch information
lishanok authored and copybara-github committed Apr 15, 2021
1 parent 295cd9b commit 9ca7ec2
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 17 deletions.
77 changes: 63 additions & 14 deletions qkeras/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import numpy as np
import tensorflow.compat.v1 as tf
from absl import logging

from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import InputLayer
Expand Down Expand Up @@ -244,11 +245,17 @@ def get_quant_mode(quant):
("quantized_ulaw", 0, -1, 1),
("quantized_tanh", 0, -1, 1),
("quantized_po2", 1, -1, 1),
("quantized_relu_po2", 1, -1, 0)
("quantized_relu_po2", 1, -1, 0),
("float", 5, 32, 1)
]

for (inst, mode, bits, sign) in modes:
if quant.__class__.__name__ == inst:
if not quant or getattr(quant, "__name__", None) == "linear":
# if quantizer not specified or linear, we use float type
if inst == "float":
return (mode, bits, sign)

elif quant.__class__.__name__ == inst:
if bits == -1:
bits = int(quant.bits)
if (
Expand All @@ -269,39 +276,47 @@ def get_operation_type(layer, output_cache):
Determines operator strenght according to the following table.
x
qb(n) +/-,exp t(-1,0,+1) b(-1,+1) b(0,1)
qb(n) * << >>,- ?,- ?,- ?
+/-,exp << >>,- + ?,- ^ ?,-
w t(-1,0,+1) ?,- ?,- ?,^ ?,^ ^
b(-1,+1) ?,- ^ ?,^ ^ ^
b(0,1) ? ?,- ^ ^ ^
qb(n) +/-,exp t(-1,0,+1) b(-1,+1) b(0,1) float
qb(n) * << >>,- ?,- ?,- ? *
+/-,exp << >>,- + ?,- ^ ?,- *
w t(-1,0,+1) ?,- ?,- ?,^ ?,^ ^ *
b(-1,+1) ?,- ^ ?,^ ^ ^ *
b(0,1) ? ?,- ^ ^ ^ *
float * * * * * *
Arguments:
layer: layer in Keras to determine the operation strength.
output_cache: cache of input tensor bit sizes.
Returns:
One of "multiplier", "adder", "barrel", "mux", "xor", "neg".
One of "mult", "fmult", "adder", "barrel", "mux", "xor".
Note: "mult" represents quantized bit multiplier, "fmult" represents
floating point multiplier.
"""

wx_table = [
["mult", "barrel", "mux", "mux", "mux"],
["barrel", "adder", "mux", "xor", "mux"],
["mux", "mux", "mux", "mux", "xor"],
["mux", "xor", "mux", "xor", "xor"],
["mux", "mux", "xor", "xor", "xor"]
["mult", "barrel", "mux", "mux", "mux", "fmult"],
["barrel", "adder", "mux", "xor", "mux", "fmult"],
["mux", "mux", "mux", "mux", "xor", "fmult"],
["mux", "xor", "mux", "xor", "xor", "fmult"],
["mux", "mux", "xor", "xor", "xor", "fmult"],
["fmult", "fmult", "fmult", "fmult", "fmult", "fmult"],
]

# check if this is a quantized layers (QDense, QConv, QDepthwise)
if hasattr(layer, "get_quantizers"):
w_quant = layer.get_quantizers()[0]
w_mode, w_bits, w_sign = get_quant_mode(w_quant)
if w_mode == "float":
logging.warning("%s kernel is unquantized!", layer.name)

# for the input, get tensor input and search the cache that associates
# the quantizer with a tensor
if output_cache.get(layer.input.experimental_ref(), None) is not None:
x_mode, x_bits, x_sign = get_quant_mode(
output_cache.get(layer.input.experimental_ref()))
if x_mode == "float":
logging.warning("%s input is unquantized!", layer.name)
else:
print("cannot determine presently model for {}".format(layer.name))
return "null", (w_mode, -1), (w_bits, -1), (w_sign, -1)
Expand Down Expand Up @@ -422,6 +437,11 @@ def extract_model_operations(in_model):
weight_type = get_quant_mode(weight_quant)
bias_type = get_quant_mode(bias_quant)

if weight_type[0] == "float":
logging.warning("%s kernel is unquantized!", layer.name)
if bias_type[0] == "float":
logging.warning("%s bias is unquantized!", layer.name)

elif layer.__class__.__name__ in ["QConv1D"]:

_, _, channels_i = input_shape
Expand All @@ -444,6 +464,11 @@ def extract_model_operations(in_model):
weight_type = get_quant_mode(weight_quant)
bias_type = get_quant_mode(bias_quant)

if weight_type[0] == "float":
logging.warning("%s kernel is unquantized!", layer.name)
if bias_type[0] == "float":
logging.warning("%s bias is unquantized!", layer.name)

elif layer.__class__.__name__ in ["QDepthwiseConv2D"]:

_, _, _, channels_i = input_shape
Expand All @@ -467,6 +492,11 @@ def extract_model_operations(in_model):
weight_type = get_quant_mode(weight_quant)
bias_type = get_quant_mode(bias_quant)

if weight_type[0] == "float":
logging.warning("%s kernel is unquantized!", layer.name)
if bias_type[0] == "float":
logging.warning("%s bias is unquantized!", layer.name)

elif layer.__class__.__name__ in ["QSeparableConv1D"]:

_, _, channels_i = input_shape
Expand Down Expand Up @@ -495,6 +525,13 @@ def extract_model_operations(in_model):
weight_type = [depthwise_type, pointwise_type]
bias_type = get_quant_mode(bias_quant)

if depthwise_type[0] == "float":
logging.warning("%s depthwise kernel is unquantized!", layer.name)
if pointwise_type[0] == "float":
logging.warning("%s pointwise kernel is unquantized!", layer.name)
if bias_type[0] == "float":
logging.warning("%s bias is unquantized!", layer.name)

elif layer.__class__.__name__ in ["QSeparableConv2D"]:

_, _, _, channels_i = input_shape
Expand Down Expand Up @@ -523,6 +560,13 @@ def extract_model_operations(in_model):
weight_type = [depthwise_type, pointwise_type]
bias_type = get_quant_mode(bias_quant)

if depthwise_type[0] == "float":
logging.warning("%s depthwise kernel is unquantized!", layer.name)
if pointwise_type[0] == "float":
logging.warning("%s pointwise kernel is unquantized!", layer.name)
if bias_type[0] == "float":
logging.warning("%s bias is unquantized!", layer.name)

elif layer.__class__.__name__ in ["QDense"]:

_, size_i = input_shape
Expand All @@ -539,6 +583,11 @@ def extract_model_operations(in_model):
weight_type = get_quant_mode(weight_quant)
bias_type = get_quant_mode(bias_quant)

if weight_type[0] == "float":
logging.warnings("%s kernel is unquantized!", layer.name)
if bias_type[0] == "float":
logging.warnings("%s bias is unquantized!", layer.name)

# "number_of_operations" is tensor_shape.Dimension type
operations[layer.name] = {
"type":
Expand Down
8 changes: 5 additions & 3 deletions qkeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ def get_model_sparsity(model, per_layer=False, allow_list=None):

# Calculate the sparsity layer by layer
layer_sparsity = []
total_sparsity = 0.
all_weights = []
for layer in model.layers:
if hasattr(layer, "quantizers") and layer.__class__.__name__ in allow_list:
Expand All @@ -944,9 +945,10 @@ def get_model_sparsity(model, per_layer=False, allow_list=None):
layer_weights = np.concatenate(layer_weights)
layer_sparsity.append((layer.name, np.mean(layer_weights == 0)))

# Average the sparsity for the entire model
all_weights = np.concatenate(all_weights)
total_sparsity = np.mean(all_weights == 0)
if len(all_weights) > 0:
# Average the sparsity for the entire model
all_weights = np.concatenate(all_weights)
total_sparsity = np.mean(all_weights == 0)
if per_layer:
return (total_sparsity, layer_sparsity)
else:
Expand Down
18 changes: 18 additions & 0 deletions tests/print_qstats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

from qkeras.estimate import print_qstats
from qkeras.utils import model_quantize
from qkeras import QConv2D
from qkeras.quantizers import *


def create_network():
Expand All @@ -35,6 +37,16 @@ def create_network():
return Model(inputs=xi, outputs=x)


def create_mix_network():

xi = Input((28, 28, 1))
x = QConv2D(32, (3, 3), kernel_quantizer=binary())(xi)
x = Activation("relu")(x)
x = Conv2D(32, (3, 3))(x)
x = Activation("softmax")(x)
return Model(inputs=xi, outputs=x)


def test_conversion_print_qstats():
# this tests if references in tensorflow are working properly.
m = create_network()
Expand All @@ -51,6 +63,12 @@ def test_conversion_print_qstats():
qq.summary()
print_qstats(qq)

# test if print_qstats works with unquantized layers
print_qstats(m)

# test if print_qstats works with mixture of quantized and unquantized layers
m1 = create_mix_network()
print_qstats(m1)

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 9ca7ec2

Please sign in to comment.