Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 436637235
Change-Id: Id737d647a37cbb83a29c16e28b4c3109e4a38419
  • Loading branch information
lishanok authored and copybara-github committed Mar 23, 2022
1 parent 80b5244 commit 7be2baf
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 38 deletions.
57 changes: 42 additions & 15 deletions qkeras/qpooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,43 @@ def call(self, inputs):
2. first, we call keras version of averaging first: y1 = keras_average(x)
then multiply it with pool_size^2: y2 = y1 * pool_area
Last, y3 = y2 * quantize(1/ pool_area)
Our numerical anaysis suggests negligible error between 1 and 2. Therefore
we use option #2 here for the simplicity of implementation.
3. Improved based on #2, but multiply x with pool_area before averaging
so that we don't lose precision during averaging. The order now becomes:
first, multiply x with pool_area: y1 = x * pool_area
then we call keras version of averaging: y2 = keras_average(y1)
Last, y3 = y2 * quantize(1/ pool_area)
4. Since there is sum_pooling operation, another solution is to use
depthwise_conv2d with kernel weights = 1 to get the pooling sum. In this
case we don't lose precision due to averaging. However, this solution
will introduce extra weights to the layer, which might break our code
elsewhere.
Since we need to match software and hardware inference numerics, we are now
using #3 in the implementation.
"""

x = super(QAveragePooling2D, self).call(inputs)

if self.average_quantizer:
# Calculates the pool area
if isinstance(self.pool_size, int):
pool_area = self.pool_size * self.pool_size
else:
pool_area = np.prod(self.pool_size)

# Revertes the division results.
x = x * pool_area
# Calculates the pooling average of x*pool_area
x = super(QAveragePooling2D, self).call(inputs*pool_area)

# Quantizes the multiplication factor.
mult_factor = 1.0 / pool_area

q_mult_factor = self.average_quantizer_internal(mult_factor)
q_mult_factor = K.cast_to_floatx(q_mult_factor)

# Computes pooling average.
x = x * q_mult_factor

else:
# Since no quantizer is available, we directly call the keras layer
x = super(QAveragePooling2D, self).call(inputs)

if self.activation is not None:
return self.activation(x)
return x
Expand Down Expand Up @@ -148,22 +162,35 @@ def call(self, inputs):
then multiply it with the denominator(pool_area) used by averaging:
y2 = y1 * pool_area
Last, y3 = y2 * quantize(1/ pool_area)
Our numerical anaysis suggests negligible error between 1 and 2. Therefore
we use option #2 here for the simplicity of implementation.
3. we perform pooling sum, and then multiply the sum with the quantized
inverse multiplication factor to get the average value.
Our previous implementation uses option #2. Yet we observed minor numerical
mismatch between software and hardware inference. Therefore we use #3 as
the current implementation.
"""

x = super(QGlobalAveragePooling2D, self).call(inputs)

if self.average_quantizer:
# Calculates pooling sum.
if self.data_format == "channels_last":
x = K.sum(inputs, axis=[1, 2], keepdims=self.keepdims)
else:
x = K.sum(inputs, axis=[2, 3], keepdims=self.keepdims)

# Calculates the pooling area
pool_area = self.compute_pooling_area(input_shape=inputs.shape)
# Reverts the division results
x = x * pool_area
# Quantizes the multiplication factor

# Quantizes the inverse multiplication factor
mult_factor = 1.0 / pool_area
q_mult_factor = self.average_quantizer_internal(mult_factor)

# Derives average pooling value from pooling sum.
x = x * q_mult_factor

else:
# If quantizer is not available, calls the keras layer.
x = super(QGlobalAveragePooling2D, self).call(inputs)

if self.activation is not None:
return self.activation(x)
return x
Expand Down
23 changes: 10 additions & 13 deletions qkeras/qtools/generate_layer_data_type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def generate_layer_data_type_map(
"GlobalAveragePooling2D", "QAveragePooling2D",
"QGlobalAveragePooling2D"]:
(input_quantizer, _) = input_qe_list[0]

