Skip to content

Commit 5ae9099

Browse files
authored
[kernel] Add RMSLayerNorm triton kernel (#5262)
* add layerrmsnorm triton kernel * add layerrmsnorm kernel * modify the atol and rtol in test file * Remove the logics of mean computations, and update the name of ther kernel functions and files * add benchmark of rms norm
1 parent 86b63f7 commit 5ae9099

File tree

4 files changed

+103
-62
lines changed

4 files changed

+103
-62
lines changed

colossalai/kernel/triton/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
if HAS_TRITON:
1111
from .context_attn_unpad import context_attention_unpadded
1212
from .flash_decoding import flash_decoding_fwd
13-
from .fused_layernorm import layer_norm
13+
from .rms_layernorm import rms_layernorm
1414
from .gptq_triton import gptq_fused_linear_triton
1515
from .kvcache_copy import copy_kv_to_blocked_cache
1616
from .no_pad_rotary_embedding import rotary_embedding
@@ -21,7 +21,7 @@
2121
"flash_decoding_fwd",
2222
"copy_kv_to_blocked_cache",
2323
"softmax",
24-
"layer_norm",
24+
"rms_layernorm",
2525
"gptq_fused_linear_triton",
2626
"rotary_embedding",
2727
]

colossalai/kernel/triton/fused_layernorm.py renamed to colossalai/kernel/triton/rms_layernorm.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,28 @@
1414
# https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
1515

1616
@triton.jit
17-
def _layer_norm_fwd_fused(
17+
def _rmsnorm_kernel(
1818
X, # pointer to the input
1919
Y, # pointer to the output
2020
W, # pointer to the weights
21-
B, # pointer to the biases
2221
stride, # how much to increase the pointer when moving by 1 row
2322
N, # number of columns in X
2423
eps, # epsilon to avoid division by zero
2524
BLOCK_SIZE: tl.constexpr,
2625
):
26+
27+
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
28+
2729
# Map the program id to the row of X and Y it should compute.
2830
row = tl.program_id(0)
2931
Y += row * stride
3032
X += row * stride
31-
# Compute mean
32-
mean = 0
33-
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
34-
for off in range(0, N, BLOCK_SIZE):
35-
cols = off + tl.arange(0, BLOCK_SIZE)
36-
a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
37-
_mean += a
38-
mean = tl.sum(_mean, axis=0) / N
3933
# Compute variance
4034
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
4135
for off in range(0, N, BLOCK_SIZE):
4236
cols = off + tl.arange(0, BLOCK_SIZE)
4337
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
44-
x = tl.where(cols < N, x - mean, 0.0)
38+
x = tl.where(cols < N, x, 0.0)
4539
_var += x * x
4640
var = tl.sum(_var, axis=0) / N
4741
rstd = 1 / tl.sqrt(var + eps)
@@ -50,15 +44,14 @@ def _layer_norm_fwd_fused(
5044
cols = off + tl.arange(0, BLOCK_SIZE)
5145
mask = cols < N
5246
w = tl.load(W + cols, mask=mask)
53-
b = tl.load(B + cols, mask=mask)
5447
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
55-
x_hat = (x - mean) * rstd
56-
y = x_hat * w + b
48+
x_hat = x * rstd
49+
y = x_hat * w
5750
# Write output
5851
tl.store(Y + cols, y.to(tl.float16), mask=mask)
5952

6053
@torch.no_grad()
61-
def layer_norm(x, weight, bias, eps):
54+
def rms_layernorm(x, weight, eps):
6255
# allocate output
6356
y = torch.empty_like(x)
6457
# reshape input data into 2D tensor
@@ -72,7 +65,7 @@ def layer_norm(x, weight, bias, eps):
7265
# heuristics for number of warps
7366
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
7467
# enqueue kernel
75-
_layer_norm_fwd_fused[(M,)](
76-
x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
68+
_rmsnorm_kernel[(M,)](
69+
x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
7770
)
7871
return y

tests/test_infer_ops/triton/test_layernorm_triton.py

Lines changed: 0 additions & 43 deletions
This file was deleted.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import pytest
2+
import torch
3+
from packaging import version
4+
import triton
5+
6+
from colossalai.kernel.triton import rms_layernorm
7+
from colossalai.testing.utils import parameterize
8+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
9+
10+
try:
11+
pass
12+
13+
HAS_TRITON = True
14+
except ImportError:
15+
HAS_TRITON = False
16+
print("please install triton from https://github.com/openai/triton")
17+
18+
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
19+
20+
21+
@pytest.mark.skipif(
22+
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
23+
)
24+
@parameterize("M", [2, 4, 8, 16])
25+
@parameterize("N", [64, 128])
26+
def test_layer_norm(M, N):
27+
28+
dtype = torch.float16
29+
eps = 1e-5
30+
x_shape = (M, N)
31+
w_shape = (x_shape[-1],)
32+
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
33+
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
34+
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
35+
36+
y_triton = rms_layernorm(x, weight, eps=eps)
37+
y_llama = rms_norm.forward(x).to(dtype)
38+
39+
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)
40+
41+
42+
43+
# Triton benchmark plot attributions
44+
configs = [
45+
triton.testing.Benchmark(
46+
x_names=["SEQUENCE_TOTAL"],
47+
x_vals=[i for i in range(128, 1025, 128)],
48+
line_arg="provider",
49+
line_vals=["llama_rms_layernorm", "triton_rms_layernorm"],
50+
line_names=["llama_rms_layernorm", "triton_rms_layernorm"],
51+
styles=[("red", "-"), ("blue", "-")],
52+
ylabel="ms",
53+
plot_name=f"RMSNorm benchmarking results",
54+
args={"HIDDEN_SIZE": 1024},
55+
)
56+
]
57+
58+
59+
@triton.testing.perf_report(configs)
60+
def benchmark_rms_layernorm(
61+
provider: str,
62+
SEQUENCE_TOTAL: int,
63+
HIDDEN_SIZE: int,
64+
):
65+
warmup = 10
66+
rep = 100
67+
68+
dtype = torch.float16
69+
eps = 1e-5
70+
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
71+
w_shape = (x_shape[-1],)
72+
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
73+
rms_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).cuda()
74+
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
75+
76+
if provider == "llama_rms_layernorm":
77+
fn = lambda: rms_norm.forward(x).to(dtype)
78+
elif provider == "triton_rms_layernorm":
79+
fn = lambda: rms_layernorm(x, weight, eps=eps)
80+
else:
81+
raise ValueError("Undefined provider.")
82+
83+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
84+
85+
return ms
86+
87+
88+
89+
if __name__ == "__main__":
90+
test_layer_norm()
91+
# benchmark_rms_layernorm.run(save_path=".")

0 commit comments

Comments
 (0)