Skip to content

[BUG] RMSNorm implementation seems wrong #2380

Closed
@laclouis5

Description

@laclouis5

Describe the bug

the RMN norm implementation seems to use the variance (torch.var) instead of the root mean square:

def rms_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
):
norm_ndim = len(normalized_shape)
if torch.jit.is_scripting():
# ndim = len(x.shape)
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
assert norm_ndim == 1
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
else:
dims = tuple(range(-1, -norm_ndim - 1, -1))
v = torch.var(x, dim=dims, keepdim=True)
x = x * torch.rsqrt(v + eps)
if weight is not None:
x = x * weight
return x

This yields a different result from the PyTorch RMSNorm implementation in my tests.

Expected behavior
Same results between native Pytorch operator and timm operator.

Desktop (please complete the following information):

  • OS: macOS
  • This repository version: 1.0.12
  • PyTorch version: 2.5

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions