Skip to content

Commit a64b9e3

Browse files
manuelcandalesvmpuri
authored and
vmpuri
committed
Replace RMSNorm by nn.RMSNorm (#1464)
In this PR we replace torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) implementation by nn.RMSNorm, and we bump the PyTorch pin to capture the massive speed up (30x-40x) to RMSNorm on MPS backend introduced in pytorch/pytorch#145301 Preliminary benchmarks on an M1 Pro with 16GB RAM, show a 33% speed up on token generation when running Llama 3.2 1B with 4-bit quantization Motivation: Token generation on MPS backend is currently CPU bound, because of MPSGraph overhead. Surprisingly, the ops that are impacting performance the most are simple ones: mul, copy_, add, where, mean, rsqrt, sub, cat, stack. Experiments on an M1 Pro show that each of those op calls on the MPS backend, has at least 20us of CPU overhead. Also, these ops dominate the graph. For example, in aggregate, these ops are called 770 times for each token, when running Llama 3.2 1B. Compare that to SDPA which is called only 33 times, and linear which is called 113 times. - mul is called 275 times per token - copy_ is called 202 times per token - add is called 97 times per token - where is called 34 times per token - mean is called 33 times per token - rsqrt is called 33 times per token - sub is called 32 times per token - cat is called 32 times per token - stack is called 32 times per token Currently, torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) operation is basically implemented like this: ``` norm = x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) output = norm(x.float()).type_as(x) * weight ``` This means that a single call to torchchat's RMSNorm involves 3 calls to `aten::mul` and calls to `aten::rsqrt`, `aten::mean` and `aten::add`. RMSNorm is called 33 times for each token. Hence, RMSNorm contributes 5 * 33 = 165 of those 770 op calls.
1 parent 8662471 commit a64b9e3

File tree

2 files changed

+6
-20
lines changed

2 files changed

+6
-20
lines changed

install/install_requirements.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ echo "Using pip executable: $PIP_EXECUTABLE"
5151
# NOTE: If a newly-fetched version of the executorch repo changes the value of
5252
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
5353
# package versions.
54-
PYTORCH_NIGHTLY_VERSION=dev20250119
54+
PYTORCH_NIGHTLY_VERSION=dev20250124
5555

5656
# Nightly version for torchvision
57-
VISION_NIGHTLY_VERSION=dev20250119
57+
VISION_NIGHTLY_VERSION=dev20250124
5858

5959
# Nightly version for torchtune
60-
TUNE_NIGHTLY_VERSION=dev20250119
60+
TUNE_NIGHTLY_VERSION=dev20250124
6161

6262
# The pip repository that hosts nightly torch packages. cpu by default.
6363
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly

torchchat/model.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def __init__(self, config: TransformerArgs) -> None:
657657
self.layers[str(layer_id)] = TransformerBlock(config)
658658

659659
if config.stage_idx == config.n_stages - 1:
660-
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
660+
self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
661661
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
662662
if config.tie_word_embeddings:
663663
self.output.weight = self.tok_embeddings.weight
@@ -751,8 +751,8 @@ def __init__(self, config: TransformerArgs) -> None:
751751
super().__init__()
752752
self.attention = Attention(config)
753753
self.feed_forward = FeedForward(config)
754-
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
755-
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
754+
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
755+
self.attention_norm = nn.RMSNorm(config.dim, config.norm_eps)
756756
# None for llama architecture, set for granite architectures
757757
self.residual_multiplier = (
758758
config.residual_multiplier
@@ -928,20 +928,6 @@ def forward(self, x: Tensor) -> Tensor:
928928
return self.w2(F.silu(self.w1(x)) * self.w3(x))
929929

930930

931-
class RMSNorm(nn.Module):
932-
def __init__(self, dim: int, eps: float = 1e-5):
933-
super().__init__()
934-
self.eps = eps
935-
self.weight = nn.Parameter(torch.ones(dim))
936-
937-
def _norm(self, x):
938-
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
939-
940-
def forward(self, x: Tensor) -> Tensor:
941-
output = self._norm(x.float()).type_as(x)
942-
return output * self.weight
943-
944-
945931
def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]):
946932
# Check for the presence of the required keys
947933
required_keys = {

0 commit comments

Comments
 (0)