Skip to content

Commit

Permalink
add benchmarks for the rms norm backward kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sjjeong94 committed Oct 3, 2024
1 parent 44c23d8 commit 6bcc77c
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 51 deletions.
25 changes: 13 additions & 12 deletions benchmarks/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,22 @@ def bench_layer_norm_modulation(batch_size, seq_len, embed_dim, provider, device
scale = torch.randn([batch_size, 1, embed_dim], device=device)
shift = torch.randn([batch_size, 1, embed_dim], device=device)

def y_fwd():
if provider == "triton":
return layer_norm_modulation(x, scale, shift)
if provider == "torch_compile":
return layer_norm_modulation_torch_compile(x, scale, shift)
if provider == "torch":
return layer_norm_modulation_torch(x, scale, shift)
if provider == "triton":
fwd = lambda: layer_norm_modulation(x, scale, shift)
elif provider == "torch_compile":
fwd = lambda: layer_norm_modulation_torch_compile(x, scale, shift)
elif provider == "torch":
fwd = lambda: layer_norm_modulation_torch(x, scale, shift)
else:
raise Exception("invalid provider")

gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=[0.5, 0.2, 0.8], rep=500)
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=[0.5, 0.2, 0.8])

gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)


# Benchmark
result_dir = "./results"
os.makedirs(result_dir, exist_ok=True)
bench_layer_norm_modulation.run(save_path=result_dir, print_data=True)
fwd_dir = "./results/fwd"
os.makedirs(fwd_dir, exist_ok=True)
bench_layer_norm_modulation.run(print_data=True, save_path=fwd_dir)
44 changes: 30 additions & 14 deletions benchmarks/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,42 @@
args={"batch_size": 4, "num_heads": 24, "head_dim": 128},
)
)
def bench_rms_norm(batch_size, num_heads, seq_len, head_dim, provider, device="cuda"):
def bench_rms_norm(batch_size, num_heads, seq_len, head_dim, provider, device="cuda", mode="forward"):
# create data
x = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
scale = torch.randn([head_dim], device=device)

def y_fwd():
if provider == "triton":
return rms_norm(x, scale)
if provider == "torch_compile":
return rms_norm_torch_compile(x, scale)
if provider == "torch":
return rms_norm_torch(x, scale)
dy = torch.randn_like(x)

if provider == "triton":
fwd = lambda: rms_norm(x, scale)
elif provider == "torch_compile":
fwd = lambda: rms_norm_torch_compile(x, scale)
elif provider == "torch":
fwd = lambda: rms_norm_torch(x, scale)
else:
raise Exception("invalid provider")

x.requires_grad_(True)
scale.requires_grad_(True)
if mode == "fwd":
func = fwd
elif mode == "bwd":
y = fwd()
bwd = lambda: y.backward(dy, retain_graph=True)
func = bwd
else:
raise Exception("invalid mode")

ms, min_ms, max_ms = triton.testing.do_bench(func, quantiles=[0.5, 0.2, 0.8])

gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=[0.5, 0.2, 0.8], rep=500)

return gbps(ms), gbps(max_ms), gbps(min_ms)


# Benchmark
result_dir = "./results"
os.makedirs(result_dir, exist_ok=True)
bench_rms_norm.run(save_path=result_dir, print_data=True)
fwd_dir = "./results/fwd"
bwd_dir = "./results/bwd"
os.makedirs(fwd_dir, exist_ok=True)
os.makedirs(bwd_dir, exist_ok=True)
bench_rms_norm.run(print_data=True, save_path=fwd_dir, mode="fwd")
bench_rms_norm.run(print_data=True, save_path=bwd_dir, mode="bwd")
37 changes: 19 additions & 18 deletions benchmarks/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,26 @@
)
def bench_apply_rope(batch_size, num_heads, seq_len, head_dim, provider, device="cuda"):
# create data
xq = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
xk = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
freqs_cis = torch.randn([1, 1, seq_len, head_dim // 2, 2, 2], device=device)

def y_fwd():
if provider == "triton":
return apply_rope(xq, xk, freqs_cis)
if provider == "torch_compile":
return apply_rope_torch_compile(xq, xk, freqs_cis)
if provider == "torch":
return apply_rope_torch(xq, xk, freqs_cis)

gbps = lambda ms: 2 * xq.numel() * xq.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=[0.5, 0.2, 0.8], rep=500)

q = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
k = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
pe = torch.randn([1, 1, seq_len, head_dim // 2, 2, 2], device=device)

if provider == "triton":
fwd = lambda: apply_rope(q, k, pe)
elif provider == "torch_compile":
fwd = lambda: apply_rope_torch_compile(q, k, pe)
elif provider == "torch":
fwd = lambda: apply_rope_torch(q, k, pe)
else:
raise Exception("invalid provider")

ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=[0.5, 0.2, 0.8])

gbps = lambda ms: 2 * q.numel() * q.element_size() / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)


# Benchmark
result_dir = "./results"
os.makedirs(result_dir, exist_ok=True)
bench_apply_rope.run(save_path=result_dir, print_data=True)
fwd_dir = "./results/fwd"
os.makedirs(fwd_dir, exist_ok=True)
bench_apply_rope.run(print_data=True, save_path=fwd_dir)
2 changes: 1 addition & 1 deletion tests/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_layer_norm_modulation(batch_size, seq_len, embed_dim, device):
# create data
x = torch.randn([batch_size, seq_len, embed_dim], device=device)
scale = torch.randn([batch_size, 1, embed_dim], device=device)
shift = torch.randn([batch_size, 1, embed_dim], device=device)
shift = torch.randn_like(scale)
# forward pass
y_tri = layer_norm_modulation(x, scale, shift)
y_ref = layer_norm_modulation_torch(x, scale, shift)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_rms_norm(batch_size, num_heads, seq_len, head_dim, device):
# create data
x = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
w = torch.randn([head_dim], device=device)
dy = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)
w.requires_grad_(True)

Expand Down
10 changes: 5 additions & 5 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
@pytest.mark.parametrize("device", ["cuda"])
def test_apply_rope(batch_size, num_heads, seq_len, head_dim, device):
# create data
xq = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
xk = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
freqs_cis = torch.randn([1, 1, seq_len, head_dim // 2, 2, 2], device=device)
q = torch.randn([batch_size, num_heads, seq_len, head_dim], device=device)
k = torch.randn_like(q)
pe = torch.randn([1, 1, seq_len, head_dim // 2, 2, 2], device=device)
# forward pass
q_tri, k_tri = apply_rope(xq, xk, freqs_cis)
q_ref, k_ref = apply_rope_torch(xq, xk, freqs_cis)
q_tri, k_tri = apply_rope(q, k, pe)
q_ref, k_ref = apply_rope_torch(q, k, pe)
# compare
torch.testing.assert_close(q_tri, q_ref, atol=1e-5, rtol=0)
torch.testing.assert_close(k_tri, k_ref, atol=1e-5, rtol=0)

0 comments on commit 6bcc77c

Please sign in to comment.