Skip to content

Commit

Permalink
DOC Improve documentation for LayerNorm (pytorch#59178)
Browse files Browse the repository at this point in the history
Summary:
Closes pytorch#51455

I think the current implementation is aggregating over the correct dimensions. The shape of `normalized_shape` is only used to determine the dimensions to aggregate over. The actual values of `normalized_shape` are used when `elementwise_affine=True` to initialize the weights and biases.

This PR updates the docstring to clarify how `normalized_shape` is used. Here is a short script comparing the implementations for tensorflow and pytorch:

```python
import torch
import torch.nn as nn

import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization

rng = np.random.RandomState()
x = rng.randn(10, 20, 64, 64).astype(np.float32)
# slightly non-trival
x[:, :10, ...] = x[:, :10, ...] * 10 + 20
x[:, 10:, ...] = x[:, 10:, ...] * 30 - 100

# Tensorflow Layer norm
x_tf = tf.convert_to_tensor(x)
layer_norm_tf = LayerNormalization(axis=[-3, -2, -1], epsilon=1e-5)
output_tf = layer_norm_tf(x_tf)
output_tf_np = output_tf.numpy()

# PyTorch Layer norm
x_torch = torch.as_tensor(x)
layer_norm_torch = nn.LayerNorm([20, 64, 64], elementwise_affine=False)
output_torch = layer_norm_torch(x_torch)
output_torch_np = output_torch.detach().numpy()

# check tensorflow and pytorch
torch.testing.assert_allclose(output_tf_np, output_torch_np)

# manual comutation
manual_output = ((x_torch - x_torch.mean(dim=(-3, -2, -1), keepdims=True)) /
                 (x_torch.var(dim=(-3, -2, -1), keepdims=True, unbiased=False) + 1e-5).sqrt())

torch.testing.assert_allclose(output_torch, manual_output)
```

To get to the layer normalization as shown here:

<img width="157" alt="Screen Shot 2021-05-29 at 2 13 52 PM" src="https://user-images.githubusercontent.com/5402633/120080691-1e37f100-c088-11eb-9060-4f263e4cd093.png">

One needs to pass in `normalized_shape` with shape `x.dim() - 1` with the size of the channels and all spatial dimensions.

Pull Request resolved: pytorch#59178

Reviewed By: ejguan

Differential Revision: D28931877

Pulled By: jbschlosser

fbshipit-source-id: 193e05205b9085bb190c221428c96d2ca29f2a70
  • Loading branch information
thomasjpfan authored and facebook-github-bot committed Jun 7, 2021
1 parent a30b359 commit 6ff001c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
Binary file added docs/source/_static/img/nn/layer_norm.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 21 additions & 14 deletions torch/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ class LayerNorm(Module):
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
The standard-deviation is calculated via the biased estimator, equivalent to
Expand Down Expand Up @@ -128,17 +129,23 @@ class LayerNorm(Module):
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = nn.LayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = nn.LayerNorm(10)
>>> # Activating the module
>>> output = m(input)
>>> # NLP Example
>>> batch, sentence_length, embedding = 20, 5, 10
>>> embedding = torch.randn(batch, sentence_length, embedding)
>>> layer_norm = nn.LayerNorm(embedding)
>>> # Activate module
>>> layer_norm(embedding)
>>> # Image Example
>>> N, C, H, W = 20, 5, 10, 10
>>> input = torch.randn(N, C, H, W)
>>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
>>> # as shown in the image below
>>> layer_norm = nn.LayerNorm([C, H, W])
>>> output = layer_norm(input)
.. image:: ../_static/img/nn/layer_norm.jpg
:scale: 50 %
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: Tuple[int, ...]
Expand Down

0 comments on commit 6ff001c

Please sign in to comment.