Skip to content

Commit b2b3d1f

Browse files
author
Matthew
committed
respond to review comments
1 parent 55a1f3d commit b2b3d1f

File tree

4 files changed

+13
-3
lines changed

4 files changed

+13
-3
lines changed

python/tvm/ir/affine_type.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class TensorAffineType(AffineType):
4848
4949
dtype : str
5050
The content data type.
51+
52+
axis : int
53+
The axis for per-channel quantization.
5154
"""
5255

5356
def __init__(self, scale, zero_point, dtype, axis=-1):

python/tvm/relay/qnn/op/qnn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,10 @@ def conv2d(
276276
):
277277
r"""Quantized 2D convolution.
278278
279-
This operator convolves quantized data with quantized kernel. The scale of
280-
the output quantized tensor is the product of the kernel_scale and
279+
This operator convolves quantized data with quantized kernel.
280+
If doing Per-channel quantization, qnn expects the kernel_zero_scale
281+
and optionally the kernel_zero_point will be 1-D vectors instead of scalars.
282+
The scale of the output quantized tensor is the product of the kernel_scale and
281283
input_scale of the input quantized tensors. The zero point of the output
282284
quantized tensor is 0. By default, the dtype of output is int32. Please also
283285
refer to Requantize operator to understand how to scale back the int32
@@ -544,6 +546,9 @@ def dense(
544546
545547
`Y = X * W`
546548
549+
If doing Per-channel quantization, qnn expects the kernel_zero_scale
550+
and optionally the kernel_zero_point will be 1-D vectors instead of scalars.
551+
547552
Parameters
548553
----------
549554
data : tvm.relay.Expr

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ def relu(expr, type_map):
256256
t = type_map[arg]
257257
scale_shape = infer_shape(t.scale)
258258
z_p = t.zero_point
259-
if len(scale_shape) > 0 and scale_shape[0] > 1:
259+
assert len(scale_shape) <= 1
260+
if len(scale_shape) == 1 and scale_shape[0] > 1:
260261
b_shape = [1] * len(infer_shape(arg))
261262
b_shape[t.axis] = -1
262263
z_p = relay.op.reshape(relay.op.broadcast_to(z_p, scale_shape), b_shape)

src/relay/qnn/op/dense.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
6262
}
6363
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
6464
ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
65+
// weight_zero_point can be a scalar or a vector of the same shape as the weight_scale
6566
AssignType(types[5], DataType::Float(32), param->units, reporter); // weight_scale
6667

6768
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

0 commit comments

Comments
 (0)