Skip to content

Commit 4ffbdcd

Browse files
Matthew BrookhartJosh Fromm
andauthored
[Relay][Quantization] Per-Channel FQ2I (#8883)
* WIP support per-channel quantization * more WIP * More WIP * fix issue with per-channel bias_add * Fix fake quantize tests (#4) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Add Relu * One more little one (#5) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Fix requantize shape bug. * Non-working Per-channel Dense * Fix legalization for non spatial operators. (#6) * Fix legalization for non spatial operators. * Fix axis checks for end2end functionality. * fix axis normalization fix lint fix lint again * Per channel fq2i (#8) * WIP support per-channel quantization * more WIP * More WIP * fix issue with per-channel bias_add * Fix fake quantize tests (#4) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Add Relu * One more little one (#5) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Fix requantize shape bug. * Non-working Per-channel Dense * Fix legalization for non spatial operators. (#6) * Fix legalization for non spatial operators. * Fix axis checks for end2end functionality. * fix axis normalization fix lint fix lint again * Fix bug in requantize dimension expansion. * Format. Co-authored-by: Josh Fromm <jwfromm@octoml.ai> * respond to review comments respond to review comments Co-authored-by: Josh Fromm <jwfromm@octoml.ai>
1 parent e0aac94 commit 4ffbdcd

File tree

15 files changed

+315
-59
lines changed

15 files changed

+315
-59
lines changed

include/tvm/ir/affine_type.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,24 +71,28 @@ class TensorAffineTypeNode : public AffineTypeNode {
7171
RelayExpr zero_point;
7272
/*! \brief The data type of this type */
7373
DataType dtype;
74+
/*! \brief The axis for per-channel quantization */
75+
int axis;
7476

7577
void VisitAttrs(tvm::AttrVisitor* v) {
7678
v->Visit("scale", &scale);
7779
v->Visit("zero_point", &zero_point);
7880
v->Visit("dtype", &dtype);
81+
v->Visit("axis", &axis);
7982
}
8083

8184
bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const {
8285
equal->MarkGraphNode();
8386
return equal(scale, other->scale) && equal(zero_point, other->zero_point) &&
84-
equal(dtype, other->dtype);
87+
equal(dtype, other->dtype) && equal(axis, other->axis);
8588
}
8689

8790
void SHashReduce(SHashReducer hash_reduce) const {
8891
hash_reduce->MarkGraphNode();
8992
hash_reduce(scale);
9093
hash_reduce(zero_point);
9194
hash_reduce(dtype);
95+
hash_reduce(axis);
9296
}
9397

9498
static constexpr const char* _type_key = "TensorAffineType";
@@ -101,7 +105,7 @@ class TensorAffineTypeNode : public AffineTypeNode {
101105
*/
102106
class TensorAffineType : public AffineType {
103107
public:
104-
TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype);
108+
TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis);
105109

106110
TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode);
107111
};

python/tvm/ir/affine_type.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,15 @@ class TensorAffineType(AffineType):
4848
4949
dtype : str
5050
The content data type.
51+
52+
axis : int
53+
The axis for per-channel quantization.
5154
"""
5255

53-
def __init__(self, scale, zero_point, dtype):
54-
self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype)
56+
def __init__(self, scale, zero_point, dtype, axis=-1):
57+
self.__init_handle_by_constructor__(
58+
_ffi_api.TensorAffineType, scale, zero_point, dtype, axis
59+
)
5560

5661

5762
@tvm._ffi.register_object("TupleAffineType")

python/tvm/relay/frontend/onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def _impl_v1(cls, inputs, attr, params):
490490
attr["dilations"] = [1] + list(attr["dilations"])
491491
if "pads" in attr:
492492
attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]]
493-
493+
attr["channels"] = kernel_shapes[0][0]
494494
out = AttrCvt(
495495
op_name=dimension_picker("conv"),
496496
transforms={

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import tvm
2222
from tvm import relay
23+
from tvm._ffi.base import TVMError
2324
from .. import op as reg
2425

2526
#################################################
@@ -139,11 +140,35 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
139140
data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs
140141

141142
shift_data = relay.subtract(
142-
relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16")
143-
)
144-
shift_kernel = relay.subtract(
145-
relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16")
143+
relay.cast(data, dtype="int16"), relay.cast(input_zero_point, dtype="int16")
146144
)
145+
# If kernel zero point is a scalar we can directly subtract it.
146+
if len(types[3].shape) == 0:
147+
shift_kernel = relay.subtract(
148+
relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, dtype="int16")
149+
)
150+
# Otherwise it needs to be broadcast.
151+
else:
152+
# Determine output axis of kernel for spatial operations.
153+
if hasattr(attrs, "kernel_layout"):
154+
output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O")
155+
# For dense operations, broadcast to [N, K] layout.
156+
elif isinstance(attrs, relay.op.op_attrs.DenseAttrs):
157+
output_axis = 0
158+
# For matrix multiplication instead expand to [K, N] layout.
159+
elif isinstance(attrs, relay.op.op_attrs.MatmulAttrs):
160+
output_axis = 1
161+
else:
162+
raise TVMError(
163+
"Legalization of %s is not yet supported with per channel parameters"
164+
% str(type(attrs))
165+
)
166+
167+
shift_kernel = relay.nn.bias_add(
168+
relay.cast(kernel, dtype="int16"),
169+
relay.cast(kernel_zero_point, dtype="int16"),
170+
output_axis,
171+
)
147172
new_attrs = {k: attrs[k] for k in attrs.keys()}
148173
return relay_op(shift_data, shift_kernel, **new_attrs)
149174

python/tvm/relay/qnn/op/qnn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,10 @@ def conv2d(
276276
):
277277
r"""Quantized 2D convolution.
278278
279-
This operator convolves quantized data with quantized kernel. The scale of
280-
the output quantized tensor is the product of the kernel_scale and
279+
This operator convolves quantized data with quantized kernel.
280+
If doing Per-channel quantization, qnn expects the kernel_zero_scale
281+
and optionally the kernel_zero_point will be 1-D vectors instead of scalars.
282+
The scale of the output quantized tensor is the product of the kernel_scale and
281283
input_scale of the input quantized tensors. The zero point of the output
282284
quantized tensor is 0. By default, the dtype of output is int32. Please also
283285
refer to Requantize operator to understand how to scale back the int32
@@ -544,6 +546,9 @@ def dense(
544546
545547
`Y = X * W`
546548
549+
If doing Per-channel quantization, qnn expects the kernel_zero_scale
550+
and optionally the kernel_zero_point will be 1-D vectors instead of scalars.
551+
547552
Parameters
548553
----------
549554
data : tvm.relay.Expr

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,22 @@
1818
import tvm
1919
from tvm import relay
2020
from tvm.ir import TensorAffineType, TupleAffineType
21+
from tvm.tir import bijective_layout
2122
from ..op import register_fake_quantization_to_integer
2223

2324

2425
def fold_constant(expr):
2526
return relay.transform.FoldConstantExpr(expr, tvm.IRModule())
2627

2728

29+
def get_zeros(scale):
30+
return fold_constant(relay.op.cast(relay.op.zeros_like(scale), "int32"))
31+
32+
33+
def infer_shape(expr):
34+
return relay.transform.InferType()(tvm.IRModule.from_expr(expr))["main"].body.checked_type.shape
35+
36+
2837
@register_fake_quantization_to_integer("qnn.dequantize")
2938
def dequantize(expr, type_map):
3039
"""Remove dequantize op"""
@@ -52,8 +61,13 @@ def quantize(expr, type_map):
5261
expr.args[1],
5362
expr.args[2],
5463
out_dtype=expr.attrs.out_dtype,
64+
axis=t.axis,
5565
)
56-
return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)]
66+
67+
return [
68+
out,
69+
TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis),
70+
]
5771

