Skip to content

Commit 607ba7c

Browse files
author
Matthew
committed
respond to review comments
1 parent 4c6dc86 commit 607ba7c

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

include/tvm/ir/affine_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class TensorAffineTypeNode : public AffineTypeNode {
7171
RelayExpr zero_point;
7272
/*! \brief The data type of this type */
7373
DataType dtype;
74-
/*! \brief The data type of this type */
74+
/*! \brief The axis for per-channel quantization */
7575
int axis;
7676

7777
void VisitAttrs(tvm::AttrVisitor* v) {

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,12 @@ def relu(expr, type_map):
255255
arg = expr.args[0]
256256
t = type_map[arg]
257257
scale_shape = infer_shape(t.scale)
258-
zero = relay.const(0, dtype="float32")
258+
z_p = t.zero_point
259259
if len(scale_shape) > 0 and scale_shape[0] > 1:
260260
b_shape = [1] * len(infer_shape(arg))
261261
b_shape[t.axis] = -1
262-
zero = relay.op.reshape(relay.op.broadcast_to(zero, scale_shape), b_shape)
263-
zero = relay.qnn.op.quantize(zero, t.scale, t.zero_point, t.axis, t.dtype)
262+
z_p = relay.op.reshape(relay.op.broadcast_to(z_p, scale_shape), b_shape)
263+
zero = relay.op.cast(z_p, t.dtype)
264264
return [relay.op.maximum(arg, fold_constant(zero)), t]
265265

266266

0 commit comments

Comments
 (0)