Skip to content

Commit 21db1eb

Browse files
author
Matthew Brookhart
authored
[F2QI] Fix a rounding error on AvgPool when input and output affine scales differ (#12577)
cc @sfvaroglu @AndrewZhaoLuo
1 parent f7c1436 commit 21db1eb

File tree

2 files changed

+61
-18
lines changed

2 files changed

+61
-18
lines changed

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,33 +114,79 @@ def adaptive_avgpool1d(expr, type_map):
114114
"""Rewrite an adaptive avgpool op"""
115115
arg = expr.args[0]
116116
t = type_map[arg]
117-
arg = relay.op.cast(arg, "int32")
117+
out_t = type_map[expr]
118+
if not (
119+
approx_equal(t.scale, out_t.scale)
120+
and approx_equal(t.zero_point, out_t.zero_point)
121+
and tvm.ir.structural_equal(t.dtype, out_t.dtype)
122+
):
123+
arg = relay.qnn.op.requantize(
124+
arg,
125+
t.scale,
126+
t.zero_point,
127+
out_t.scale,
128+
out_t.zero_point,
129+
out_dtype="int32",
130+
axis=t.axis,
131+
)
132+
else:
133+
arg = relay.op.cast(arg, "int32")
118134
output_size = expr.attrs.output_size
119135
out = relay.op.nn.adaptive_avg_pool1d(arg, output_size)
120-
out = relay.op.cast(out, t.dtype)
121-
return [out, t]
136+
return [out, TensorAffineType(out_t.scale, out_t.zero_point, "int32", out_t.axis)]
122137

123138

124139
@register_fake_quantization_to_integer("nn.avg_pool2d")
125140
def avgpool2d(expr, type_map):
126141
"""Rewrite a avgpool op"""
127142
arg = expr.args[0]
128143
t = type_map[arg]
129-
arg = relay.op.cast(arg, "int32")
144+
out_t = type_map[expr]
145+
if not (
146+
approx_equal(t.scale, out_t.scale)
147+
and approx_equal(t.zero_point, out_t.zero_point)
148+
and tvm.ir.structural_equal(t.dtype, out_t.dtype)
149+
):
150+
arg = relay.qnn.op.requantize(
151+
arg,
152+
t.scale,
153+
t.zero_point,
154+
out_t.scale,
155+
out_t.zero_point,
156+
out_dtype="int32",
157+
axis=t.axis,
158+
)
159+
else:
160+
arg = relay.op.cast(arg, "int32")
130161
out = relay.op.nn.avg_pool2d(arg, **expr.attrs)
131-
out = relay.op.cast(out, t.dtype)
132-
return [out, t]
162+
return [out, TensorAffineType(out_t.scale, out_t.zero_point, "int32", out_t.axis)]
133163

134164

135165
@register_fake_quantization_to_integer("nn.global_avg_pool2d")
136166
def global_avgpool2d(expr, type_map):
137167
"""Rewrite a global_avgpool op"""
138168
arg = expr.args[0]
139169
t = type_map[arg]
140-
arg = relay.op.cast(arg, "int32")
170+
out_t = type_map[expr]
171+
out_t = type_map[expr]
172+
if not (
173+
approx_equal(t.scale, out_t.scale)
174+
and approx_equal(t.zero_point, out_t.zero_point)
175+
and tvm.ir.structural_equal(t.dtype, out_t.dtype)
176+
):
177+
arg = relay.qnn.op.requantize(
178+
arg,
179+
t.scale,
180+
t.zero_point,
181+
out_t.scale,
182+
out_t.zero_point,
183+
out_dtype="int32",
184+
axis=t.axis,
185+
)
186+
else:
187+
arg = relay.op.cast(arg, "int32")
141188
out = relay.op.nn.global_avg_pool2d(arg)
142-
out = relay.op.cast(out, t.dtype)
143-
return [out, t]
189+
return [out, TensorAffineType(out_t.scale, out_t.zero_point, "int32", out_t.axis)]
144190

145191

146192
@register_fake_quantization_to_integer("broadcast_to")

tests/python/relay/test_pass_fake_quantization_to_integer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,9 @@ def test_fake_quantize_maxpool():
281281
def test_fake_quantize_adaptive_avgpool1d(output_size):
282282
x = relay.var("x", shape=[1, 128, 768], dtype="int8")
283283

284-
zero = relay.const(0)
285-
x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)
284+
x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(-12))
286285
op = relay.op.nn.adaptive_avg_pool1d(x, output_size)
287-
op = relay.qnn.op.quantize(op, relay.const(2.0), zero)
286+
op = relay.qnn.op.quantize(op, relay.const(0.5), relay.const(10))
288287

289288
x_np = np.random.randint(-128, 127, size=[1, 128, 768], dtype="int8")
290289

@@ -294,10 +293,9 @@ def test_fake_quantize_adaptive_avgpool1d(output_size):
294293
def test_fake_quantize_avgpool():
295294
x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
296295

297-
zero = relay.const(0)
298-
x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)
296+
x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(-12))
299297
op = relay.op.nn.avg_pool2d(x, [3, 3])
300-
op = relay.qnn.op.quantize(op, relay.const(2.0), zero)
298+
op = relay.qnn.op.quantize(op, relay.const(0.5), relay.const(10))
301299

302300
x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8")
303301

@@ -307,10 +305,9 @@ def test_fake_quantize_avgpool():
307305
def test_fake_quantize_global_avg_pool():
308306
x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
309307

310-
zero = relay.const(0)
311-
x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)
308+
x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(-12))
312309
op = relay.op.nn.global_avg_pool2d(x)
313-
op = relay.qnn.op.quantize(op, relay.const(2.0), zero)
310+
op = relay.qnn.op.quantize(op, relay.const(0.5), relay.const(10))
314311

315312
x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8")
316313

0 commit comments

Comments
 (0)