5872

5973
def register_unary_identity(op_name):
@@ -94,14 +108,19 @@ def bias_add(expr, type_map):
94108
b_t = type_map[b]
95109
in_scale = fold_constant(x_t.scale)
96110
in_zero_point = fold_constant(x_t.zero_point)
97-
if not tvm.ir.structural_equal(x_t, b_t):
111+
if not (
112+
tvm.ir.structural_equal(x_t.scale, b_t.scale)
113+
and tvm.ir.structural_equal(x_t.zero_point, b_t.zero_point)
114+
and tvm.ir.structural_equal(x_t.dtype, b_t.dtype)
115+
):
98116
b = relay.qnn.op.requantize(
99117
b,
100118
b_t.scale,
101119
b_t.zero_point,
102120
in_scale,
103121
in_zero_point,
104122
out_dtype=x_t.dtype,
123+
axis=0,
105124
)
106125
out = relay.op.nn.bias_add(x, b, **expr.attrs)
107126
return [out, x_t]
@@ -116,11 +135,13 @@ def conv2d(expr, type_map):
116135
x_t = type_map[x]
117136
w_t = type_map[weight]
118137
conv_scale = fold_constant(x_t.scale * w_t.scale)
119-
conv_zp = relay.const(0)
138+
conv_zp = get_zeros(conv_scale)
120139
out = relay.qnn.op.conv2d(
121140
x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs
122141
)
123-
return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype)]
142+
out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"]
143+
out_axis = bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1]
144+
return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)]
124145

