Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ build/
.DS_Store
*.key
.cache
.vscode/
266 changes: 266 additions & 0 deletions src/parallax/metal/paged_attention/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import os
from typing import Dict, List

import mlx.core as mx

# Cache for compiled kernels
_KERNELS: Dict[str, object] = {}


def _get_metal_source(filename):
path = os.path.join(os.path.dirname(__file__), filename)
with open(path, "r") as f:
return f.read()


def _type_to_string(dtype: mx.Dtype) -> str:
if dtype == mx.float32:
return "float"
elif dtype == mx.float16:
return "half"
elif dtype == mx.bfloat16:
# Metal 3.1+ supports bfloat, typically via bfloat16_t or using half
# For now we map to bfloat16_t assuming compiler support
return "bfloat16_t"
else:
raise ValueError(f"Unsupported dtype for paged attention: {dtype}")


def _get_kernel(
name: str,
filename: str,
input_names: List[str],
output_names: List[str],
dtype: mx.Dtype = mx.float32,
):
type_str = _type_to_string(dtype)
kernel_key = f"{name}_{type_str}"

if kernel_key not in _KERNELS:
source = _get_metal_source(filename)
# Simple template substitution
source = source.replace("{{T}}", type_str)

header = """
#include <metal_stdlib>
using namespace metal;
"""
_KERNELS[kernel_key] = mx.fast.metal_kernel(
name=name, # Internal name for MLX JIT cache (not used for dispatch if we hold the object)
input_names=input_names,
output_names=output_names,
source=source,
header=header,
)
return _KERNELS[kernel_key]


def reshape_and_cache(
key: mx.array, # (batch, num_kv_heads, 1, head_dim)
value: mx.array, # (batch, num_kv_heads, 1, head_dim)
key_cache: mx.array, # (num_layers, num_blocks, num_kv_heads, block_size, head_dim)
value_cache: mx.array,
block_tables: mx.array, # (batch, max_blocks)
context_lengths: mx.array, # (batch,)
block_size: int,
layer_idx: int,
):
"""
Writes new keys and values into the Paged KV Cache using a custom Metal kernel.
NOTE: This performs an in-place update on key_cache/value_cache buffers.
"""
batch_size = key.shape[0]
num_kv_heads = key.shape[1]
head_dim = key.shape[3]
num_layers = key_cache.shape[0]
num_blocks = key_cache.shape[1]

dtype = key.dtype
if key_cache.dtype != dtype:
raise ValueError(f"Key cache dtype {key_cache.dtype} does not match key dtype {dtype}")

# 1. Prepare inputs
indices = context_lengths - 1
block_indices_in_table = indices // block_size
offsets = indices % block_size

batch_indices = mx.arange(batch_size)
physical_block_numbers = block_tables[batch_indices, block_indices_in_table]

slot_mapping = physical_block_numbers.astype(mx.int64) * block_size + offsets.astype(mx.int64)

# 2. Prepare Constants
key_stride = num_kv_heads * head_dim
value_stride = num_kv_heads * head_dim

def mk_int(val):
return mx.array(val, dtype=mx.int32)

c_key_stride = mk_int(key_stride)
c_val_stride = mk_int(value_stride)
c_num_kv = mk_int(num_kv_heads)
c_head_dim = mk_int(head_dim)
c_block_size = mk_int(block_size)
c_layer_idx = mk_int(layer_idx)
c_num_layers = mk_int(num_layers)
c_num_blocks = mk_int(num_blocks)

# Inputs list
inputs = [
key,
value,
key_cache,
value_cache,
slot_mapping,
c_key_stride,
c_val_stride,
c_num_kv,
c_head_dim,
c_block_size,
c_layer_idx,
c_num_layers,
c_num_blocks,
]

# Input names (just for declaration)
input_names = [
"key",
"value",
"key_cache",
"value_cache",
"slot_mapping",
"key_stride",
"value_stride",
"num_kv_heads",
"head_dim",
"block_size",
"layer_idx",
"num_layers",
"num_blocks",
]

# 3. Get and Launch Kernel
kernel = _get_kernel(
name="reshape_and_cache_kernel",
filename="reshape_and_cache.metal",
input_names=input_names,
output_names=["dummy_out"],
dtype=dtype,
)

grid = (num_kv_heads * head_dim, batch_size, 1)
thread_group = (min(1024, num_kv_heads * head_dim), 1, 1)

# Execute
outputs = kernel(
inputs=inputs,
grid=grid,
threadgroup=thread_group,
output_shapes=[(1,)],
output_dtypes=[mx.float32], # Dummy output dtype usually doesn't matter
verbose=False,
)

mx.eval(outputs)

return key_cache, value_cache


def paged_attention(
queries: mx.array,
key_cache: mx.array,
value_cache: mx.array,
block_tables: mx.array,
context_lengths: mx.array,
block_size: int,
scale: float,
num_kv_heads: int,
layer_idx: int,
) -> mx.array:
"""
Paged Attention using Metal Kernel.
"""
batch_size = queries.shape[0]
num_heads = queries.shape[1]
dtype = queries.dtype

if queries.ndim == 4:
if queries.shape[2] != 1:
pass
queries = queries.squeeze(2)

head_dim = queries.shape[2]
num_layers = key_cache.shape[0]
num_total_blocks = key_cache.shape[1]
max_blocks = block_tables.shape[1]

# Prepare Constants
def mk_int(val):
return mx.array(val, dtype=mx.int32)

