@@ -847,6 +847,94 @@ def native_group_norm(
847
847
return (out , mean , rstd )
848
848
849
849
850
+ @register_decomposition (aten .native_group_norm_backward )
851
+ @pw_cast_for_opmath
852
+ def native_group_norm_backward (
853
+ grad_output : Tensor ,
854
+ input : Tensor ,
855
+ mean : Tensor ,
856
+ rstd : Tensor ,
857
+ gamma : Optional [Tensor ],
858
+ N : int ,
859
+ C : int ,
860
+ HxW : int ,
861
+ group : int ,
862
+ output_mask : List [bool ],
863
+ ) -> Tuple [Optional [Tensor ], Optional [Tensor ], Optional [Tensor ]]:
864
+ utils .check_same_device (
865
+ grad_output , input , mean , rstd , allow_cpu_scalar_tensors = False
866
+ )
867
+ utils .check_same_shape (input , grad_output , allow_cpu_scalar_tensors = False )
868
+ utils .check_same_shape (mean , rstd , allow_cpu_scalar_tensors = False )
869
+ utils .check (
870
+ input .numel () == N * C * HxW ,
871
+ lambda : f"Expect input to have { N * C * HxW } elements" ,
872
+ )
873
+ utils .check (
874
+ mean .shape == (N , group ),
875
+ lambda : f"Expect mean to have shape ({ N } , { group } , but got { mean .shape } " ,
876
+ )
877
+ utils .check (
878
+ gamma is None or gamma .numel () == C ,
879
+ lambda : f"Expect gamma to have { C } elements but got { gamma .numel () if gamma is not None else - 1 } " ,
880
+ )
881
+
882
+ cpg , _rem = divmod (C , group )
883
+ utils .check (
884
+ _rem == 0 ,
885
+ lambda : f"Expect number of channels { C } to be evenly-divisible by number of groups { group } " ,
886
+ )
887
+
888
+ # Compute Internal gradients
889
+ ds = torch .mul (grad_output , input ).view (N , C , HxW ).sum (dim = [2 ])
890
+ db = grad_output .view (N , C , HxW ).sum (dim = [2 ])
891
+
892
+ d_input : Optional [Tensor ] = None
893
+ d_gamma : Optional [Tensor ] = None
894
+ d_bias : Optional [Tensor ] = None
895
+ if output_mask [0 ]:
896
+ s = 1.0 / (HxW * cpg )
897
+ if gamma is not None :
898
+ ds_val = torch .mul (ds , gamma .unsqueeze (0 )).reshape (N , group , cpg ).sum (2 )
899
+ db_val = torch .mul (db , gamma .unsqueeze (0 )).reshape (N , group , cpg ).sum (2 )
900
+ c1 = torch .mul (
901
+ rstd .unsqueeze (- 1 ),
902
+ gamma .reshape (1 , group , cpg ),
903
+ )
904
+ else :
905
+ ds_val = ds .reshape (N , group , cpg ).sum (2 )
906
+ db_val = db .reshape (N , group , cpg ).sum (2 )
907
+ c1 = torch .mul (
908
+ rstd .unsqueeze (- 1 ),
909
+ torch .ones ((1 , group , cpg ), device = rstd .device ),
910
+ )
911
+ c2 = (db_val * mean - ds_val ) * rstd * rstd * rstd * s
912
+ c3 = - c2 * mean - db_val * rstd * s
913
+
914
+ c1 = c1 .unsqueeze (- 1 )
915
+ c2 = _unsqueeze_to_dim (c2 , 4 )
916
+ c3 = _unsqueeze_to_dim (c3 , 4 )
917
+ d_input = (
918
+ torch .mul (grad_output .reshape (N , group , cpg , HxW ), c1 )
919
+ + torch .mul (input .reshape (N , group , cpg , HxW ), c2 )
920
+ + c3
921
+ )
922
+ d_input = d_input .reshape (input .shape ).to (input .dtype )
923
+ if output_mask [1 ]:
924
+ d_gamma = (
925
+ (
926
+ (ds .view (N , group , cpg ) - db .view (N , group , cpg ) * mean .unsqueeze (- 1 ))
927
+ * rstd .unsqueeze (- 1 )
928
+ )
929
+ .sum (dim = [0 ])
930
+ .reshape (C )
931
+ )
932
+ if output_mask [2 ]:
933
+ d_bias = db .sum (dim = [0 ])
934
+
935
+ return (d_input , d_gamma , d_bias )
936
+
937
+
850
938
def _maybe_cast (x : Optional [Tensor ], dtype ) -> Optional [Tensor ]:
851
939
if x is not None :
852
940
return x .to (dtype )
0 commit comments