125146

126147
@register_fake_quantization_to_integer("nn.dense")
@@ -132,11 +153,11 @@ def dense(expr, type_map):
132153
x_t = type_map[x]
133154
w_t = type_map[weight]
134155
dense_scale = fold_constant(x_t.scale * w_t.scale)
135-
dense_zp = relay.const(0)
156+
dense_zp = get_zeros(dense_scale)
136157
out = relay.qnn.op.dense(
137158
x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs
138159
)
139-
return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)]
160+
return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype, 1)]
140161

141162

142163
@register_fake_quantization_to_integer("nn.batch_matmul")
@@ -148,7 +169,7 @@ def batch_matmul(expr, type_map):
148169
matmul_scale = fold_constant(x_t.scale * y_t.scale)
149170
matmul_zp = relay.const(0)
150171
out = relay.qnn.op.batch_matmul(x, y, x_t.zero_point, y_t.zero_point, x_t.scale, y_t.scale)
151-
return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype)]
172+
return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype, x_t.axis)]
152173

153174

154175
@register_fake_quantization_to_integer("concatenate")
@@ -198,19 +219,52 @@ def clip(expr, type_map):
198219
amax = expr.attrs.a_max
199220
scale = fold_constant(t.scale)
200221
z_p = fold_constant(t.zero_point)
201-
if isinstance(scale, relay.expr.Constant) and isinstance(z_p, relay.expr.Constant):
222+
if (
223+
isinstance(scale, relay.expr.Constant)
224+
and scale.data.numpy().size == 1
225+
and isinstance(z_p, relay.expr.Constant)
226+
and z_p.data.numpy().size == 1
227+
):
202228
scale = scale.data.numpy().item()
203229
z_p = z_p.data.numpy().item()
204230
new_min = int(amin / scale + z_p)
205231
new_max = int(amax / scale + z_p)
206232
out = relay.op.clip(arg, new_min, new_max)
207233
else:
208-
amin = relay.op.round(relay.op.const(amin) / scale + z_p)
209-
amax = relay.op.round(relay.op.const(amax) / scale + z_p)
210-
out = relay.op.minimum(relay.op.maximum(arg, amin), amax)
234+
if not isinstance(amin, relay.expr.Constant):
235+
amin = relay.op.const(amin)
236+
if not isinstance(amax, relay.expr.Constant):
237+
amax = relay.op.const(amax)
238+
239+
scale_shape = infer_shape(scale)
240+
if len(scale_shape) > 0 and scale_shape[0] > 1:
241+
b_shape = [1] * len(infer_shape(arg))
242+
b_shape[t.axis] = -1
243+
amin = relay.op.reshape(relay.op.broadcast_to(amin, scale_shape), b_shape)
244+
amax = relay.op.reshape(relay.op.broadcast_to(amax, scale_shape), b_shape)
245+
amin = relay.qnn.op.quantize(amin, scale, z_p, t.axis, t.dtype)
246+
amax = relay.qnn.op.quantize(amax, scale, z_p, t.axis, t.dtype)
247+
out = relay.op.minimum(relay.op.maximum(arg, fold_constant(amin)), fold_constant(amax))
248+
211249
return [out, t]
212250

213251

252+
@register_fake_quantization_to_integer("nn.relu")
253+
def relu(expr, type_map):
254+
"""Rewrite a relu op"""
255+
arg = expr.args[0]
256+
t = type_map[arg]
257+
scale_shape = infer_shape(t.scale)
258+
z_p = t.zero_point
259+
assert len(scale_shape) <= 1
260+
if len(scale_shape) == 1 and scale_shape[0] > 1:
261+
b_shape = [1] * len(infer_shape(arg))
262+
b_shape[t.axis] = -1
263+
z_p = relay.op.reshape(relay.op.broadcast_to(z_p, scale_shape), b_shape)
264+
zero = relay.op.cast(z_p, t.dtype)
265+
return [relay.op.maximum(arg, fold_constant(zero)), t]
266+
267+
214268
@register_fake_quantization_to_integer("nn.pad")
215269
def pad(expr, type_map):
216270
"""Rewite an nn.pad op"""
@@ -231,6 +285,7 @@ def pad(expr, type_map):
231285
t.scale,
232286
t.zero_point,
233287
out_dtype=t.dtype,
288+
axis=pad_t.axis,
234289
)
235290
else:
236291
## If the pad-value is a constant, we need to quantize it
@@ -319,6 +374,7 @@ def binary(expr, type_map):
319374
out_t.scale,
320375
out_t.zero_point,
321376
out_dtype=out_t.dtype,
377+
axis=left_t.axis,
322378
)
323379

