Skip to content

Commit c60ada3

Browse files
authored
Merge pull request #50 from learning-chip/mamba2_kernel
benchmark script for simple_gla vs mamba2 kernel
2 parents 9aa2480 + f57a027 commit c60ada3

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Dependencies:
3+
$ pip install mamba-ssm==2.2.2 triton==2.3.1
4+
5+
For correctness check, see:
6+
https://github.com/sustcsonglin/flash-linear-attention/pull/49
7+
"""
8+
9+
import torch
10+
import triton
11+
12+
from fla.ops.simple_gla import chunk_simple_gla
13+
14+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
15+
16+
17+
@triton.testing.perf_report(
18+
triton.testing.Benchmark(
19+
# argument names to use as an x-axis for the plot
20+
x_names=['T'],
21+
# different possible values for `x_name`
22+
x_vals=[64] + [128 * 2 ** i for i in range(0, 8)],
23+
# argument name whose value corresponds to a different line in the plot
24+
line_arg='provider',
25+
# possible values for `line_arg``
26+
line_vals=["chunk_simple_gla", "mamba2_ssd"],
27+
# label name for the lines
28+
line_names=["chunk_simple_gla", "mamba2_ssd"],
29+
# line styles
30+
styles=[('blue', '-'), ('red', '-')],
31+
ylabel="Execution Time (ms)", # label name for the y-axis
32+
# name for the plot. Used also as a file name for saving the plot.
33+
plot_name="Performance",
34+
args={},
35+
)
36+
)
37+
def benchmark(T, provider):
38+
# TODO: also add bwd pass benchmark
39+
device = 'cuda'
40+
dtype = torch.bfloat16
41+
B, H, D = 16, 8, 128
42+
# TODO: test more shapes
43+
# TODO: different values for D_V and D_QK
44+
# TODO: different values for H_Q and H_KV
45+
final_state = False # does not impact performance
46+
47+
# initialize Mamba2-format inputs
48+
X_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device)
49+
dt_mamba = torch.ones(B, T, H, dtype=dtype, device=device)
50+
A_mamba = -0.1 * torch.rand(H, dtype=dtype, device=device)
51+
B_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device)
52+
C_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device)
53+
54+
quantiles = [0.5, 0.2, 0.8]
55+
if provider == 'chunk_simple_gla':
56+
# mapping inputs Mamba2 -> FLA
57+
# C, B, X: [B, T, H, D] -> [B, H, T, D]
58+
# g: [B, T, H] -> [B, H, T]
59+
q = C_mamba.transpose(1, 2).contiguous()
60+
k = B_mamba.transpose(1, 2).contiguous()
61+
v = X_mamba.transpose(1, 2).contiguous()
62+
g = (A_mamba * dt_mamba).transpose(1, 2).contiguous()
63+
# NOTE: whether to include the memory-copy cost of `contiguous()`?
64+
# this depends on the memory layout used by surrounding non-SSM layers
65+
66+
results = triton.testing.do_bench(
67+
lambda: chunk_simple_gla(
68+
q, k, v, g, scale=1.0, output_final_state=final_state
69+
), quantiles=quantiles
70+
)
71+
72+
elif provider == 'mamba2_ssd':
73+
# NOTE: `chunk_size` is configurable in mamba2 kernel
74+
# here sets to the same hard-coded `BT = 64` as in simple_gla kernel
75+
# TODO: benchmark different chunk sizes
76+
results = triton.testing.do_bench(
77+
lambda: mamba_chunk_scan_combined(
78+
X_mamba, dt_mamba, A_mamba, B_mamba, C_mamba,
79+
chunk_size=64, D=None, return_final_states=final_state
80+
),
81+
quantiles=quantiles
82+
)
83+
return results
84+
85+
if __name__ == '__main__':
86+
benchmark.run(print_data=True, save_path='.')

0 commit comments

Comments
 (0)