Skip to content

Commit 90ebce9

Browse files
bingyanghuangluotao1
authored andcommitted
QAT int8 MKL-DNN transformation pass (#17819)
1 parent 377f9e6 commit 90ebce9

File tree

3 files changed

+425
-0
lines changed

3 files changed

+425
-0
lines changed

python/paddle/fluid/contrib/slim/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from .quantization_strategy import *
2121
from . import mkldnn_post_training_strategy
2222
from .mkldnn_post_training_strategy import *
23+
from . import quantization_mkldnn_pass
24+
from .quantization_mkldnn_pass import *
2325

2426
__all__ = quantization_pass.__all__ + quantization_strategy.__all__
2527
__all__ += mkldnn_post_training_strategy.__all__
28+
__all__ += quantization_mkldnn_pass.__all__
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
from .... import core
17+
from ....framework import IrGraph
18+
from ....framework import IrNode
19+
20+
__all__ = ['TransformForMkldnnPass']
21+
22+
23+
class TransformForMkldnnPass(object):
24+
"""
25+
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
26+
IrGraph. Following transformations did in this pass:
27+
1. Convert int8 range weights with float32 data type, which are generated by
28+
the QuantizationFreezePass, to float32 range weights with float32 data type
29+
by using the corresponding scales. This conversion is because MKL-DNN INT8
30+
conv2d kernel now only supports float32 weights input, will do weights
31+
quantization inside the conv2d kernel.
32+
2. Create the new conv2d op with the converted weights and link its output
33+
to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
34+
_output" as true
35+
3. Transform fake_quantize_xx op to quantize op
36+
4. Remove fake_dequantize_abs_max op
37+
"""
38+
39+
def __init__(self, scope=None, place=None):
40+
"""
41+
Args:
42+
scope(fluid.Scope): scope is used to initialize the new parameters.
43+
place(fluid.CPUPlace): place is used to initialize the new parameters.
44+
45+
46+
Examples:
47+
.. code-block:: python
48+
# The original graph will be rewrite.
49+
import paddle.fluid as fluid
50+
from paddle.fluid.contrib.slim.quantization \
51+
import TransformForMkldnnPass
52+
from paddle.fluid.framework import IrGraph
53+
from paddle.fluid import core
54+
55+
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
56+
place = fluid.CPUPlace()
57+
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(),
58+
place)
59+
mkldnn_pass.apply(graph)
60+
"""
61+
62+
self._scope = scope
63+
self._place = place
64+
65+
self.quantize_type = [
66+
'fake_quantize_moving_average_abs_max',
67+
'fake_quantize_range_abs_max'
68+
]
69+
self.dequantize_type = ['fake_dequantize_max_abs']
70+
71+
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
72+
self._conv_ops = ['conv2d', 'depthwise_conv2d']
73+
74+
self.InScale = {}
75+
self.max_range = {}
76+
self.conv_new_output = {}
77+
self.s8_max = 127
78+
# Temporary code for keeping the mul op as fake quantization
79+
#TODO Intel: Remove the following code when mul int8 mkldnn
80+
# kernel enabled
81+
self.mul_input_id = []
82+
self.mul_output_id = []
83+
84+
def apply(self, graph):
85+
"""
86+
Quantize the graph for running MKL-DNN INT8 inference. According
87+
to activation quantization type, the graph will transform fake
88+
quantize ops to quantize ops and remove the fake dequantize ops.
89+
90+
Args:
91+
graph(IrGraph): the applied graph.
92+
"""
93+
94+
assert isinstance(graph,
95+
IrGraph), 'graph must be the instance of IrGraph.'
96+
ops = graph.all_op_nodes()
97+
98+
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
99+
# Collect the InScales and max_range to calculate the new scales for MKL-DNN
100+
# INT8 conv2d
101+
for op_node in ops:
102+
if op_node.name() in self.dequantize_type:
103+
input_name = op_node.input("X")[0]
104+
scale_name = op_node.input("Scale")[0]
105+
self.InScale[input_name] = self._load_param(self._scope,
106+
scale_name)[0]
107+
self.max_range[input_name] = op_node.op().attr("max_range")
108+
self.conv_new_output[input_name] = op_node.output("Out")[0]
109+
# Temporary graph transform on keeping the mul op
110+
# TODO Intel: Remove following code
111+
elif op_node.name() in ['mul']:
112+
input_node = graph._find_node_by_name(op_node.inputs,
113+
op_node.input('X')[0])
114+
output_node = graph._find_node_by_name(op_node.outputs,
115+
op_node.output('Out')[0])
116+
self.mul_input_id.append(input_node.id())
117+
self.mul_output_id.append(output_node.id())
118+
119+
for op_node in ops:
120+
if op_node.name() in self._conv_ops:
121+
self._transform_to_conv_mkldnn(graph, op_node)
122+
elif op_node.name() in self.quantize_type:
123+
self._transform_to_quantize_mkldnn(graph, op_node)
124+
elif op_node.name() in self.dequantize_type:
125+
self._remove_fake_dequantize_op(graph, op_node)
126+
self._remove_unused_var_nodes(graph)
127+
return graph
128+
129+
def _transform_to_conv_mkldnn(self, graph, op_node):
130+
weight_name = op_node.input("Filter")[0]
131+
output_name = op_node.output("Output")[0]
132+
# Convert int8 range weights to fp32 range weights
133+
weight = self._load_param(self._scope, weight_name)
134+
w_fp32 = np.divide(
135+
np.multiply(weight, 127), self.max_range[output_name])
136+
w_fp32 = w_fp32.reshape(weight.shape)
137+
self._restore_var(weight_name, w_fp32)
138+
input_var_node = graph._find_node_by_name(op_node.inputs,
139+
op_node.input("Input")[0])
140+
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
141+
142+
# Set fake_dequantize_abs_max's output as new output of conv2d
143+
output_var_node = graph._find_node_by_name(
144+
graph.all_var_nodes(), self.conv_new_output[output_name])
145+
attrs = {
146+
name: op_node.op().attr(name)
147+
for name in op_node.op().attr_names()
148+
}
149+
150+
conv_op_node = graph.create_op_node(
151+
op_type='conv2d',
152+
attrs=attrs,
153+
inputs={'Input': input_var_node,
154+
'Filter': weight_var_node},
155+
outputs={'Output': output_var_node})
156+
157+
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
158+
scale_in = self.s8_max / self.InScale[output_name]
159+
scale_w = []
160+
scale_w.append(self.max_range[output_name] / self.s8_max)
161+
162+
conv_op_node.set_attr("Scale_weights", scale_w)
163+
conv_op_node.set_attr("Scale_in", scale_in)
164+
conv_op_node.set_attr("Scale_out", 1.0)
165+
conv_op_node.set_attr("use_mkldnn", 1)
166+
conv_op_node.set_attr("force_fp32_output", 1)
167+
graph.link_to(input_var_node, conv_op_node)
168+
graph.link_to(weight_var_node, conv_op_node)
169+
graph.link_to(conv_op_node, output_var_node)
170+
graph.safe_remove_nodes(op_node)
171+
172+
def _transform_to_quantize_mkldnn(self, graph, op_node):
173+
"""
174+
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
175+
"""
176+
input_var_node = graph._find_node_by_name(op_node.inputs,
177+
op_node.input("X")[0])
178+
output_var_node = graph._find_node_by_name(op_node.outputs,
179+
op_node.output("Out")[0])
180+
if output_var_node.id() in self.mul_input_id:
181+
return
182+
else:
183+
scale_in = self.s8_max / self._load_param(
184+
self._scope, op_node.input("InScale")[0])[0]
185+
quant_op_node = graph.create_op_node(
186+
op_type='quantize',
187+
attrs={
188+
'data_format': 'MKLDNNLAYOUT',
189+
'use_mkldnn': 1,
190+
'Scale': scale_in,
191+
'is_negative_input': 1
192+
},
193+
inputs={'Input': input_var_node},
194+
outputs={'Output': output_var_node})
195+
graph.link_to(input_var_node, quant_op_node)
196+
graph.link_to(quant_op_node, output_var_node)
197+
graph.safe_remove_nodes(op_node)
198+
199+
def _remove_fake_dequantize_op(self, graph, op_node):
200+
input_var_node = graph._find_node_by_name(op_node.inputs,
201+
op_node.input("X")[0])
202+
if input_var_node.id() in self.mul_output_id:
203+
return
204+
else:
205+
graph.safe_remove_nodes(op_node)
206+
207+
def _load_param(self, scope, param_name):
208+
return np.array(scope.find_var(param_name).get_tensor())
209+
210+
def _restore_var(self, name, array):
211+
tensor = self._scope.find_var(name).get_tensor()
212+
tensor.set(array, self._place)
213+
214+
def _remove_unused_var_nodes(self, graph):
215+
all_used_vars = set()
216+
ops = graph.all_op_nodes()
217+
for op_node in ops:
218+
for input_node in op_node.inputs:
219+
all_used_vars.add(input_node)
220+
for output_node in op_node.outputs:
221+
all_used_vars.add(output_node)
222+
223+
all_used_vars = {n.node for n in all_used_vars}
224+
all_unused_vars = {
225+
n
226+
for n in filter(lambda node: node.node not in all_used_vars,
227+
graph.all_var_nodes())
228+
}
229+
graph.safe_remove_nodes(all_unused_vars)

0 commit comments

Comments
 (0)