Skip to content

Commit c3c8b23

Browse files
jwfrommwweic
authored andcommitted
[Relay] Option to select which convolution layers are quantized. (apache#3173)
* Stashing for later maybe. * Added new option to leave specific layers unquantized. * Better error checking. * remove unneeded import * tab to spaces * pylint fixes * more pylint fixes
1 parent 87edf6f commit c3c8b23

File tree

5 files changed

+55
-6
lines changed

5 files changed

+55
-6
lines changed

python/tvm/relay/quantize/_annotate.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ def conv2d_rewrite(ref_call, new_args, ctx):
156156
if cnt < current_qconfig().skip_k_conv:
157157
_set_conv_counter(cnt + 1)
158158
return None
159+
160+
if current_qconfig().skip_conv_layers is not None:
161+
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
162+
if cnt in leave_alone_indices:
163+
_set_conv_counter(cnt + 1)
164+
return None
165+
159166
_set_conv_counter(cnt + 1)
160167

161168
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
@@ -168,6 +175,7 @@ def conv2d_rewrite(ref_call, new_args, ctx):
168175
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
169176

170177
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
178+
171179
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
172180

173181

@@ -178,6 +186,11 @@ def dense_rewrite(ref_call, new_args, ctx):
178186
cnt = _conv_counter()
179187
if cnt < current_qconfig().skip_k_conv:
180188
return None
189+
if current_qconfig().skip_conv_layers is not None:
190+
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
191+
if cnt - 1 in leave_alone_indices:
192+
return None
193+
181194
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
182195
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
183196

@@ -194,8 +207,13 @@ def dense_rewrite(ref_call, new_args, ctx):
194207
@register_annotate_function("multiply")
195208
def multiply_rewrite(ref_call, new_args, ctx):
196209
"""Rewrite function for multiply."""
197-
if _conv_counter() <= current_qconfig().skip_k_conv:
210+
cnt = _conv_counter()
211+
if cnt <= current_qconfig().skip_k_conv:
198212
return None
213+
if current_qconfig().skip_conv_layers is not None:
214+
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
215+
if cnt - 1 in leave_alone_indices:
216+
return None
199217

200218
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
201219
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@@ -216,8 +234,13 @@ def multiply_rewrite(ref_call, new_args, ctx):
216234
@register_annotate_function("add")
217235
def add_rewrite(ref_call, new_args, ctx):
218236
"""Rewrite function for add."""
219-
if _conv_counter() <= current_qconfig().skip_k_conv:
237+
cnt = _conv_counter()
238+
if cnt <= current_qconfig().skip_k_conv:
220239
return None
240+
if current_qconfig().skip_conv_layers is not None:
241+
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
242+
if cnt - 1 in leave_alone_indices:
243+
return None
221244

222245
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
223246
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@@ -244,8 +267,13 @@ def add_rewrite(ref_call, new_args, ctx):
244267

245268
def identity_rewrite(ref_call, new_args, ctx):
246269
"""Simply forward the original operation"""
247-
if _conv_counter() <= current_qconfig().skip_k_conv:
270+
cnt = _conv_counter()
271+
if cnt <= current_qconfig().skip_k_conv:
248272
return None
273+
if current_qconfig().skip_conv_layers is not None:
274+
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
275+
if cnt - 1 in leave_alone_indices:
276+
return None
249277

250278
x_expr, x_kind = _get_expr_kind(new_args[0])
251279
if x_kind is None:
@@ -262,8 +290,14 @@ def identity_rewrite(ref_call, new_args, ctx):
262290

263291
def pool2d_rewrite(ref_call, new_args, ctx):
264292
"""Rewrite function for max pool2d"""
265-
if _conv_counter() <= current_qconfig().skip_k_conv:
293+
cnt = _conv_counter()
294+
if cnt <= current_qconfig().skip_k_conv:
266295
return None
296+
if current_qconfig().skip_conv_layers is not None:
297+
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
298+
if cnt - 1 in leave_alone_indices:
299+
return None
300+
267301
expr, x_kind = _get_expr_kind(new_args[0])
268302

269303
if x_kind is None:
@@ -280,8 +314,13 @@ def pool2d_rewrite(ref_call, new_args, ctx):
280314
@register_annotate_function("concatenate")
281315
def concatenate_rewrite(ref_call, new_args, ctx):
282316
"""Rewrite function for concatenate"""
283-
if _conv_counter() <= current_qconfig().skip_k_conv:
317+
cnt = _conv_counter()
318+
if cnt <= current_qconfig().skip_k_conv:
284319
return None
320+
if current_qconfig().skip_conv_layers is not None:
321+
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
322+
if cnt - 1 in leave_alone_indices:
323+
return None
285324

286325
input_tuple = new_args[0]
287326
expr_list = [_get_expr_kind(x)[0] for x in input_tuple]

python/tvm/relay/quantize/quantize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class QConfig(NodeBase):
7171
"dtype_activation": "int32",
7272
"global_scale": 8.0,
7373
"skip_k_conv": 1,
74+
"skip_conv_layers": None,
7475
"round_for_shift": True,
7576
"store_lowbit_output": True,
7677
"debug_enabled_ops": None,
@@ -139,6 +140,10 @@ def qconfig(**kwargs):
139140
skip_k_conv: int
140141
The number of skipped conv2d.
141142
143+
skip_conv_layers: list
144+
Different way of specifying which layers to avoid. Provide a list of indices
145+
that indicate which conv2d layers to leave untouched.
146+
142147
round_for_shift: boolean
143148
Whether to add bias for rounding during shift.
144149

src/relay/pass/quantize.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
596596
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
597597
p->stream << "global_scale=" << op->global_scale << ", ";
598598
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
599+
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
599600
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
600601
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
601602
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";

src/relay/pass/quantize.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class QConfigNode : public Node {
126126
DataType dtype_activation = Int(32);
127127
double global_scale = 8.0;
128128
int skip_k_conv = 1;
129+
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
129130
bool round_for_shift = true;
130131
bool store_lowbit_output = true;
131132
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
@@ -140,6 +141,7 @@ class QConfigNode : public Node {
140141
v->Visit("dtype_activation", &dtype_activation);
141142
v->Visit("global_scale", &global_scale);
142143
v->Visit("skip_k_conv", &skip_k_conv);
144+
v->Visit("skip_conv_layers", &skip_conv_layers);
143145
v->Visit("round_for_shift", &round_for_shift);
144146
v->Visit("store_lowbit_output", &store_lowbit_output);
145147
v->Visit("debug_enabled_ops", &debug_enabled_ops);

topi/python/topi/cuda/conv2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
104104
return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
105105
pre_computed=False)
106106
if cfg.template_key == 'int8':
107-
return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
107+
if (data.dtype == 'int8' or data.dtype == 'uint8'):
108+
return conv2d_NCHWc_int8(
109+
cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
108110

109111
if layout == 'NCHW':
110112
return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)

0 commit comments

Comments
 (0)