324380
if right_t != out_t:
@@ -329,6 +385,7 @@ def binary(expr, type_map):
329385
out_t.scale,
330386
out_t.zero_point,
331387
out_dtype=out_t.dtype,
388+
axis=right_t.axis,
332389
)
333390
out = op(left, right)
334391
return [out, out_t]

src/ir/affine_type.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,28 @@ namespace tvm {
3030
using tvm::ReprPrinter;
3131
using namespace tvm::runtime;
3232

33-
TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype) {
33+
TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype,
34+
int axis) {
3435
ObjectPtr<TensorAffineTypeNode> n = make_object<TensorAffineTypeNode>();
3536
n->scale = std::move(scale);
3637
n->zero_point = std::move(zero_point);
3738
n->dtype = std::move(dtype);
39+
n->axis = std::move(axis);
3840
data_ = std::move(n);
3941
}
4042

4143
TVM_REGISTER_NODE_TYPE(TensorAffineTypeNode);
4244

4345
TVM_REGISTER_GLOBAL("ir.TensorAffineType")
44-
.set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype) {
45-
return TensorAffineType(scale, zero_point, dtype);
46+
.set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) {
47+
return TensorAffineType(scale, zero_point, dtype, axis);
4648
});
4749

4850
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
4951
.set_dispatch<TensorAffineTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
5052
auto* node = static_cast<const TensorAffineTypeNode*>(ref.get());
5153
p->stream << "TensorAffineType(" << node->scale << ", " << node->zero_point << ", "
52-
<< node->dtype << ")";
54+
<< node->dtype << ", " << node->axis << ")";
5355
});
5456

5557
TupleAffineType::TupleAffineType(Array<TensorAffineType> types) {

src/relay/qnn/op/convolution.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point,
495495
* \param input_zero_point The input zero point expr.
496496
* \param param The qnn conv2d attributes.
497497
* \param out_channels The number of output channels.
498-
* \return The sequence of Relay operatos for term3.
498+
* \return The sequence of Relay operators for term3.
499499
* \note The term3 looks like this
500500
*
501501
* Sigma(c,r,s) zp_a * QW(k, c, r, s)
@@ -625,7 +625,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
625625
* \node Lowering of the qnn.conv2d operator
626626
* A quantized tensor is represented in following manner
627627
* A = scale_a x (QA - zp_A)
628-
* where QA is quantized tensor, scale_a and zp_A are quantizations
628+
* where QA is quantized tensor, scale_a and zp_A are quantization
629629
* params.
630630
*
631631
* Quantized convolution will convolve two quantized tensors and returns a
@@ -662,8 +662,8 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
662662
* a workaround, we fall back to simpler lowering using int32 conv if
663663
* the conv is dilated. We fallback also in case of grouped conv.
664664
*
665-
* For depthwise, we can similarly unroll the computation. The intial compute is as follows
666-
* wehere cm = channel_multiplier
665+
* For depthwise, we can similarly unroll the computation. The initial compute is as follows
666+
* where cm = channel_multiplier
667667
*
668668
* Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w)
669669
* * (Qa(n, oc/cm, oh + r, ow + s) - zp_a)
@@ -693,12 +693,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
693693
Expr kernel_zero_point = new_args[3];
694694
const auto* param = attrs.as<Conv2DAttrs>();
695695
ICHECK(param != nullptr);
696-
// Assertion checks for exisiing support.
696+
// Assertion checks for existing support.
697697
ICHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC")
698698
<< "qnn.conv2d supports only NCHW/NHWC input data layout.";
699699
ICHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" ||
700700
param->kernel_layout == "HWOI")
701701
<< "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout.";
702+
ICHECK(param->kernel_size.defined()) << "qnn.conv2d requires kernel size to be specified.";
702703

703704
int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier;
704705
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =

0 commit comments

Comments
 (0)