Skip to content

Commit

Permalink
Simplify LayerNorm2d implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Jan 31, 2022
1 parent 1ab9030 commit cafa02d
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions torchvision/prototype/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,11 @@


class LayerNorm2d(nn.LayerNorm):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.channels_last = kwargs.pop("channels_last", False)
super().__init__(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
# TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298
if not self.channels_last:
x = x.permute(0, 2, 3, 1)
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
if not self.channels_last:
x = x.permute(0, 3, 1, 2)
x = x.permute(0, 3, 1, 2)
return x


Expand Down

0 comments on commit cafa02d

Please sign in to comment.