Skip to content

Commit 97d82ed

Browse files
authored
feat(kernel): add paged attn naive kernel (#273)
1 parent 5483a74 commit 97d82ed

File tree

5 files changed

+724
-0
lines changed

5 files changed

+724
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ build/
1010
.DS_Store
1111
*.key
1212
.cache
13+
.vscode/
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
import os
2+
from typing import Dict, List
3+
4+
import mlx.core as mx
5+
6+
# Cache for compiled kernels
7+
_KERNELS: Dict[str, object] = {}
8+
9+
10+
def _get_metal_source(filename):
11+
path = os.path.join(os.path.dirname(__file__), filename)
12+
with open(path, "r") as f:
13+
return f.read()
14+
15+
16+
def _type_to_string(dtype: mx.Dtype) -> str:
17+
if dtype == mx.float32:
18+
return "float"
19+
elif dtype == mx.float16:
20+
return "half"
21+
elif dtype == mx.bfloat16:
22+
# Metal 3.1+ supports bfloat, typically via bfloat16_t or using half
23+
# For now we map to bfloat16_t assuming compiler support
24+
return "bfloat16_t"
25+
else:
26+
raise ValueError(f"Unsupported dtype for paged attention: {dtype}")
27+
28+
29+
def _get_kernel(
30+
name: str,
31+
filename: str,
32+
input_names: List[str],
33+
output_names: List[str],
34+
dtype: mx.Dtype = mx.float32,
35+
):
36+
type_str = _type_to_string(dtype)
37+
kernel_key = f"{name}_{type_str}"
38+
39+
if kernel_key not in _KERNELS:
40+
source = _get_metal_source(filename)
41+
# Simple template substitution
42+
source = source.replace("{{T}}", type_str)
43+
44+
header = """
45+
#include <metal_stdlib>
46+
using namespace metal;
47+
"""
48+
_KERNELS[kernel_key] = mx.fast.metal_kernel(
49+
name=name, # Internal name for MLX JIT cache (not used for dispatch if we hold the object)
50+
input_names=input_names,
51+
output_names=output_names,
52+
source=source,
53+
header=header,
54+
)
55+
return _KERNELS[kernel_key]
56+
57+
58+
def reshape_and_cache(
59+
key: mx.array, # (batch, num_kv_heads, 1, head_dim)
60+
value: mx.array, # (batch, num_kv_heads, 1, head_dim)
61+
key_cache: mx.array, # (num_layers, num_blocks, num_kv_heads, block_size, head_dim)
62+
value_cache: mx.array,
63+
block_tables: mx.array, # (batch, max_blocks)
64+
context_lengths: mx.array, # (batch,)
65+
block_size: int,
66+
layer_idx: int,
67+
):
68+
"""
69+
Writes new keys and values into the Paged KV Cache using a custom Metal kernel.
70+
NOTE: This performs an in-place update on key_cache/value_cache buffers.
71+
"""
72+
batch_size = key.shape[0]
73+
num_kv_heads = key.shape[1]
74+
head_dim = key.shape[3]
75+
num_layers = key_cache.shape[0]
76+
num_blocks = key_cache.shape[1]
77+
78+
dtype = key.dtype
79+
if key_cache.dtype != dtype:
80+
raise ValueError(f"Key cache dtype {key_cache.dtype} does not match key dtype {dtype}")
81+
82+
# 1. Prepare inputs
83+
indices = context_lengths - 1
84+
block_indices_in_table = indices // block_size
85+
offsets = indices % block_size
86+
87+
batch_indices = mx.arange(batch_size)
88+
physical_block_numbers = block_tables[batch_indices, block_indices_in_table]
89+
90+
slot_mapping = physical_block_numbers.astype(mx.int64) * block_size + offsets.astype(mx.int64)
91+
92+
# 2. Prepare Constants
93+
key_stride = num_kv_heads * head_dim
94+
value_stride = num_kv_heads * head_dim
95+
96+
def mk_int(val):
97+
return mx.array(val, dtype=mx.int32)
98+
99+
c_key_stride = mk_int(key_stride)
100+
c_val_stride = mk_int(value_stride)
101+
c_num_kv = mk_int(num_kv_heads)
102+
c_head_dim = mk_int(head_dim)
103+
c_block_size = mk_int(block_size)
104+
c_layer_idx = mk_int(layer_idx)
105+
c_num_layers = mk_int(num_layers)
106+
c_num_blocks = mk_int(num_blocks)
107+
108+
# Inputs list
109+
inputs = [
110+
key,
111+
value,
112+
key_cache,
113+
value_cache,
114+
slot_mapping,
115+
c_key_stride,
116+
c_val_stride,
117+
c_num_kv,
118+
c_head_dim,
119+
c_block_size,
120+
c_layer_idx,
121+
c_num_layers,
122+
c_num_blocks,
123+
]
124+
125+
# Input names (just for declaration)
126+
input_names = [
127+
"key",
128+
"value",
129+
"key_cache",
130+
"value_cache",
131+
"slot_mapping",
132+
"key_stride",
133+
"value_stride",
134+
"num_kv_heads",
135+
"head_dim",
136+
"block_size",
137+
"layer_idx",
138+
"num_layers",
139+
"num_blocks",
140+
]
141+
142+
# 3. Get and Launch Kernel
143+
kernel = _get_kernel(
144+
name="reshape_and_cache_kernel",
145+
filename="reshape_and_cache.metal",
146+
input_names=input_names,
147+
output_names=["dummy_out"],
148+
dtype=dtype,
149+
)
150+
151+
grid = (num_kv_heads * head_dim, batch_size, 1)
152+
thread_group = (min(1024, num_kv_heads * head_dim), 1, 1)
153+
154+
# Execute
155+
outputs = kernel(
156+
inputs=inputs,
157+
grid=grid,
158+
threadgroup=thread_group,
159+
output_shapes=[(1,)],
160+
output_dtypes=[mx.float32], # Dummy output dtype usually doesn't matter
161+
verbose=False,
162+
)
163+
164+
mx.eval(outputs)
165+
166+
return key_cache, value_cache
167+
168+
169+
def paged_attention(
170+
queries: mx.array,
171+
key_cache: mx.array,
172+
value_cache: mx.array,
173+
block_tables: mx.array,
174+
context_lengths: mx.array,
175+
block_size: int,
176+
scale: float,
177+
num_kv_heads: int,
178+
layer_idx: int,
179+
) -> mx.array:
180+
"""
181+
Paged Attention using Metal Kernel.
182+
"""
183+
batch_size = queries.shape[0]
184+
num_heads = queries.shape[1]
185+
dtype = queries.dtype
186+
187+
if queries.ndim == 4:
188+
if queries.shape[2] != 1:
189+
pass
190+
queries = queries.squeeze(2)
191+
192+
head_dim = queries.shape[2]
193+
num_layers = key_cache.shape[0]
194+
num_total_blocks = key_cache.shape[1]
195+
max_blocks = block_tables.shape[1]
196+
197+
# Prepare Constants
198+
def mk_int(val):
199+
return mx.array(val, dtype=mx.int32)
200+
201+
c_num_heads = mk_int(num_heads)
202+
c_num_kv_heads = mk_int(num_kv_heads)
203+
c_head_dim = mk_int(head_dim)
204+
c_block_size = mk_int(block_size)
205+
c_max_blocks = mk_int(max_blocks)
206+
c_layer_idx = mk_int(layer_idx)
207+
c_num_layers = mk_int(num_layers)
208+
c_num_total_blocks = mk_int(num_total_blocks)
209+
c_scale = mx.array(scale, dtype=mx.float32)
210+
211+
inputs = [
212+
queries,
213+
key_cache,
214+
value_cache,
215+
block_tables,
216+
context_lengths,
217+
c_num_heads,
218+
c_num_kv_heads,
219+
c_head_dim,
220+
c_block_size,
221+
c_max_blocks,
222+
c_layer_idx,
223+
c_num_layers,
224+
c_num_total_blocks,
225+
c_scale,
226+
]
227+
228+
input_names = [
229+
"queries",
230+
"key_cache",
231+
"value_cache",
232+
"block_tables",
233+
"context_lengths",
234+
"num_heads",
235+
"num_kv_heads",
236+
"head_dim",
237+
"block_size",
238+
"max_blocks",
239+
"layer_idx",
240+
"num_layers",
241+
"num_total_blocks",
242+
"scale",
243+
]
244+
245+
kernel = _get_kernel(
246+
name="paged_attention_kernel",
247+
filename="paged_attention_kernel.metal",
248+
input_names=input_names,
249+
output_names=["output"],
250+
dtype=dtype, # This will generate paged_attention_kernel_half etc.
251+
)
252+
253+
grid = (num_heads * 32, batch_size, 1)
254+
thread_group = (32, 1, 1)
255+
256+
outputs = kernel(
257+
inputs=inputs,
258+
grid=grid,
259+
threadgroup=thread_group,
260+
output_shapes=[(batch_size, num_heads, head_dim)],
261+
output_dtypes=[dtype], # Output matches input dtype
262+
verbose=False,
263+
)
264+
265+
out = outputs[0]
266+
return out[:, :, None, :]
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
2+
// Inputs:
3+
// queries, key_cache, value_cache, block_tables, context_lengths
4+
// output (output array)
5+
// num_heads, num_kv_heads, head_dim, block_size, max_blocks, layer_idx,
6+
// num_layers, num_total_blocks, scale (All pointers)
7+
8+
uint3 gid = thread_position_in_grid;
9+
uint3 tid = thread_position_in_threadgroup;
10+
11+
// Each threadgroup handles one head.
12+
// Assuming threadgroup size is 32 (SIMD width).
13+
// head_idx comes from the group index.
14+
// Or we can calculate it from gid if grid is linear.
15+
// grid.x = num_heads * 32.
16+
17+
int head_idx = gid.x / 32;
18+
int batch_idx = gid.y;
19+
20+
// Dereference constants
21+
int _num_heads = num_heads;
22+
int _num_kv_heads = num_kv_heads;
23+
int _head_dim = head_dim;
24+
int _block_size = block_size;
25+
int _max_blocks = max_blocks;
26+
int _layer_idx = layer_idx;
27+
int _num_total_blocks = num_total_blocks;
28+
float _scale = scale;
29+
30+
if (head_idx >= _num_heads)
31+
return;
32+
33+
int kv_head_idx = head_idx / (_num_heads / _num_kv_heads);
34+
35+
// Load Query
36+
// Q: [batch, num_heads, head_dim]
37+
// Thread i loads elements i, i+32, ...
38+
39+
float q_vec[4] = {0.0f, 0.0f, 0.0f, 0.0f};
40+
41+
int q_offset = batch_idx * _num_heads * _head_dim + head_idx * _head_dim;
42+
43+
for (int i = tid.x; i < _head_dim; i += 32) {
44+
if (i < 128) {
45+
q_vec[i / 32] = queries[q_offset + i];
46+
}
47+
}
48+
49+
// Running statistics for Softmax
50+
float m_i = -INFINITY;
51+
float l_i = 0.0f;
52+
float acc_vec[4] = {0.0f, 0.0f, 0.0f, 0.0f};
53+
54+
int context_len = context_lengths[batch_idx];
55+
int num_context_blocks = (context_len + _block_size - 1) / _block_size;
56+
57+
// Strides
58+
long layer_stride =
59+
(long)_num_total_blocks * _num_kv_heads * _block_size * _head_dim;
60+
long block_stride = _num_kv_heads * _block_size * _head_dim;
61+
long head_stride = _block_size * _head_dim;
62+
63+
long layer_offset = _layer_idx * layer_stride;
64+
65+
// Iterate over blocks
66+
for (int b = 0; b < num_context_blocks; b++) {
67+
int block_num = block_tables[batch_idx * _max_blocks + b];
68+
69+
long block_base =
70+
layer_offset + block_num * block_stride + kv_head_idx * head_stride;
71+
72+
int tokens_in_block = _block_size;
73+
if (b == num_context_blocks - 1) {
74+
tokens_in_block = context_len % _block_size;
75+
if (tokens_in_block == 0)
76+
tokens_in_block = _block_size;
77+
}
78+
79+
for (int t = 0; t < tokens_in_block; t++) {
80+
// Compute Dot Product Q * K[t]
81+
float score = 0.0f;
82+
for (int i = tid.x; i < _head_dim; i += 32) {
83+
// offset inside block: t * head_dim + i
84+
float k_val = key_cache[block_base + t * _head_dim + i];
85+
86+
if (i < 128) {
87+
score += q_vec[i / 32] * k_val;
88+
}
89+
}
90+
91+
// SIMD Reduction for score
92+
score = simd_sum(score);
93+
score *= _scale;
94+
95+
// Softmax update
96+
float m_prev = m_i;
97+
m_i = max(m_prev, score);
98+
float alpha = exp(m_prev - m_i);
99+
float beta = exp(score - m_i);
100+
101+
l_i = l_i * alpha + beta;
102+
103+
// Accumulate V
104+
for (int i = tid.x; i < _head_dim; i += 32) {
105+
float v_val = value_cache[block_base + t * _head_dim + i];
106+
if (i < 128) {
107+
acc_vec[i / 32] = acc_vec[i / 32] * alpha + v_val * beta;
108+
}
109+
}
110+
}
111+
}
112+
113+
// Finalize Output
114+
for (int i = 0; i < 4; i++) {
115+
acc_vec[i] /= l_i;
116+
}
117+
118+
int out_offset = batch_idx * _num_heads * _head_dim + head_idx * _head_dim;
119+
120+
for (int i = tid.x; i < _head_dim; i += 32) {
121+
if (i < 128) {
122+
output[out_offset + i] = ({{T}})acc_vec[i / 32];
123+
}
124+
}

0 commit comments

Comments
 (0)