Skip to content

Commit 67ad688

Browse files
masahiylc
authored andcommitted
[AMP] Disallow fp16 conversion for summation-like ops (apache#8810)
* [AMP] Disallow fp16 conversion for summation-like ops * test only structural equality
1 parent 08c5d45 commit 67ad688

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

python/tvm/relay/transform/mixed_precision.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@
8181
"divide",
8282
"nn.bias_add",
8383
"nn.batch_norm",
84-
"sum",
85-
"mean",
8684
"sqrt",
8785
"shape_of",
8886
# Simple activations
@@ -107,15 +105,9 @@
107105
# "nn.global_max_pool1d", # does not exist yet
108106
"nn.global_max_pool2d",
109107
# "nn.global_max_pool3d", # does not exist yet
110-
# "nn.global_avg_pool1d", # does not exist yet
111-
"nn.global_avg_pool2d",
112-
# "nn.global_avg_pool3d", # does not exist yet
113108
"nn.adaptive_max_pool1d",
114109
"nn.adaptive_max_pool2d",
115110
"nn.adaptive_max_pool3d",
116-
"nn.adaptive_avg_pool1d",
117-
"nn.adaptive_avg_pool2d",
118-
"nn.adaptive_avg_pool3d",
119111
]
120112
DEFAULT_NEVER_LIST = [
121113
# In general if |f(x)| >> |x| for expected inputs then put the op here.
@@ -131,6 +123,13 @@
131123
# Do not allow arange arguments (begin/end) to be fp16. "end" can be a big fp32 number
132124
# not representable in fp16.
133125
"arange",
126+
# Ops that could involve a large summation are not allowed in fp16.
127+
"nn.global_avg_pool2d",
128+
"nn.adaptive_avg_pool1d",
129+
"nn.adaptive_avg_pool2d",
130+
"nn.adaptive_avg_pool3d",
131+
"sum",
132+
"mean",
134133
]
135134

136135

tests/python/relay/test_to_mixed_precision.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,23 +221,36 @@ def test_do_not_convert_softmax():
221221
b = relay.nn.softmax(a)
222222
mod = tvm.IRModule.from_expr(b)
223223
mod = tvm.relay.transform.InferType()(mod)
224-
225-
mod_params = {
226-
"a": np.random.uniform(-1, 1, size=shape).astype("float32"),
227-
}
228-
output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0)
229-
assert tvm.ir.structural_equal(mod, output_mod)
224+
out_mod = ToMixedPrecision("float16")(mod)
225+
orig_mod = tvm.relay.transform.InferType()(mod)
226+
assert tvm.ir.structural_equal(orig_mod, out_mod)
230227

231228

232229
def test_do_not_convert_arange():
233230
"""Arange is a red listed operation and therefore should never be fp16."""
234231
dtype = "float32"
235232
arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype))
236233
mod = tvm.IRModule.from_expr(arange)
237-
mod = tvm.relay.transform.InferType()(mod)
234+
out_mod = ToMixedPrecision("float16")(mod)
235+
orig_mod = tvm.relay.transform.InferType()(mod)
236+
assert tvm.ir.structural_equal(orig_mod, out_mod)
238237

239-
output_mod = verify_mixed_precision_output_close(mod, {}, atol=0.0, rtol=0)
240-
assert tvm.ir.structural_equal(mod, output_mod)
238+
239+
def test_do_not_convert_summation():
240+
"""Ops that could involve a large summation are not allowed in fp16."""
241+
shape = [1, 3, 16, 16]
242+
a = relay.var("a", shape=shape)
243+
ops = [
244+
relay.sum,
245+
relay.mean,
246+
relay.nn.global_avg_pool2d,
247+
lambda inp: relay.nn.adaptive_avg_pool2d(inp, (1, 1)),
248+
]
249+
for op in ops:
250+
mod = tvm.IRModule.from_expr(op(a))
251+
out_mod = ToMixedPrecision("float16")(mod)
252+
orig_mod = tvm.relay.transform.InferType()(mod)
253+
assert tvm.ir.structural_equal(orig_mod, out_mod)
241254

242255

243256
def test_green_gray_propagates_simple():

0 commit comments

Comments
 (0)