qtools_average_quantizer = None
# This is a hack. We don't want to implement a new accumulator class
# just for averagpooling. So we re-use accumulator type in conv/dense
# layers which need multiplier and kernel as input parameters.
Expand Down Expand Up @@ -447,18 +447,15 @@ def generate_layer_data_type_map(
output_quantizer = update_output_quantizer_in_graph(
graph, node_id, quantizer_factory, layer_quantizer, for_reference)

layer_data_type_map[layer] = LayerDataType(
input_quantizer_list,
multiplier,
accumulator,
None,
None,
None,
None,
output_quantizer,
output_shapes,
operation_count
)
layer_data_type_map[layer] = {
"input_quantizer_list": input_quantizer_list,
"average_quantizer": qtools_average_quantizer,
"pool_sum_accumulator": accumulator,
"pool_avg_multiplier": multiplier,
"output_quantizer": output_quantizer,
"output_shapes": output_shapes,
"operation_count": operation_count
}

# If it's a Quantized Activation layer.
elif node_type in ["QActivation", "QAdaptiveActivation", "Activation"]:
Expand Down
39 changes: 35 additions & 4 deletions qkeras/qtools/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def map_to_json(mydict):
output_dict["source_quantizers"] = q_list

def get_val(feature, key):
# Return feature[key] or feature.key
if isinstance(feature, dict):
val = feature.get(key, None)
else:
Expand All @@ -139,7 +140,28 @@ def get_val(feature, key):

def set_layer_item(layer_item, key, feature, shape=None,
is_compound_datatype=False, output_key_name=None):

"""Generates the quantizer entry to a given layer_item.
This function extracts relevanant quantizer fields using the key (
quantizer name) from a given feature (layer entry from layer_data_type_map).
Args:
layer_item: Layer entry in the output dictionary. It includes the
info such as quantizers, output shape, etc. of each layer
key: Quantizer, such as kernel/bias quantizer, etc. If feature
feature: layer_data_type_map entry of each layer. This feature will be
parsed and converted to layer_item for the output dictionary.
shape: quantizer input shape
is_compound_datatype: Bool. Wether the quantizer is a compound
or unitary quantizer type. For example, kernel quantizer and bias
quantizer are unitary quantizer types, multiplier and accumulator
are compound quantizer types.
output_key_name: str. Change key to output_key_name in layer_item. If
None, will use the existing key.
Return:
None
"""
val = get_val(feature, key)
if val is not None:
quantizer = val
Expand Down Expand Up @@ -172,10 +194,20 @@ def set_layer_item(layer_item, key, feature, shape=None,
"variance_quantizer", "variance_quantizer"]:
set_layer_item(layer_item, key=key, feature=feature)

for key in ["internal_divide_quantizer", "internal_divide_quantizer",
for key in ["internal_divide_quantizer",
"internal_multiplier", "internal_accumulator"]:
set_layer_item(layer_item, key=key, feature=feature,
is_compound_datatype=True)

elif layer_item["layer_type"] in [
"AveragePooling2D", "AvgPool2D", "GlobalAvgPool2D",
"GlobalAveragePooling2D", "QAveragePooling2D",
"QGlobalAveragePooling2D"]:
set_layer_item(layer_item, key="average_quantizer", feature=feature)
for key in ["pool_sum_accumulator", "pool_avg_multiplier"]:
set_layer_item(layer_item, key=key, feature=feature,
is_compound_datatype=True)

else:
# populate the feature to dictionary
set_layer_item(layer_item, key="weight_quantizer", feature=feature,
Expand All @@ -201,8 +233,7 @@ def set_layer_item(layer_item, key, feature, shape=None,
set_layer_item(layer_item, key="fused_accumulator", feature=feature,
is_compound_datatype=True)

layer_item["operation_count"] = get_val(feature, "operation_count")

layer_item["operation_count"] = get_val(feature, "operation_count")
output_dict[layer.name] = layer_item

return output_dict
48 changes: 44 additions & 4 deletions tests/qpooling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ def test_qpooling_in_qtools():
dtype_dict = interface.map_to_json(layer_map)

# Checks the QAveragePpooling layer datatype
multiplier = dtype_dict["pooling"]["multiplier"]
accumulator = dtype_dict["pooling"]["accumulator"]
multiplier = dtype_dict["pooling"]["pool_avg_multiplier"]
accumulator = dtype_dict["pooling"]["pool_sum_accumulator"]
average_quantizer = dtype_dict["pooling"]["average_quantizer"]
output = dtype_dict["pooling"]["output_quantizer"]

assert_equal(multiplier["quantizer_type"], "quantized_bits")
Expand All @@ -225,9 +226,15 @@ def test_qpooling_in_qtools():
assert_equal(output["int_bits"], 1)
assert_equal(output["is_signed"], 1)

assert_equal(average_quantizer["quantizer_type"], "binary")
assert_equal(average_quantizer["bits"], 1)
assert_equal(average_quantizer["int_bits"], 1)
assert_equal(average_quantizer["is_signed"], 1)

# Checks the QGlobalAveragePooling layer datatype
multiplier = dtype_dict["global_pooling"]["multiplier"]
accumulator = dtype_dict["global_pooling"]["accumulator"]
multiplier = dtype_dict["global_pooling"]["pool_avg_multiplier"]
accumulator = dtype_dict["global_pooling"]["pool_sum_accumulator"]
average_quantizer = dtype_dict["global_pooling"]["average_quantizer"]
output = dtype_dict["global_pooling"]["output_quantizer"]

assert_equal(multiplier["quantizer_type"], "quantized_bits")
Expand All @@ -247,6 +254,39 @@ def test_qpooling_in_qtools():
assert_equal(output["int_bits"], 2)
assert_equal(output["is_signed"], 1)

assert_equal(average_quantizer["quantizer_type"], "quantized_bits")
assert_equal(average_quantizer["bits"], 4)
assert_equal(average_quantizer["int_bits"], 1)
assert_equal(average_quantizer["is_signed"], 1)


def test_QAveragePooling_output():
# Checks if the output of QAveragePooling layer with average_quantizer
# is correct.
x = np.ones(shape=(2, 6, 6, 1))
x[0, 0, :, :] = 0
x = tf.constant(x)

y = QAveragePooling2D(
pool_size=(3, 3),
strides=3,
padding="valid",
average_quantizer="quantized_bits(8, 1, 1)")(x)
y = y.numpy()
assert np.all(y == [[[[0.65625], [0.65625]], [[0.984375], [0.984375]]],
[[[0.984375], [0.984375]], [[0.984375], [0.984375]]]])


def test_QGlobalAveragePooling_output():
# Checks if the output of QGlobalAveragePooling layer with average_quantizer
# is correct.
x = np.ones(shape=(2, 3, 3, 2))
x[0, 0, 1, :] = 0
x = tf.constant(x)
y = QGlobalAveragePooling2D(average_quantizer="quantized_bits(8, 1, 1)")(x)
y = y.numpy()
assert np.all(y == np.array([[0.875, 0.875], [0.984375, 0.984375]]))


if __name__ == "__main__":
pytest.main([__file__])
4 changes: 2 additions & 2 deletions tests/qtools_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,12 +566,12 @@ def test_pooling():
model = pooling_qmodel()
dtype_dict = run(model, input_quantizers)

accumulator = dtype_dict["avg_pooling"]["accumulator"]
accumulator = dtype_dict["avg_pooling"]["pool_sum_accumulator"]
assert accumulator["quantizer_type"] == "quantized_bits"
assert accumulator["bits"] == 10
assert accumulator["int_bits"] == 3

accumulator = dtype_dict["global_avg_pooling"]["accumulator"]
accumulator = dtype_dict["global_avg_pooling"]["pool_sum_accumulator"]
assert accumulator["quantizer_type"] == "quantized_bits"
assert accumulator["bits"] == 16
assert accumulator["int_bits"] == 9
Expand Down

0 comments on commit 7be2baf

Please sign in to comment.