Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/paddle/static/quantization/quanter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ...base.framework import IrGraph, core
from ..log_helper import get_logger
from .quantization_pass import (
AddQuantDequantForResidual,
AddQuantDequantPass,
ConvertToInt8Pass,
OutScaleForInferencePass,
Expand Down Expand Up @@ -370,6 +371,16 @@ def quant_aware(
for sub_graph in sub_graphs:
transform_pass.apply(sub_graph)

residual_pass = AddQuantDequantForResidual(
scope=scope,
place=place,
quant_bits=config['activation_bits'],
is_test=is_test,
)

for subgraph in sub_graphs:
residual_pass.apply(sub_graph)

if len(quant_dequant_ops) > 0:
qdq_func = (
AddQuantDequantPassV2
Expand Down
108 changes: 108 additions & 0 deletions python/paddle/static/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3511,3 +3511,111 @@ def _insert_quant_dequant_op(self, graph, var_node):
graph.link_to(zero_point_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node


class AddQuantDequantForResidual:
"""
Quantize the residual connections. Add quant and dequant ops for the residual inputs.
"""

def __init__(
self,
scope,
place,
quant_bits=8,
is_test=True,
):
"""
Args:
scope(static.Scope): The scope is used to initialize these new parameters.
place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors.
If it's string, it can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
quant_bits(int, optional): quantization bit number for weight. Default is 8.
is_test(bool, optional): Whether quantization with training or not. Default is True.
"""
self._place = _get_paddle_place(place)
self._scope = scope
self._quant_bits = quant_bits
self._is_test = is_test
assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None."

def apply(self, graph):
"""
Args:
graph(IrGraph): the target graph.
"""
assert isinstance(
graph, IrGraph
), 'graph must be the instance of IrGraph.'
weight_var_names = self._all_weight_node_names(graph)
var_node_names_with_order = self._var_name_order(graph)
for op in graph.all_op_nodes():
if op.name() != 'elementwise_add':
continue
first_input_name = op.inputs[0].name()
second_input_name = op.inputs[1].name()
if (
first_input_name in weight_var_names
or second_input_name in weight_var_names
):
continue
skip_node = (
op.inputs[0]
if var_node_names_with_order[first_input_name]
< var_node_names_with_order[second_input_name]
else op.inputs[1]
)
self._insert_quant_dequant(graph, skip_node, op)

def _all_weight_node_names(self, graph):
"""
Return a list of weight variables (including casted weight)
"""
weight_var_names = [
node.name() for node in graph.all_persistable_nodes()
]
for op in graph.all_op_nodes():
if op.name() == 'cast' and op.inputs[0].persistable():
weight_var_names.append(op.outputs[0].name())

return weight_var_names

def _var_name_order(self, graph):
"""
Return a dictionary with variable names as key and their order as value
"""
ordered_ops = graph.topology_sort()
var_node_names_with_order = {}
for idx, op_node in enumerate(ordered_ops):
for in_var_node in op_node.inputs:
in_var_name = in_var_node.name()
if var_node_names_with_order.get(in_var_name) is None:
var_node_names_with_order[in_var_name] = idx

return var_node_names_with_order

def _insert_quant_dequant(self, graph, var_node, op):
"""
Insert per tensort quantize_linear and dequantize_linear node between var_node and op
"""
insert_quant_pass = InsertQuantizeLinear(
self._place,
self._scope,
quant_bits=self._quant_bits,
quant_axis=-1,
channel_wise=False,
is_test=self._is_test,
)
quant_var_name = var_node.name() + '.skip'
op_role = op.op().attr("op_role")
(
quant_var_node,
scale_var_node,
) = insert_quant_pass.insert_quant_op(
graph, var_node, var_name=quant_var_name, op_role=op_role
)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node, op_role
)
graph.update_input_link(var_node, dequant_var_node, op)
24 changes: 6 additions & 18 deletions test/collective/fleet/test_fleet_qat_meta_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, input_size, output_size):

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear2(x) + x
x = self.linear3(x)
return x

Expand Down Expand Up @@ -91,10 +91,12 @@ def valid_program(self, train_prog, eval_prog):
self.assertEqual(
ops_type.count('matmul_v2'), 3
) # SimpleNet has 3 linear layers
self.assertEqual(ops_type.count('quantize_linear'), 6)
# There are three linear layers and each layer has this op in weight.
self.assertEqual(ops_type.count('quantize_linear'), 7)
# There are three linear layers and each layer has two quantize linear.
# Also, the input of skip connection has one quantize linear.
# Hence, there are 3 * 2 + 1 = 7 quantize linear.
self.assertEqual(
ops_type.count('dequantize_linear'), 6
ops_type.count('dequantize_linear'), 7
) # Dequantize Op will follow quantize op (fake quantize), so the number is same.

def test_fleet_with_qat(self):
Expand All @@ -113,9 +115,6 @@ def test_fleet_with_qat(self):
else base.CPUPlace()
)
eval_prog = train_prog.clone(for_test=True)
optimizer.qat_init(
place, scope=paddle.static.global_scope(), test_program=eval_prog
)
self.execute_program(train_prog, startup_prog, input_x, input_y)
self.valid_program(train_prog, eval_prog)

Expand All @@ -125,17 +124,6 @@ def setup_strategy(self, strategy):
strategy.qat = True
strategy.amp = True

def valid_program(self, train_prog, eval_prog):
ops_type = [op.type for op in train_prog.block(0).ops]
self.assertEqual(
ops_type.count('matmul_v2'), 3
) # SimpleNet has 3 linear layers
self.assertEqual(ops_type.count('quantize_linear'), 6)
# There are three linear layers and each layer has this op in weight.
self.assertEqual(
ops_type.count('dequantize_linear'), 6
) # Dequantize Op will follow quantize op (fake quantize), so the number is same.


if __name__ == "__main__":
unittest.main()