@@ -221,23 +221,36 @@ def test_do_not_convert_softmax():
221
221
b = relay .nn .softmax (a )
222
222
mod = tvm .IRModule .from_expr (b )
223
223
mod = tvm .relay .transform .InferType ()(mod )
224
-
225
- mod_params = {
226
- "a" : np .random .uniform (- 1 , 1 , size = shape ).astype ("float32" ),
227
- }
228
- output_mod = verify_mixed_precision_output_close (mod , mod_params , atol = 0.0 , rtol = 0 )
229
- assert tvm .ir .structural_equal (mod , output_mod )
224
+ out_mod = ToMixedPrecision ("float16" )(mod )
225
+ orig_mod = tvm .relay .transform .InferType ()(mod )
226
+ assert tvm .ir .structural_equal (orig_mod , out_mod )
230
227
231
228
232
229
def test_do_not_convert_arange ():
233
230
"""Arange is a red listed operation and therefore should never be fp16."""
234
231
dtype = "float32"
235
232
arange = relay .arange (relay .const (1 , dtype ), relay .const (128 , dtype ))
236
233
mod = tvm .IRModule .from_expr (arange )
237
- mod = tvm .relay .transform .InferType ()(mod )
234
+ out_mod = ToMixedPrecision ("float16" )(mod )
235
+ orig_mod = tvm .relay .transform .InferType ()(mod )
236
+ assert tvm .ir .structural_equal (orig_mod , out_mod )
238
237
239
- output_mod = verify_mixed_precision_output_close (mod , {}, atol = 0.0 , rtol = 0 )
240
- assert tvm .ir .structural_equal (mod , output_mod )
238
+
239
+ def test_do_not_convert_summation ():
240
+ """Ops that could involve a large summation are not allowed in fp16."""
241
+ shape = [1 , 3 , 16 , 16 ]
242
+ a = relay .var ("a" , shape = shape )
243
+ ops = [
244
+ relay .sum ,
245
+ relay .mean ,
246
+ relay .nn .global_avg_pool2d ,
247
+ lambda inp : relay .nn .adaptive_avg_pool2d (inp , (1 , 1 )),
248
+ ]
249
+ for op in ops :
250
+ mod = tvm .IRModule .from_expr (op (a ))
251
+ out_mod = ToMixedPrecision ("float16" )(mod )
252
+ orig_mod = tvm .relay .transform .InferType ()(mod )
253
+ assert tvm .ir .structural_equal (orig_mod , out_mod )
241
254
242
255
243
256
def test_green_gray_propagates_simple ():
0 commit comments