Skip to content

Commit

Permalink
remove layernorm duplicate refs
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Oct 19, 2024
1 parent 02579ae commit c5ef395
Show file tree
Hide file tree
Showing 120 changed files with 165 additions and 280 deletions.
53 changes: 18 additions & 35 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -28978,43 +28978,28 @@ This version of the operator has been available since version 23 of the default

### <a name="RMSNormalization-23"></a>**RMSNormalization-23**</a>

This is RMS normalization defined in ONNX as function.
The overall computation can be split into two stages.
The first stage is standardization, which makes the
normalized elements have zero mean and unit variances.
The computation required by standardization can be
This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467.
The overall computation can be split into two stages. The first stage is standardization, which makes the
normalized elements have zero mean and unit variances. The root mean squared norm is taken over the last D dimensions,
where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape),
the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be
described by the following equations.
```
Mean = ReduceMean<axes=normalized_axes>(X)
D = Sub(X, Mean)
DD = Mul(D, D)
Var = ReduceMean<axes=normalized_axes>(DD)
VarEps = Add(Var, epsilon)
StdDev = Sqrt(VarEps)
InvStdDev = Reciprocal(StdDev)
Normalized = Mul(D, InvStdDev)
XSquared = Mul(X, X)
XSquaredMean = ReduceMean<axes=normalized_axes>(XSquared)
RMS = Sqrt(XSquaredMean)
RMSEps = Add(RMS, epsilon)
SqrtRMS = Sqrt(RMSEps)
Normalized = Div(X, SqrtRMS)
```
where `normalized_axes` is `[axis, ..., rank of X - 1]`.
The variables `Var` and `StdDev` stand for variance and
standard deviation, respectively. The second output is
`Mean` and the last one is `InvStdDev`.
Depending on `stash_type` attribute, the actual computation
must happen in different floating-point precision.
For example, if `stash_type` is 1, this operator casts
all input variables to 32-bit float, perform the computation, and
finally cast `Normalized` back to the original type of `X`.
The second stage then scales and shifts the outcome of the
first stage using
where `normalized_axes` is `[axis, ..., rank of X - 1]`. The variables `RMS` stand for root mean square,
The second stage then scales and shifts the outcome of the first stage using:
```
Y= Mul(Normalized, Scale)
```
The second stage doesn't depends on `stash_type`.
All equations are in [this syntax](https://github.com/onnx/onnx/blob/main/docs/Syntax.md).
The same variable (i.e., input, output, and attribute) uses
the same name in the equations above and this operator's definition.
Let `d[i]` indicate the i-th dimension of `X`.
If `X`'s shape is `[d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]`,
the shape of `Mean` and `InvStdDev` is `[d[0], ..., d[axis-1], 1, ..., 1]`.
the shape of `RMS` is `[d[0], ..., d[axis-1], 1, ..., 1]`.
`Y` and `X` have the same shape. This operator supports unidirectional broadcasting
(tensors `Scale` and `B` should be unidirectional broadcastable to tensor `X`);
for more details please check [the doc](Broadcasting.md).
Expand All @@ -29031,25 +29016,23 @@ This version of the operator has been available since version 23 of the default
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>stash_type</tt> : int (default is 1)</dt>
<dd>type used for stash mean/inv_std_var</dd>
<dd>Type of Mean and InvStdDev. This also specifies stage one's computation precision.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input data tensor from the previous layer.</dd>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels. The root mean squared norm is taken over the last D dimensions, D is determined by the axis attribute.</dd>
<dt><tt>scale</tt> : V</dt>
<dd>Scale tensor.</dd>
</dl>

#### Outputs (1 - 2)
#### Outputs

<dl>
<dt><tt>Y</tt> : V</dt>
<dd>Output data tensor.</dd>
<dt><tt>InvStdVar</tt> (optional) : U</dt>
<dd>Saved inverse standard variance used during training to speed up gradient computation.</dd>
<dd>Output data tensor. Same shape as X</dd>
</dl>

#### Type Constraints
Expand Down
79 changes: 31 additions & 48 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -21103,43 +21103,28 @@ expect(

### <a name="RMSNormalization"></a><a name="rmsnormalization">**RMSNormalization**</a>

This is RMS normalization defined in ONNX as function.
The overall computation can be split into two stages.
The first stage is standardization, which makes the
normalized elements have zero mean and unit variances.
The computation required by standardization can be
This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467.
The overall computation can be split into two stages. The first stage is standardization, which makes the
normalized elements have zero mean and unit variances. The root mean squared norm is taken over the last D dimensions,
where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape),
the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be
described by the following equations.
```
Mean = ReduceMean<axes=normalized_axes>(X)
D = Sub(X, Mean)
DD = Mul(D, D)
Var = ReduceMean<axes=normalized_axes>(DD)
VarEps = Add(Var, epsilon)
StdDev = Sqrt(VarEps)
InvStdDev = Reciprocal(StdDev)
Normalized = Mul(D, InvStdDev)
XSquared = Mul(X, X)
XSquaredMean = ReduceMean<axes=normalized_axes>(XSquared)
RMS = Sqrt(XSquaredMean)
RMSEps = Add(RMS, epsilon)
SqrtRMS = Sqrt(RMSEps)
Normalized = Div(X, SqrtRMS)
```
where `normalized_axes` is `[axis, ..., rank of X - 1]`.
The variables `Var` and `StdDev` stand for variance and
standard deviation, respectively. The second output is
`Mean` and the last one is `InvStdDev`.
Depending on `stash_type` attribute, the actual computation
must happen in different floating-point precision.
For example, if `stash_type` is 1, this operator casts
all input variables to 32-bit float, perform the computation, and
finally cast `Normalized` back to the original type of `X`.
The second stage then scales and shifts the outcome of the
first stage using
where `normalized_axes` is `[axis, ..., rank of X - 1]`. The variables `RMS` stand for root mean square,
The second stage then scales and shifts the outcome of the first stage using:
```
Y= Mul(Normalized, Scale)
```
The second stage doesn't depends on `stash_type`.
All equations are in [this syntax](https://github.com/onnx/onnx/blob/main/docs/Syntax.md).
The same variable (i.e., input, output, and attribute) uses
the same name in the equations above and this operator's definition.
Let `d[i]` indicate the i-th dimension of `X`.
If `X`'s shape is `[d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]`,
the shape of `Mean` and `InvStdDev` is `[d[0], ..., d[axis-1], 1, ..., 1]`.
the shape of `RMS` is `[d[0], ..., d[axis-1], 1, ..., 1]`.
`Y` and `X` have the same shape. This operator supports unidirectional broadcasting
(tensors `Scale` and `B` should be unidirectional broadcastable to tensor `X`);
for more details please check [the doc](Broadcasting.md).
Expand All @@ -21156,25 +21141,23 @@ This version of the operator has been available since version 23 of the default
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>stash_type</tt> : int (default is 1)</dt>
<dd>type used for stash mean/inv_std_var</dd>
<dd>Type of Mean and InvStdDev. This also specifies stage one's computation precision.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input data tensor from the previous layer.</dd>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels. The root mean squared norm is taken over the last D dimensions, D is determined by the axis attribute.</dd>
<dt><tt>scale</tt> : V</dt>
<dd>Scale tensor.</dd>
</dl>

#### Outputs (1 - 2)
#### Outputs

<dl>
<dt><tt>Y</tt> : V</dt>
<dd>Output data tensor.</dd>
<dt><tt>InvStdVar</tt> (optional) : U</dt>
<dd>Saved inverse standard variance used during training to speed up gradient computation.</dd>
<dd>Output data tensor. Same shape as X</dd>
</dl>

#### Type Constraints
Expand All @@ -21200,12 +21183,12 @@ X = np.random.randn(3, 4).astype(np.float32)
def case(axis: int) -> None:
normalized_shape = calculate_normalized_shape(X.shape, axis)
W = np.random.randn(*normalized_shape).astype(np.float32)
Y, inv_std_dev = _rms_normalization(X, W, axis=axis)
Y = _rms_normalization(X, W, axis=axis)

node = onnx.helper.make_node(
"RMSNormalization",
inputs=["X", "W"],
outputs=["Y", "InvStdDev"],
outputs=["Y"],
axis=axis,
)

Expand All @@ -21214,7 +21197,7 @@ def case(axis: int) -> None:
else:
name = f"test_rms_normalization_2d_axis{axis}"

expect(node, inputs=[X, W], outputs=[Y, inv_std_dev], name=name)
expect(node, inputs=[X, W], outputs=[Y], name=name)

for i in range(len(X.shape)):
case(i)
Expand All @@ -21234,11 +21217,11 @@ X = np.random.randn(2, 3, 5).astype(np.float32)
def case(axis: int) -> None:
normalized_shape = calculate_normalized_shape(X.shape, axis)
W = np.random.randn(*normalized_shape).astype(np.float32)
Y, inv_std_dev = _rms_normalization(X, W, axis, epsilon)
Y = _rms_normalization(X, W, axis, epsilon)
node = onnx.helper.make_node(
"RMSNormalization",
inputs=["X", "W"],
outputs=["Y", "InvStdDev"],
outputs=["Y"],
axis=axis,
epsilon=epsilon,
)
Expand All @@ -21248,7 +21231,7 @@ def case(axis: int) -> None:
else:
name = f"test_rms_normalization_3d_axis{axis}_epsilon"

expect(node, inputs=[X, W], outputs=[Y, inv_std_dev], name=name)
expect(node, inputs=[X, W], outputs=[Y], name=name)

for i in range(len(X.shape)):
case(i)
Expand All @@ -21264,23 +21247,23 @@ for i in range(len(X.shape)):
```python
X = np.random.randn(2, 3, 4, 5).astype(np.float32)

# Default axis in LayerNormalization is -1.
# Default axis in RMSNormalization is -1.
normalized_shape = calculate_normalized_shape(X.shape, -1)
W = np.random.randn(*normalized_shape).astype(np.float32)
# Axis is default to -1 in the reference implementation.
Y, inv_std_dev = _rms_normalization(X, W)
Y = _rms_normalization(X, W)

# Not specifying axis attribute means -1.
node = onnx.helper.make_node(
"RMSNormalization",
inputs=["X", "W"],
outputs=["Y", "InvStdDev"],
outputs=["Y"],
)

expect(
node,
inputs=[X, W],
outputs=[Y, inv_std_dev],
outputs=[Y],
name="test_rms_normalization_default_axis",
)
```
Expand All @@ -21297,12 +21280,12 @@ X = np.random.randn(2, 3, 4, 5).astype(np.float32)
def case(axis: int) -> None:
normalized_shape = calculate_normalized_shape(X.shape, axis)
W = np.random.randn(*normalized_shape).astype(np.float32)
Y, inv_std_dev = _rms_normalization(X, W, axis)
Y = _rms_normalization(X, W, axis)

node = onnx.helper.make_node(
"RMSNormalization",
inputs=["X", "W"],
outputs=["Y", "InvStdDev"],
outputs=["Y"],
axis=axis,
)

Expand All @@ -21311,7 +21294,7 @@ def case(axis: int) -> None:
else:
name = f"test_rms_normalization_4d_axis{axis}"

expect(node, inputs=[X, W], outputs=[Y, inv_std_dev], name=name)
expect(node, inputs=[X, W], outputs=[Y], name=name)

for i in range(len(X.shape)):
case(i)
Expand Down
26 changes: 13 additions & 13 deletions docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -14510,12 +14510,12 @@ X = np.random.randn(3, 4).astype(np.float32)
def case(axis: int) -> None:
normalized_shape = calculate_normalized_shape(X.shape, axis)
W = np.random.randn(*normalized_shape).astype(np.float32)
Y, inv_std_dev = _rms_normalization(X, W, axis=axis)
Y = _rms_normalization(X, W, axis=axis)

node = onnx.helper.make_node(
"RMSNormalization",
inputs=["X", "W"],
outputs=["Y", "InvStdDev"],
outputs=["Y"],
axis=axis,
)

Expand All @@ -14524,7 +14524,7 @@ def case(axis: int) -> None:
else:
name = f"test_rms_normalization_2d_axis{axis}"

expect(node, inputs=[X, W], outputs=[Y, inv_std_dev], name=name)
expect(node, inputs=[X, W], outputs=[Y], name=name)

for i in range(len(X.shape)):
case(i)
Expand All @@ -14542,11 +14542,11 @@ X = np.random.randn(2, 3, 5).astype(np.float32)
def case(axis: int) -> None:
normalized_shape = calculate_normalized_shape(X.shape, axis)
W = np.random.randn(*normalized_shape).astype(np.float32)
Y, inv_std_dev = _rms_normalization(X, W, axis, epsilon)
Y = _rms_normalization(X, W, axis, epsilon)
node = onnx.helper.make_node(
"RMSNormalization",
inputs=["X", "W"],
outputs=["Y", "InvStdDev"],
outputs=["Y"],
axis=axis,
epsilon=epsilon,
)
Expand All @@ -14556,7 +14556,7 @@ def case(axis: int) -> None:
else:
name = f"test_rms_normalization_3d_axis{axis}_epsilon"

expect(node, inputs=[X, W], outputs=[Y, inv_std_dev], name=name)
expect(node, inputs=[X, W], outputs=[Y], name=name)

for i in range(len(X.shape)):
case(i)
Expand All @@ -14570,23 +14570,23 @@ for i in range(len(X.shape)):
```python
X = np.random.randn(2, 3, 4, 5).astype(np.float32)

# Default axis in LayerNormalization is -1.
# Default axis in RMSNormalization is -1.
normalized_shape = calculate_normalized_shape(X.shape, -1)
W = np.random.randn(*normalized_shape).astype(np.float32)
# Axis is default to -1 in the reference implementation.
Y, inv_std_dev = _rms_normalization(X, W)
Y = _rms_normalization(X, W)

# Not specifying axis attribute means -1.
node = onnx.helper.make_node(
"RMSNormalization",
inputs=["X", "W"],
outputs=["Y", "InvStdDev"],
outputs=["Y"],
)

expect(
node,
inputs=[X, W],
outputs=[Y, inv_std_dev],
outputs=[Y],
name="test_rms_normalization_default_axis",
)
```
Expand All @@ -14601,12 +14601,12 @@ X = np.random.randn(2, 3, 4, 5).astype(np.float32)
def case(axis: int) -> None:
normalized_shape = calculate_normalized_shape(X.shape, axis)
W = np.random.randn(*normalized_shape).astype(np.float32)
Y, inv_std_dev = _rms_normalization(X, W, axis)
Y = _rms_normalization(X, W, axis)

node = onnx.helper.make_node(
"RMSNormalization",
inputs=["X", "W"],
outputs=["Y", "InvStdDev"],
outputs=["Y"],
axis=axis,
)

Expand All @@ -14615,7 +14615,7 @@ def case(axis: int) -> None:
else:
name = f"test_rms_normalization_4d_axis{axis}"

expect(node, inputs=[X, W], outputs=[Y, inv_std_dev], name=name)
expect(node, inputs=[X, W], outputs=[Y], name=name)

for i in range(len(X.shape)):
case(i)
Expand Down
Loading

0 comments on commit c5ef395

Please sign in to comment.