Skip to content

Commit

Permalink
feat: layernorm rms replacement for T5 (#107)
Browse files Browse the repository at this point in the history
* feat: rms replacement base

* fix: add rms kernel replacement

* fix: reformat

* fix: import

* fix: reformat

Co-authored-by: Michaël Benesty <pommedeterresautee@users.noreply.github.com>
  • Loading branch information
gaetansnl and pommedeterresautee authored Oct 27, 2022
1 parent dcb0e83 commit 1463e39
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/kernl/optimizer/dynamo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from kernl.optimizer.attention import fuse_attention_pattern_1, fuse_attention_pattern_2
from kernl.optimizer.dropout import remove_dropout
from kernl.optimizer.layer_norm import replace_layer_norm
from kernl.optimizer.layer_norm import replace_layer_norm, replace_layer_norm_rms
from kernl.optimizer.linear import replace_all_linear
from kernl.optimizer.normalizer import normalize_operators

Expand All @@ -29,4 +29,5 @@ def dynamo_backend_ofi(gm: torch.fx.GraphModule, assume_causal=False):
fuse_attention_pattern_2(gm, assume_causal)
replace_all_linear(gm)
replace_layer_norm(gm)
replace_layer_norm_rms(gm)
return gm
24 changes: 23 additions & 1 deletion src/kernl/optimizer/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from kernl.implementations.layer_norm import layer_norm
from kernl.implementations.layer_norm import _layer_norm_fwd_fused_single_pass, layer_norm
from kernl.utils.extended_matcher import replace_pattern


Expand All @@ -29,7 +29,12 @@ def layer_norm_wrapper(v: torch.Tensor, layernorm: torch.nn.LayerNorm):
return layer_norm(v, layernorm.weight, layernorm.bias, layernorm.eps)


def layer_norm_rms_wrapper(v: torch.Tensor, weight: torch.Tensor, eps: float):
return layer_norm(v, weight, None, eps, _layer_norm_fwd_fused_single_pass, use_rms_norm=True)


torch.fx.wrap("layer_norm_wrapper")
torch.fx.wrap("layer_norm_rms_wrapper")


def replace_layer_norm(gm: torch.fx.GraphModule):
Expand All @@ -50,3 +55,20 @@ def forward(self, v):
return layer_norm_wrapper(v, self.layernorm)

replace_pattern(gm, Pattern(), Replacement())


def replace_layer_norm_rms(gm: torch.fx.GraphModule):
def pattern(v, weight):
to_38 = v.to(torch.float32)
pow_32 = to_38.pow(2)
mean_31 = pow_32.mean(-1, keepdim=True)
add_68 = mean_31 + 1e-06
rsqrt_31 = torch.rsqrt(add_68)
mul_69 = v * rsqrt_31
mul_70 = weight * mul_69
return mul_70

def replace(v, weight):
return layer_norm_rms_wrapper(v, weight, 1e-06)

replace_pattern(gm, pattern, replace)
2 changes: 1 addition & 1 deletion src/kernl/utils/extended_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def backtracking(anchor_index, match):
backtracking(anchor_index + 1, match)

# revert to saved_match before matching with current anchor
match = copy.copy(saved_match)
match = copy.copy(saved_match)

match = InternalMatch(anchors=self.pattern_anchors)
backtracking(0, match)
Expand Down

0 comments on commit 1463e39

Please sign in to comment.