diff --git a/src/kernl/optimizer/dynamo_backend.py b/src/kernl/optimizer/dynamo_backend.py index 0c5fafe9..ce04ec48 100644 --- a/src/kernl/optimizer/dynamo_backend.py +++ b/src/kernl/optimizer/dynamo_backend.py @@ -17,7 +17,7 @@ from kernl.optimizer.attention import fuse_attention_pattern_1, fuse_attention_pattern_2 from kernl.optimizer.dropout import remove_dropout -from kernl.optimizer.layer_norm import replace_layer_norm +from kernl.optimizer.layer_norm import replace_layer_norm, replace_layer_norm_rms from kernl.optimizer.linear import replace_all_linear from kernl.optimizer.normalizer import normalize_operators @@ -29,4 +29,5 @@ def dynamo_backend_ofi(gm: torch.fx.GraphModule, assume_causal=False): fuse_attention_pattern_2(gm, assume_causal) replace_all_linear(gm) replace_layer_norm(gm) + replace_layer_norm_rms(gm) return gm diff --git a/src/kernl/optimizer/layer_norm.py b/src/kernl/optimizer/layer_norm.py index c3f33f59..05a98689 100644 --- a/src/kernl/optimizer/layer_norm.py +++ b/src/kernl/optimizer/layer_norm.py @@ -15,7 +15,7 @@ import torch -from kernl.implementations.layer_norm import layer_norm +from kernl.implementations.layer_norm import _layer_norm_fwd_fused_single_pass, layer_norm from kernl.utils.extended_matcher import replace_pattern @@ -29,7 +29,12 @@ def layer_norm_wrapper(v: torch.Tensor, layernorm: torch.nn.LayerNorm): return layer_norm(v, layernorm.weight, layernorm.bias, layernorm.eps) +def layer_norm_rms_wrapper(v: torch.Tensor, weight: torch.Tensor, eps: float): + return layer_norm(v, weight, None, eps, _layer_norm_fwd_fused_single_pass, use_rms_norm=True) + + torch.fx.wrap("layer_norm_wrapper") +torch.fx.wrap("layer_norm_rms_wrapper") def replace_layer_norm(gm: torch.fx.GraphModule): @@ -50,3 +55,20 @@ def forward(self, v): return layer_norm_wrapper(v, self.layernorm) replace_pattern(gm, Pattern(), Replacement()) + + +def replace_layer_norm_rms(gm: torch.fx.GraphModule): + def pattern(v, weight): + to_38 = v.to(torch.float32) + pow_32 = to_38.pow(2) + mean_31 = pow_32.mean(-1, keepdim=True) + add_68 = mean_31 + 1e-06 + rsqrt_31 = torch.rsqrt(add_68) + mul_69 = v * rsqrt_31 + mul_70 = weight * mul_69 + return mul_70 + + def replace(v, weight): + return layer_norm_rms_wrapper(v, weight, 1e-06) + + replace_pattern(gm, pattern, replace) diff --git a/src/kernl/utils/extended_matcher.py b/src/kernl/utils/extended_matcher.py index 662aa34e..de7bd578 100644 --- a/src/kernl/utils/extended_matcher.py +++ b/src/kernl/utils/extended_matcher.py @@ -256,7 +256,7 @@ def backtracking(anchor_index, match): backtracking(anchor_index + 1, match) # revert to saved_match before matching with current anchor - match = copy.copy(saved_match) + match = copy.copy(saved_match) match = InternalMatch(anchors=self.pattern_anchors) backtracking(0, match)