Skip to content

Commit b8e1c54

Browse files
malfetpytorchmergebot
authored andcommitted
[Prim] Implement group_norm_backward (pytorch#84037)
Test plan: CI, i.e. `python3 test_decomp.py -v -k test_comprehensive_nn_functional_group_norm` plus: ``` #!/usr/bin/env python3.8 import torch func = torch.ops.aten.native_group_norm_backward.default decomp = torch._decomp.decomposition_table[func] for args in ( (torch.rand(1, 6, 3), torch.rand(1, 6, 3), torch.rand(1, 2), torch.rand(1, 2), torch.rand(6), 1, 6, 3, 2, [True, True, True]), (torch.rand(64, 768, 7, 7), torch.rand(64, 768, 7, 7), torch.rand(64, 1), torch.rand(64, 1), torch.rand(768), 64, 768, 49, 1, [True, True, True])): nrc=func(*args) drc=decomp(*args) for i in range(len(nrc)): print(i, torch.max(nrc[i]-drc[i])) print(all(torch.allclose(x, y) for (x, y) in zip(nrc, drc))) ``` Pull Request resolved: pytorch#84037 Approved by: https://github.com/Chillee, https://github.com/ngimel
1 parent 2436cf8 commit b8e1c54

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

torch/_decomp/decompositions.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,94 @@ def native_group_norm(
847847
return (out, mean, rstd)
848848

849849

850+
@register_decomposition(aten.native_group_norm_backward)
851+
@pw_cast_for_opmath
852+
def native_group_norm_backward(
853+
grad_output: Tensor,
854+
input: Tensor,
855+
mean: Tensor,
856+
rstd: Tensor,
857+
gamma: Optional[Tensor],
858+
N: int,
859+
C: int,
860+
HxW: int,
861+
group: int,
862+
output_mask: List[bool],
863+
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
864+
utils.check_same_device(
865+
grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
866+
)
867+
utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
868+
utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
869+
utils.check(
870+
input.numel() == N * C * HxW,
871+
lambda: f"Expect input to have { N * C * HxW} elements",
872+
)
873+
utils.check(
874+
mean.shape == (N, group),
875+
lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
876+
)
877+
utils.check(
878+
gamma is None or gamma.numel() == C,
879+
lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
880+
)
881+
882+
cpg, _rem = divmod(C, group)
883+
utils.check(
884+
_rem == 0,
885+
lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
886+
)
887+
888+
# Compute Internal gradients
889+
ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2])
890+
db = grad_output.view(N, C, HxW).sum(dim=[2])
891+
892+
d_input: Optional[Tensor] = None
893+
d_gamma: Optional[Tensor] = None
894+
d_bias: Optional[Tensor] = None
895+
if output_mask[0]:
896+
s = 1.0 / (HxW * cpg)
897+
if gamma is not None:
898+
ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
899+
db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
900+
c1 = torch.mul(
901+
rstd.unsqueeze(-1),
902+
gamma.reshape(1, group, cpg),
903+
)
904+
else:
905+
ds_val = ds.reshape(N, group, cpg).sum(2)
906+
db_val = db.reshape(N, group, cpg).sum(2)
907+
c1 = torch.mul(
908+
rstd.unsqueeze(-1),
909+
torch.ones((1, group, cpg), device=rstd.device),
910+
)
911+
c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s
912+
c3 = -c2 * mean - db_val * rstd * s
913+
914+
c1 = c1.unsqueeze(-1)
915+
c2 = _unsqueeze_to_dim(c2, 4)
916+
c3 = _unsqueeze_to_dim(c3, 4)
917+
d_input = (
918+
torch.mul(grad_output.reshape(N, group, cpg, HxW), c1)
919+
+ torch.mul(input.reshape(N, group, cpg, HxW), c2)
920+
+ c3
921+
)
922+
d_input = d_input.reshape(input.shape).to(input.dtype)
923+
if output_mask[1]:
924+
d_gamma = (
925+
(
926+
(ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1))
927+
* rstd.unsqueeze(-1)
928+
)
929+
.sum(dim=[0])
930+
.reshape(C)
931+
)
932+
if output_mask[2]:
933+
d_bias = db.sum(dim=[0])
934+
935+
return (d_input, d_gamma, d_bias)
936+
937+
850938
def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
851939
if x is not None:
852940
return x.to(dtype)

0 commit comments

Comments
 (0)