@@ -1263,6 +1263,38 @@ def verify(shape, axis=-1):
1263
1263
verify ((2 , 5 , 6 ))
1264
1264
1265
1265
1266
+ @tvm .testing .uses_gpu
1267
+ def test_forward_group_norm ():
1268
+ def verify (shape , num_groups = 1 ):
1269
+ x = np .random .uniform (size = shape ).astype ("float32" )
1270
+ gamma = np .random .uniform (size = (shape [1 ])).astype ("float32" )
1271
+ beta = np .random .uniform (size = (shape [1 ])).astype ("float32" )
1272
+ ref_res = mx .nd .GroupNorm (
1273
+ data = mx .nd .array (x ),
1274
+ gamma = mx .nd .array (gamma ),
1275
+ beta = mx .nd .array (beta ),
1276
+ num_groups = num_groups ,
1277
+ )
1278
+ mx_sym = mx .sym .GroupNorm (
1279
+ mx .sym .var ("x" ), mx .sym .var ("gamma" ), mx .sym .var ("beta" ), num_groups = num_groups
1280
+ )
1281
+ shape_dict = {"x" : x .shape , "gamma" : gamma .shape , "beta" : beta .shape }
1282
+ mod , _ = relay .frontend .from_mxnet (mx_sym , shape_dict )
1283
+ for target , ctx in tvm .testing .enabled_targets ():
1284
+ for kind in ["graph" , "debug" ]:
1285
+ intrp = relay .create_executor (kind , mod = mod , ctx = ctx , target = target )
1286
+ op_res = intrp .evaluate ()(x , gamma , beta )
1287
+ tvm .testing .assert_allclose (
1288
+ op_res .asnumpy (), ref_res .asnumpy (), rtol = 1e-3 , atol = 1e-5
1289
+ )
1290
+
1291
+ verify ((1 , 4 , 2 ), num_groups = 4 )
1292
+ # TODO(trevmorr): MXNet GroupNorm implementation is bugged for cases when num_groups != num_channels
1293
+ # https://github.com/apache/incubator-mxnet/pull/18199
1294
+ # verify((1, 4, 2, 3), num_groups=2)
1295
+ # verify((1, 4, 2, 3))
1296
+
1297
+
1266
1298
@tvm .testing .uses_gpu
1267
1299
def test_forward_one_hot ():
1268
1300
def verify (indices_shape , depth , on_value , off_value , dtype ):
0 commit comments