Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 160 additions & 1 deletion examples/deepseek_v32/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,166 @@
```
deepseek_v32/
├── README.md # This file
├── fp8_mqa_logits.py # FP8 Indexer
├── figures/ # Figures and diagrams
├── inference/ # Inference implementation folder
├── fp8_lighting_indexer.py # FP8 lighting indexer
├── sparse_mla_fwd.py # Sparse MLA forward implementation
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
├── topk_selector.py # Top-k selector implementation
```

## File Descriptions

### Architecture Overview

![DeepSeek V3.2 Architecture](./figures/v32_arch.png)

The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations:

1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
3. **Multi-Query Attention** (`sparse_mla_fwd.py` and `sparse_mla_fwd_pipelined.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward pass

### Lightning Indexer

Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector.

The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices:

```python
T.gemm(
index_k_shared,
index_q_shared,
s,
transpose_B=True,
clear_accum=True,
policy=T.GemmWarpPolicy.FullCol,
)
```

After the matmul, we apply ReLU and aggregate across heads with learned weights:

```python
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i, h_i] = (
T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]
) * index_k_scale_fragment[bn_i]

T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
```

The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions:

```python
for bq_i in T.serial(block_Q):
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
for bq_i in T.serial(block_Q):
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
```

The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training.

### Top-k Selector

The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process.

The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence:

```python
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
input_idx = s*BLOCK_SIZE+tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
inval_int16 = convert_to_uint16(input[bx, input_idx])
T.atomic_add(s_histogram[inval_int16], 1)
```

The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin:

```python
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
```

Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing:

```python
if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True)
index[bx, pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
s_input_idx[0, pos] = input_idx
```

Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence.

### Sparse MLA Forward

The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens.

Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions:

```python
# Dense MLA: iterate over full sequence
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
# ... compute attention over this block
```

Sparse MLA only loads KV positions selected by the top-k selector:

```python
# Sparse MLA: iterate over selected indices only
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
# ... compute attention over selected tokens
```

This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid:

```python
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
```

Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA.

### Sparse MLA Forward (Pipelined)

The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines.

The key difference is splitting the warp groups into specialized roles:

```python
if tx < 128:
# Consumer 0: computes left half of output (D//2 dimensions)
# Handles QK matmul, softmax, and PV for left half

elif tx >= 128 and tx < 256:
# Consumer 1: computes right half of output (D//2 dimensions)
# Only does PV matmul for right half

elif tx >= 256:
# Producer: loads KV data from global memory
# Uses async copy with barriers to feed consumers
```

The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed:

```python
# Producer alternates between two buffers
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
# ... load KV into buffer 0
T.cp_async_barrier_noinc(bar_k_0_ready[0])

# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
# ... load KV into buffer 1
T.cp_async_barrier_noinc(bar_k_1_ready[0])
```

Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
Binary file added examples/deepseek_v32/figures/v32_arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,7 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
cost = mask.sum()
return logits, cost


if __name__ == "__main__":
torch.manual_seed(0)
S, SKV, H, HKV, D, kv_stride = 4096, 8192, 32, 1, 64, 1
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
Expand Down Expand Up @@ -304,3 +301,6 @@ def logits_fn():
logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12
print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}")
print(f"cost_ref: {cost_ref}")

if __name__ == "__main__":
test_fp8_lighting_indexer()
10 changes: 10 additions & 0 deletions examples/deepseek_v32/inference/.claude/settings.local.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"permissions": {
"allow": [
"Read(//weka-hg/prod/deepseek/permanent/wanglei/tilelang/examples/deepseek_v32/**)",
"Read(//weka-hg/prod/deepseek/permanent/wanglei/tilelang/examples/deepseek_mla/**)"
],
"deny": [],
"ask": []
}
}
14 changes: 14 additions & 0 deletions examples/deepseek_v32/inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# DeepSeek V3.2

First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count:
```bash
cd inference
export EXPERTS=256
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
```

Launch the interactive chat interface and start exploring DeepSeek's capabilities:
```bash
export CONFIG=config_671B_v3.2.json
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
```
26 changes: 26 additions & 0 deletions examples/deepseek_v32/inference/config_671B_v3.2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"vocab_size": 129280,
"dim": 7168,
"inter_dim": 18432,
"moe_inter_dim": 2048,
"n_layers": 61,
"n_dense_layers": 3,
"n_heads": 128,
"n_routed_experts": 256,
"n_shared_experts": 1,
"n_activated_experts": 8,
"n_expert_groups": 8,
"n_limited_groups": 4,
"route_scale": 2.5,
"score_func": "sigmoid",
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"dtype": "fp8",
"scale_fmt": "ue8m0",
"index_n_heads": 64,
"index_head_dim": 128,
"index_topk": 2048
}
100 changes: 100 additions & 0 deletions examples/deepseek_v32/inference/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange

import torch
from safetensors.torch import safe_open, save_file

mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate": ("gate", None),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"norm": ("norm", None),
"lm_head": ("head", 0),
"scale": ("scale", None),
"wq_b": ("wq_b", None),
"wk": ("wk", None),
"k_norm": ("k_norm", None),
"weights_proj": ("weights_proj", None),
}


def main(hf_ckpt_path, save_path, n_experts, mp):
"""
Converts and saves model checkpoint files into a specified format.

Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.

Returns:
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]

for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping, f"Key {key} not found in mapping"
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(
dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param

os.makedirs(save_path, exist_ok=True)

for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))

for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
Loading
Loading