Skip to content

Commit

Permalink
add rms norm backward kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sjjeong94 committed Sep 24, 2024
1 parent 8c4a9f2 commit e27cf1b
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 17 deletions.
4 changes: 2 additions & 2 deletions benchmarks/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
def bench_layer_norm_modulation(batch_size, seq_len, embed_dim, provider, device="cuda"):
# create data
x = torch.randn(batch_size, seq_len, embed_dim).to(device)
scale = torch.randn(batch_size, embed_dim).to(device)
shift = torch.randn(batch_size, embed_dim).to(device)
scale = torch.randn(batch_size, 1, embed_dim).to(device)
shift = torch.randn(batch_size, 1, embed_dim).to(device)

def y_fwd():
if provider == "triton":
Expand Down
6 changes: 3 additions & 3 deletions tests/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
@pytest.mark.parametrize("device", ["cuda"])
def test_layer_norm_modulation(batch_size, seq_len, embed_dim, device):
# create data
x = torch.randn(batch_size, seq_len, embed_dim).to(device)
scale = torch.randn(batch_size, 1, embed_dim).to(device)
shift = torch.randn(batch_size, 1, embed_dim).to(device)
x = torch.randn([batch_size, seq_len, embed_dim], device=device)
scale = torch.randn([batch_size, 1, embed_dim], device=device)
shift = torch.randn([batch_size, 1, embed_dim], device=device)
# forward pass
y_tri = layer_norm_modulation(x, scale, shift)
y_ref = layer_norm_modulation_torch(x, scale, shift)
Expand Down
28 changes: 24 additions & 4 deletions tests/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,31 @@
@pytest.mark.parametrize("head_dim", [128, 256, 512])
@pytest.mark.parametrize("device", ["cuda"])
def test_rms_norm(batch_size, num_heads, seq_len, head_dim, device):

# create data
x = torch.randn(batch_size, num_heads, seq_len, head_dim).to(device)
scale = torch.randn(head_dim).to(device)
x = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
w = torch.randn([head_dim], device=device)
dy = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
x.requires_grad_(True)
w.requires_grad_(True)

# forward pass
y_tri = rms_norm(x, scale)
y_ref = rms_norm_torch(x, scale)
y_tri = rms_norm(x, w)
y_ref = rms_norm_torch(x, w)

# backward pass (triton)
y_tri.backward(dy, retain_graph=True)
dx_tri = x.grad.clone()
dw_tri = w.grad.clone()
x.grad = None
w.grad = None

# backward pass (torch)
y_ref.backward(dy, retain_graph=True)
dx_ref = x.grad.clone()
dw_ref = w.grad.clone()

# compare
torch.testing.assert_close(y_tri, y_ref, atol=1e-5, rtol=0)
torch.testing.assert_close(dx_tri, dx_ref, atol=1e-3, rtol=0)
torch.testing.assert_close(dw_tri, dw_ref, atol=1e-3, rtol=0)
6 changes: 3 additions & 3 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
@pytest.mark.parametrize("device", ["cuda"])
def test_apply_rope(batch_size, num_heads, seq_len, head_dim, device):
# create data
xq = torch.randn(batch_size, num_heads, seq_len, head_dim).to(device)
xk = torch.randn(batch_size, num_heads, seq_len, head_dim).to(device)
freqs_cis = torch.randn(1, 1, seq_len, head_dim // 2, 2, 2).to(device)
xq = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
xk = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
freqs_cis = torch.randn([1, 1, seq_len, head_dim // 2, 2, 2], device=device)
# forward pass
q_tri, k_tri = apply_rope(xq, xk, freqs_cis)
q_ref, k_ref = apply_rope_torch(xq, xk, freqs_cis)
Expand Down
58 changes: 53 additions & 5 deletions triton_kernels/kernels/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,39 @@ def _rms_norm_fwd(
tl.store(Y + cols, y, mask=mask)


@triton.jit
def _rms_norm_bwd(
dY,
dX,
dW,
X,
W,
Rstd,
stride,
N,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0)
X += row * stride
dY += row * stride
dX += row * stride
dW += row * stride
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
dy = tl.load(dY + cols, mask=mask, other=0.0)
x = tl.load(X + cols, mask=mask, other=0.0)
w = tl.load(W + cols, mask=mask, other=0.0)
rstd = tl.load(Rstd + row)

m = dy * w
dx = rstd * m
dx += rstd * -(1 / N) * rstd * rstd * tl.sum(m * x, axis=0) * x
dw = dy * (x * rstd)

tl.store(dX + cols, dx, mask=mask)
tl.store(dW + cols, dw, mask=mask)


class _rms_norm(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, scale: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
Expand All @@ -137,16 +170,31 @@ def forward(ctx, x: torch.Tensor, scale: torch.Tensor, eps: float = 1e-6) -> tor
num_warps=num_warps,
num_ctas=1,
)
ctx.save_for_backward(x, scale, rstd)
ctx.save_for_backward(x_arg, scale, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return y

def backward(ctx, dy: torch.Tensor) -> torch.Tensor:
# TODO: implement backward pass
x, s, r = ctx.saved_tensors
return x, s, None
dx = torch.empty_like(dy)
dy_arg = dy.view(-1, dy.shape[-1])
M, N = dy_arg.shape
x, w, r = ctx.saved_tensors
dw = torch.empty_like(x)
_rms_norm_bwd[(M,)](
dy_arg,
dx,
dw,
x,
w,
r,
x.stride(0),
N,
BLOCK_SIZE=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
dw = torch.sum(dw, dim=0)
return dx, dw, None


rms_norm = _rms_norm.apply

0 comments on commit e27cf1b

Please sign in to comment.