2525import paddle
2626
2727
28+ def group_norm_naive_for_general_dimension (x , scale , bias , epsilon , groups ):
29+ # original version group norm only support 4-D tensor
30+ # this function generalizes to support differnt dimensions tensor (>= 2-D)
31+ input_shape = x .shape
32+ N , C = x .shape [0 ], x .shape [1 ]
33+ G = groups
34+ x = x .reshape ((N * G , - 1 ))
35+ mean = np .mean (x , axis = 1 , keepdims = True )
36+ var = np .var (x , axis = 1 , keepdims = True )
37+ output = (x - mean ) / np .sqrt (var + epsilon )
38+ output = output .reshape (input_shape ) * scale .reshape (
39+ (- 1 , 1 , 1 )) + bias .reshape ((- 1 , 1 , 1 ))
40+ return output
41+
42+
2843class TestDygraphGroupNormv2 (unittest .TestCase ):
2944 def test_dygraph (self ):
3045 places = [fluid .CPUPlace ()]
3146 if core .is_compiled_with_cuda () and core .op_support_gpu ("group_norm" ):
3247 places .append (fluid .CUDAPlace (0 ))
48+ shapes = [[2 , 2 , 2 , 2 ], [2 , 2 , 4 ], [4 , 2 ], [4 , 2 , 6 , 6 , 2 ],
49+ [2 , 2 , 2 , 2 , 2 , 2 ]]
3350 for p in places :
34- shape = [2 , 2 , 2 , 2 ]
3551
3652 def compute_v1 (x ):
3753 with fluid .dygraph .guard (p ):
@@ -62,23 +78,26 @@ def attr_data_format():
6278
6379 self .assertRaises (ValueError , attr_data_format )
6480
65- x = np .random .randn (* shape ).astype ("float32" )
66- y1 = compute_v1 (x )
67- y2 = compute_v2 (x )
68- result = np .allclose (y1 , y2 , atol = 1e-5 )
69- if not result :
70- print ("y1:" , y1 , "\t y2:" , y2 )
71- self .assertTrue (result )
72- test_weight_bias_false ()
73- test_nn_exception ()
81+ for shape in shapes :
82+ x = np .random .randn (* shape ).astype ("float32" )
83+ y1 = compute_v1 (x )
84+ y2 = compute_v2 (x )
85+ result = np .allclose (y1 , y2 , atol = 1e-5 )
86+ if not result :
87+ print ("y1:" , y1 , "\t y2:" , y2 )
88+ self .assertTrue (result )
89+ test_weight_bias_false ()
90+ test_nn_exception ()
7491
7592 def test_static (self ):
93+ paddle .enable_static ()
7694 places = [fluid .CPUPlace ()]
7795 if core .is_compiled_with_cuda () and core .op_support_gpu ("group_norm" ):
7896 places .append (fluid .CUDAPlace (0 ))
97+ shapes = [[2 , 6 , 2 , 2 ], [2 , 6 , 4 ], [4 , 6 ], [4 , 6 , 6 , 6 , 2 ],
98+ [4 , 6 , 2 , 2 , 2 , 2 ]]
7999 for p in places :
80100 exe = fluid .Executor (p )
81- shape = [2 , 6 , 2 , 2 ]
82101
83102 def compute_v1 (x_np ):
84103 with program_guard (Program (), Program ()):
@@ -98,10 +117,39 @@ def compute_v2(x_np):
98117 r = exe .run (feed = {'x' : x_np }, fetch_list = [y ])[0 ]
99118 return r
100119
101- x = np .random .randn (* shape ).astype ("float32" )
102- y1 = compute_v1 (x )
103- y2 = compute_v2 (x )
104- self .assertTrue (np .allclose (y1 , y2 , atol = 1e-5 ))
120+ for shape in shapes :
121+ x = np .random .randn (* shape ).astype ("float32" )
122+ y1 = compute_v1 (x )
123+ y2 = compute_v2 (x )
124+ self .assertTrue (np .allclose (y1 , y2 , atol = 1e-5 ))
125+
126+
127+ class TestGroupNormAPIV2_With_General_Dimensions (unittest .TestCase ):
128+ def test_numerical_accuracy (self ):
129+ paddle .disable_static ()
130+ shapes = [(2 , 6 ), (2 , 6 , 4 ), (2 , 6 , 4 , 4 ), (2 , 6 , 6 , 6 , 2 ), (2 , 6 , 6 , 6 ,
131+ 2 , 3 )]
132+ places = [fluid .CPUPlace ()]
133+ if core .is_compiled_with_cuda () and core .op_support_gpu ("group_norm" ):
134+ places .append (fluid .CUDAPlace (0 ))
135+
136+ for place in places :
137+ for shape in shapes :
138+ scale = np .array ([1 ]).astype ("float32" )
139+ bias = np .array ([0 ]).astype ("float32" )
140+ data = np .random .random (shape ).astype ("float32" )
141+ expect_res1 = group_norm_naive_for_general_dimension (
142+ data , scale , bias , epsilon = 1e-5 , groups = 6 )
143+ expect_res2 = group_norm_naive_for_general_dimension (
144+ data , scale , bias , epsilon = 1e-5 , groups = 2 )
145+
146+ gn1 = paddle .nn .GroupNorm (num_channels = 6 , num_groups = 6 )
147+ gn2 = paddle .nn .GroupNorm (num_channels = 6 , num_groups = 2 )
148+ data_pd = paddle .to_tensor (data )
149+ result1 = gn1 (data_pd ).numpy ()
150+ result2 = gn2 (data_pd ).numpy ()
151+ self .assertTrue (np .allclose (result1 , expect_res1 , atol = 1e-5 ))
152+ self .assertTrue (np .allclose (result2 , expect_res2 , atol = 1e-5 ))
105153
106154
107155if __name__ == '__main__' :
0 commit comments