c_num_heads = mk_int(num_heads)
c_num_kv_heads = mk_int(num_kv_heads)
c_head_dim = mk_int(head_dim)
c_block_size = mk_int(block_size)
c_max_blocks = mk_int(max_blocks)
c_layer_idx = mk_int(layer_idx)
c_num_layers = mk_int(num_layers)
c_num_total_blocks = mk_int(num_total_blocks)
c_scale = mx.array(scale, dtype=mx.float32)

inputs = [
queries,
key_cache,
value_cache,
block_tables,
context_lengths,
c_num_heads,
c_num_kv_heads,
c_head_dim,
c_block_size,
c_max_blocks,
c_layer_idx,
c_num_layers,
c_num_total_blocks,
c_scale,
]

input_names = [
"queries",
"key_cache",
"value_cache",
"block_tables",
"context_lengths",
"num_heads",
"num_kv_heads",
"head_dim",
"block_size",
"max_blocks",
"layer_idx",
"num_layers",
"num_total_blocks",
"scale",
]

kernel = _get_kernel(
name="paged_attention_kernel",
filename="paged_attention_kernel.metal",
input_names=input_names,
output_names=["output"],
dtype=dtype, # This will generate paged_attention_kernel_half etc.
)

grid = (num_heads * 32, batch_size, 1)
thread_group = (32, 1, 1)

outputs = kernel(
inputs=inputs,
grid=grid,
threadgroup=thread_group,
output_shapes=[(batch_size, num_heads, head_dim)],
output_dtypes=[dtype], # Output matches input dtype
verbose=False,
)

out = outputs[0]
return out[:, :, None, :]
124 changes: 124 additions & 0 deletions src/parallax/metal/paged_attention/paged_attention_kernel.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@

// Inputs:
// queries, key_cache, value_cache, block_tables, context_lengths
// output (output array)
// num_heads, num_kv_heads, head_dim, block_size, max_blocks, layer_idx,
// num_layers, num_total_blocks, scale (All pointers)

uint3 gid = thread_position_in_grid;
uint3 tid = thread_position_in_threadgroup;

// Each threadgroup handles one head.
// Assuming threadgroup size is 32 (SIMD width).
// head_idx comes from the group index.
// Or we can calculate it from gid if grid is linear.
// grid.x = num_heads * 32.

int head_idx = gid.x / 32;
int batch_idx = gid.y;

// Dereference constants
int _num_heads = num_heads;
int _num_kv_heads = num_kv_heads;
int _head_dim = head_dim;
int _block_size = block_size;
int _max_blocks = max_blocks;
int _layer_idx = layer_idx;
int _num_total_blocks = num_total_blocks;
float _scale = scale;

if (head_idx >= _num_heads)
return;

int kv_head_idx = head_idx / (_num_heads / _num_kv_heads);

// Load Query
// Q: [batch, num_heads, head_dim]
// Thread i loads elements i, i+32, ...

float q_vec[4] = {0.0f, 0.0f, 0.0f, 0.0f};

int q_offset = batch_idx * _num_heads * _head_dim + head_idx * _head_dim;

for (int i = tid.x; i < _head_dim; i += 32) {
if (i < 128) {
q_vec[i / 32] = queries[q_offset + i];
}
}

// Running statistics for Softmax
float m_i = -INFINITY;
float l_i = 0.0f;
float acc_vec[4] = {0.0f, 0.0f, 0.0f, 0.0f};

int context_len = context_lengths[batch_idx];
int num_context_blocks = (context_len + _block_size - 1) / _block_size;

// Strides
long layer_stride =
(long)_num_total_blocks * _num_kv_heads * _block_size * _head_dim;
long block_stride = _num_kv_heads * _block_size * _head_dim;
long head_stride = _block_size * _head_dim;

long layer_offset = _layer_idx * layer_stride;

// Iterate over blocks
for (int b = 0; b < num_context_blocks; b++) {
int block_num = block_tables[batch_idx * _max_blocks + b];

long block_base =
layer_offset + block_num * block_stride + kv_head_idx * head_stride;

int tokens_in_block = _block_size;
if (b == num_context_blocks - 1) {
tokens_in_block = context_len % _block_size;
if (tokens_in_block == 0)
tokens_in_block = _block_size;
}

for (int t = 0; t < tokens_in_block; t++) {
// Compute Dot Product Q * K[t]
float score = 0.0f;
for (int i = tid.x; i < _head_dim; i += 32) {
// offset inside block: t * head_dim + i
float k_val = key_cache[block_base + t * _head_dim + i];

if (i < 128) {
score += q_vec[i / 32] * k_val;
}
}

// SIMD Reduction for score
score = simd_sum(score);
score *= _scale;

// Softmax update
float m_prev = m_i;
m_i = max(m_prev, score);
float alpha = exp(m_prev - m_i);
float beta = exp(score - m_i);

l_i = l_i * alpha + beta;

// Accumulate V
for (int i = tid.x; i < _head_dim; i += 32) {
float v_val = value_cache[block_base + t * _head_dim + i];
if (i < 128) {
acc_vec[i / 32] = acc_vec[i / 32] * alpha + v_val * beta;
}
}
}
}

// Finalize Output
for (int i = 0; i < 4; i++) {
acc_vec[i] /= l_i;
}

int out_offset = batch_idx * _num_heads * _head_dim + head_idx * _head_dim;

for (int i = tid.x; i < _head_dim; i += 32) {
if (i < 128) {
output[out_offset + i] = ({{T}})acc_vec[i / 32];
}
}
Loading