- 
                Notifications
    You must be signed in to change notification settings 
- Fork 292
[Example] Add topk into sparse mla example and append some docs #901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…t new directory structure and file descriptions for deepseek_v32 example. Added sections for architecture overview, Lightning Indexer, Top-k Selector, and Sparse MLA Forward implementations.
…_v32 example scripts - Added per-file ignores for the inference directory in `pyproject.toml`. - Refactored code in `topk_selector.py`, `convert.py`, `generate.py`, `kernel.py`, and `model.py` to enhance readability by adjusting spacing and line breaks. - Ensured consistent formatting across function definitions and assertions for better clarity.
| Warning Rate limit exceeded@LeiWang1999 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 28 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the  We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (1)
 WalkthroughAdds a complete DeepSeek v3.2 inference example: documentation, config and requirements, HF-to-sharded safetensors conversion, interactive/batch generation, TileLang FP8 kernels and top‑k selector, a distributed Transformer (MLA/MoE/indexer), tests, and local tooling/settings. No changes to existing public APIs outside new modules. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  actor U as User
  participant C as convert.py
  participant HF as HF safetensors (*.safetensors)
  participant S as Sharded outputs (model{i}-mp{mp}.safetensors)
  participant T as Token files
  U->>C: Run convert.py (--hf-ckpt-path, --save-path, --n-experts, --model-parallel)
  C->>HF: Load and iterate tensors
  C->>C: Remap names, compute shard slices, shard expert/dense tensors
  C->>S: Save per‑shard safetensors
  C->>T: Copy tokenizer files
  C-->>U: Return saved model path
sequenceDiagram
  autonumber
  actor U as User
  participant G as generate.py
  participant Tok as AutoTokenizer
  participant M as Transformer (model.py)
  participant K as FP8 Kernels / Topk (kernel.py / topk_selector.py)
  U->>G: torchrun generate.py --ckpt-path --config ...
  G->>Tok: Tokenize prompt(s)
  G->>M: Load model shards, init distributed
  loop autoregressive steps
    G->>M: Forward(tokens)
    M->>K: FP8 quant / GEMM / Index / Topk
    K-->>M: Intermediate tensors
    M-->>G: Logits
    G->>G: Sample/argmax, append token(s)
  end
  G->>Tok: Decode output
  G-->>U: Print/emit completions
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
 Suggested reviewers
 Poem
 Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
 ✅ Passed checks (1 passed)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🧪 Early access (Sonnet 4.5): enabledWe are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience. Note: 
 Comment  | 
| 👋 Hi! Thank you for contributing to the TileLang project. Please remember to run  We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (18)
examples/deepseek_v32/inference/README.md (1)
1-14: Enhance documentation with prerequisites and guidance.The workflow is clearly structured, but users may need additional context:
- Hardware requirements: No mention of GPU count, VRAM requirements, or how to choose the
MPvalue- Model download: Missing instructions for obtaining
HF_CKPT_PATH(Hugging Face model location)- EXPERTS parameter: No explanation of why 256 or how to adjust it
- Generation modes: Only
--interactiveis shown; batch mode is not documentedConsider adding a prerequisites section:
## Prerequisites - **Hardware**: 8x A100 80GB GPUs (or adjust `MP` for your setup) - **Model weights**: Download from Hugging Face: ```bash export HF_CKPT_PATH=/path/to/deepseek-v3.2-671B # huggingface-cli download deepseek-ai/DeepSeek-V3.2 --local-dir $HF_CKPT_PATH
- Experts: Set to match your model configuration (256 for v3.2)
Usage
1. Convert Model Weights
...
2. Interactive Generation
...
3. Batch Generation (Optional)
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --prompts prompts.txtThis helps users understand requirements before running commands. </blockquote></details> <details> <summary>examples/deepseek_v32/inference/convert.py (1)</summary><blockquote> `56-57`: **Document the reason for skipping layer 61.** The skip of `model.layers.61` is hardcoded without explanation. Consider adding a comment explaining why this specific layer is excluded from conversion. ```diff for name in f.keys(): + # Skip layer 61: [reason for exclusion] if "model.layers.61" in name: continueexamples/deepseek_v32/inference/kernel.py (1)
8-12: Consider using the non-deprecated fast math configuration.According to the TileLang documentation,
TL_DISABLE_FAST_MATHis deprecated and will be removed in version 0.1.7. Consider usingTL_ENABLE_FAST_MATH: Falseinstead.Apply this diff:
pass_configs = { tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, }examples/deepseek_v32/topk_selector.py (1)
180-185: Public API wrapper is straightforward.The function correctly allocates the output tensor and invokes the kernel.
Minor note: As flagged by static analysis,
seq_lenis unpacked on line 181 but not used in the function body (it's passed implicitly via theinputtensor shape). This is harmless but could be simplified:def tl_topk(input, starts, ends, topk): - batch, seq_len = input.shape + batch, _ = input.shape indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device)examples/deepseek_v32/inference/model.py (9)
11-11: Use absolute import for kernel module.The relative import
from kernel import act_quant, fp8_gemm, fp8_indexassumes the script is run from a specific directory. For better portability and explicit dependencies, use an absolute import.Apply this diff:
-from kernel import act_quant, fp8_gemm, fp8_index +from examples.deepseek_v32.inference.kernel import act_quant, fp8_gemm, fp8_indexAlternatively, if this is intended to remain a local import, add a comment explaining the expected execution context.
13-15: Initialize distributed globals properly.The module-level globals
world_size,rank, andblock_sizeare reassigned inTransformer.__init__. This pattern can lead to confusion and makes testing harder. Consider encapsulating these in a configuration object or passing them explicitly.If you prefer to keep globals for simplicity in this inference script, consider adding a comment explaining that these are overwritten during
Transformer.__init__based on distributed context.
162-168: Clarify assertion for bias parameter.Line 162 asserts
bias is None, but the function signature allowsOptional[torch.Tensor]. This indicates bias is not yet supported. Either remove the parameter or add a TODO comment explaining future support.Apply this diff:
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor: """ ... """ - assert bias is None + assert bias is None, "Bias is not yet supported in quantized linear"
508-511: Remove debug assertion in production code.Lines 508-511 broadcast
topk_indicesto verify consistency across ranks using an assertion. This is useful for debugging but adds overhead in production inference. Consider removing or gating it behind a debug flag.topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1] - topk_indices_ = topk_indices.clone() - dist.broadcast(topk_indices_, src=0) - assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}" + # Debug check (can be removed in production): + # topk_indices_ = topk_indices.clone() + # dist.broadcast(topk_indices_, src=0) + # assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}" return topk_indices
714-714: Document the hardcoded dimension check for bias.Line 714 conditionally creates a bias parameter only when
self.dim == 7168. This appears to be a model-specific configuration. Document why this dimension requires bias.- self.bias = nn.Parameter(torch.empty(args.n_routed_experts, - dtype=torch.float32)) if self.dim == 7168 else None + # Bias is only used for specific model sizes (e.g., 671B model uses dim=7168) + self.bias = nn.Parameter(torch.empty(args.n_routed_experts, + dtype=torch.float32)) if self.dim == 7168 else None
839-849: Optimize expert routing to avoid sparse indexing overhead.Lines 839-849 use
torch.whereto find indices for each expert and then index into the input. For large batch sizes or many experts, this can be inefficient. Consider batching expert calls or using grouped GEMMs if available.This is a known pattern in MoE inference. If performance is critical, explore grouped GEMM kernels or expert batching strategies used in libraries like Megablocks or tutel.
891-898: Clarify residual connection flow in Block.forward.The
Block.forwardsignature includes bothxandresidualparameters, and the norm layers can return updated residuals. This pattern is correct but may be confusing. Add a comment explaining the fused residual + norm pattern.def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: """ Forward pass for the Transformer block. + + Note: RMSNorm can fuse residual addition for efficiency. + If residual is None, this is the first block; otherwise, residual accumulates. ... """
921-925: Avoid mutating global state in class initialization.Lines 921-925 modify module-level globals
world_size,rank, andLinear.dtype/scale_fmtduringTransformer.__init__. This makes the module stateful and harder to test. Consider passing these as explicit parameters or using a context manager.For an inference script, this pattern is acceptable for simplicity, but document that
Transformer.__init__has side effects on module-level state.
965-972: Add guard for main block execution.The
if __name__ == "__main__"block instantiates a model and runs a forward pass. This is useful for testing but should not run during imports. Consider moving to a separate test file or adding a comment that this is for quick validation.The test code is fine for a standalone script. If this file is ever imported as a module, ensure the main block doesn't execute unintentionally.
examples/deepseek_v32/inference/generate.py (5)
25-27: Clarify Gumbel sampling implementation.Lines 25-27 implement Gumbel-max sampling by dividing probabilities by exponential noise. This is correct but uncommon. The more typical approach is
torch.multinomial. Document why this method is preferred (e.g., efficiency, determinism).If you prefer the Gumbel trick for performance or reproducibility, add a comment:
def sample(logits, temperature: float = 1.0): """ Samples a token from the logits using temperature scaling. + + Uses Gumbel-max trick: argmax(log(p) - log(-log(u))) for efficiency. ... """
111-111: Document the choice of random seed.Line 111 sets a manual seed to
33377335. If this is for reproducibility, document it. If it's arbitrary, consider using a configurable seed via CLI or environment variable.- torch.manual_seed(33377335) + torch.manual_seed(33377335) # Fixed seed for reproducibility
125-134: Simplify interactive prompt collection.Lines 125-134 handle prompt input differently based on world size. The logic for
world_size == 1is redundant since theelif rank == 0branch can handle both cases.Apply this diff:
- if world_size == 1: - prompt = input(">>> ") - elif rank == 0: + if rank == 0: prompt = input(">>> ") objects = [prompt] - dist.broadcast_object_list(objects, 0) + if world_size > 1: + dist.broadcast_object_list(objects, 0) else: objects = [None] dist.broadcast_object_list(objects, 0) prompt = objects[0]
148-166: Validate input file format in batch mode.Lines 148-166 read prompts from a file using
split("\n\n"). This assumes double-newline delimited prompts. Document this format or add validation for malformed input.Add a docstring or comment explaining the expected format:
else: + # Batch mode: expects prompts separated by double newlines in input_file with open(input_file) as f: prompts = f.read().split("\n\n")
195-195: Improve CLI validation message.Line 195 asserts that either
input_fileorinteractivemust be specified. The assertion message is clear, but usingargparsegroups or mutual exclusivity would be more idiomatic.Consider using
argparsemutual exclusivity:parser = ArgumentParser() parser.add_argument("--ckpt-path", type=str, required=True) parser.add_argument("--config", type=str, required=True) - parser.add_argument("--input-file", type=str, default="") - parser.add_argument("--interactive", action="store_true") + mode_group = parser.add_mutually_exclusive_group(required=True) + mode_group.add_argument("--input-file", type=str) + mode_group.add_argument("--interactive", action="store_true") parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--temperature", type=float, default=0.6) args = parser.parse_args() - assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
- examples/deepseek_v32/figures/v32_arch.pngis excluded by- !**/*.png
📒 Files selected for processing (11)
- examples/deepseek_v32/README.md(1 hunks)
- examples/deepseek_v32/inference/.claude/settings.local.json(1 hunks)
- examples/deepseek_v32/inference/README.md(1 hunks)
- examples/deepseek_v32/inference/config_671B_v3.2.json(1 hunks)
- examples/deepseek_v32/inference/convert.py(1 hunks)
- examples/deepseek_v32/inference/generate.py(1 hunks)
- examples/deepseek_v32/inference/kernel.py(1 hunks)
- examples/deepseek_v32/inference/model.py(1 hunks)
- examples/deepseek_v32/inference/requirements.txt(1 hunks)
- examples/deepseek_v32/topk_selector.py(1 hunks)
- pyproject.toml(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
examples/deepseek_v32/inference/generate.py (2)
examples/deepseek_v32/inference/model.py (16)
Transformer(901-962)
ModelArgs(19-91)
forward(113-134)
forward(202-212)
forward(231-242)
forward(267-283)
forward(301-321)
forward(336-337)
forward(478-511)
forward(579-644)
forward(670-680)
forward(716-748)
forward(774-784)
forward(825-849)
forward(877-898)
forward(938-962)examples/deepseek_v32/inference/convert.py (1)
main(36-89)
examples/deepseek_v32/inference/kernel.py (10)
tilelang/__init__.py (1)
set_log_level(25-35)tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)tilelang/language/tir/op.py (2)
reinterpret(1816-1835)
if_then_else(2824-2854)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/__init__.py (2)
symbolic(83-94)
use_swizzle(97-103)tilelang/language/allocate.py (2)
alloc_shared(21-36)
alloc_fragment(53-64)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/reduce.py (2)
reduce_absmax(127-139)
reduce_sum(87-109)tilelang/language/customize.py (1)
clamp(110-123)
examples/deepseek_v32/topk_selector.py (6)
tilelang/language/tir/op.py (2)
reinterpret(1816-1835)
if_then_else(2824-2854)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/__init__.py (1)
symbolic(83-94)tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/fill.py (1)
fill(9-21)tilelang/language/builtin.py (1)
sync_threads(333-341)
examples/deepseek_v32/inference/model.py (1)
examples/deepseek_v32/inference/kernel.py (3)
act_quant(83-106)
fp8_gemm(167-190)
fp8_index(248-268)
examples/deepseek_v32/inference/convert.py (1)
examples/deepseek_v32/inference/generate.py (1)
main(81-169)
🪛 markdownlint-cli2 (0.18.1)
examples/deepseek_v32/README.md
123-123: Spaces inside emphasis markers
(MD037, no-space-in-emphasis)
123-123: Spaces inside emphasis markers
(MD037, no-space-in-emphasis)
🪛 Ruff (0.13.1)
examples/deepseek_v32/inference/generate.py
107-107: Do not assign a lambda expression, use a def
Rewrite print as a def
(E731)
examples/deepseek_v32/topk_selector.py
181-181: Unpacked variable seq_len is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
🔇 Additional comments (25)
pyproject.toml (1)
60-60: LGTM! Appropriate lint exclusion for example code.The per-file ignore follows the established pattern and is appropriate for the new inference example directory.
examples/deepseek_v32/inference/config_671B_v3.2.json (1)
1-26: Configuration is well-aligned with model code.The JSON configuration keys match the
ModelArgsdataclass fields inmodel.py. The config is loaded viaModelArgs(**json.load(f))ingenerate.py, ensuring exact key matching. All 25 parameters (vocab_size, n_routed_experts, q_lora_rank, kv_lora_rank, moe_inter_dim, index_topk, etc.) are correctly referenced throughout the model initialization and forward passes.examples/deepseek_v32/inference/.claude/settings.local.json (1)
1-10: Verify whether local settings files should be committed.This file contains user-specific filesystem paths (
wanglei) and internal infrastructure references (//weka-hg/prod/deepseek/permanent/). The.local.naming convention typically indicates environment-specific configuration that should not be version-controlled.Verification confirmed:
- This file is currently tracked by git and being committed
- No
.gitignorepattern exists to exclude.local.jsonfiles- This is the only
.claude/settings.local.jsonfile in the repository- No template or example file exists for reference
Recommendations:
- Add
.claude/settings.local.json(or the pattern**/.claude/settings.local.json) to.gitignore- Remove this file from version control:
git rm --cached examples/deepseek_v32/inference/.claude/settings.local.json- Create a template file like
.claude/settings.example.jsondocumenting the expected structure with placeholder pathsexamples/deepseek_v32/inference/convert.py (5)
10-33: Well-structured parameter mapping.The mapping dictionary provides a clear and maintainable way to handle parameter name transformations and shard dimensions. The tuple structure
(new_name, shard_dimension)effectively encodes both the renaming and sharding strategy.
69-80: Sharding logic is well-implemented.The dual sharding strategy (expert-based and dimension-based) is clearly implemented. The divisibility check on line 76-77 prevents runtime errors, and the use of
.contiguous()ensures proper memory layout after slicing.Note: The expert index extraction on line 72 using
split(".")[-3]makes a similar assumption about name structure as the earlier key extraction. This should be covered by the verification script suggested above.
82-89: File I/O operations look correct.The directory creation, shard file saving, and token file copying are all implemented correctly. The naming convention
model{i}-mp{mp}.safetensorsaligns with the loading pattern ingenerate.py.
92-100: CLI argument handling is appropriate.The argument parser is well-structured with clear names, and the divisibility assertion on line 99 provides early validation before the conversion begins. This prevents wasted processing time on invalid configurations.
65-66: Key extraction logic verified and robust.The
split(".")[-2]approach correctly extracts mapping keys for all standard PyTorch checkpoint parameter names. Verification confirms that parameter names follow the convention<module_path>.<param_name>.<param_type>, ensuring at least 2 dot-separated components (minimum:"param_name.weight"). The runtime assertion on line 66 provides additional validation, making this implementation both safe and appropriate for the expected checkpoint format.examples/deepseek_v32/inference/kernel.py (4)
19-32: Helper functions implement standard FP8 quantization patterns.The bit manipulation helpers (
fast_log2_ceil,fast_pow2,fast_round_scale) correctly implement IEEE 754 float32 exponent extraction and reconstruction for efficient FP8 scaling. The use ofT.reinterpretandT.if_then_elseis appropriate for TileLang kernels.
35-106: Activation quantization kernel is well-structured.The implementation correctly handles block-wise FP8 quantization with per-block scaling:
- The safety bound
max(amax, 1e-4)on line 68 prevents division by zero.- The clamping to
[fp8_min, fp8_max]on line 74 ensures valid FP8 range.- Input validation in the public API ensures contiguity and block alignment.
- The conditional pipeline depth based on
round_scaleis a good optimization.
109-190: FP8 GEMM implementation follows best practices.The kernel demonstrates advanced TileLang techniques:
- 4-stage pipeline (line 145) for overlapping computation and memory transfers.
- Swizzle pattern (line 140) to improve L2 cache locality.
- Proper double-buffer accumulation pattern to enable 2x accumulation parallelism.
- Correct handling of per-block scales for both inputs.
The public API provides appropriate input validation and tensor reshaping.
193-268: FP8 index kernel correctly implements the scoring computation.The kernel follows the documented computation flow:
- FP8 Q @ FP8 K → FP32 logits (lines 225-232)
- ReLU(logits) * q_s → scaled logits (line 235)
- Sum reduction → logits_sum (line 238)
- logits_sum * k_s → index_score (line 241)
The use of
out_idx=[4]properly indicates the output tensor position in the signature, and the 2-stage pipeline provides good performance.examples/deepseek_v32/topk_selector.py (4)
5-24: Bit conversion helpers correctly handle float-to-unsigned mapping.The conversion functions implement the standard technique for treating floats as unsigned integers for magnitude comparison (essential for radix-based selection):
- Negative values are bitwise-inverted to maintain ordering.
- Positive values have the sign bit set to ensure they sort after negatives.
The
TL_DISABLE_THREAD_STORAGE_SYNCconfiguration means synchronization must be manually managed (which is done correctly throughout the kernel).
27-112: Stage 1 histogram-based filtering is well-implemented.The first stage correctly implements the 8-bit histogram approach:
- Atomic histogram accumulation (line 73) safely handles concurrent updates.
- The parallel cumulative sum (lines 78-85) uses proper barriers with
barrier_id=3andarrive_count=RADIX.- Threshold bin identification (line 89) correctly finds where cumsum exceeds topk.
- Position allocation via
atomic_add(..., return_prev=True)(line 105) prevents race conditions.The assumption that the threshold bucket size is less than 4K (line 33) is reasonable for typical distributions but should be validated for edge cases.
113-176: Stage 2 multi-pass refinement correctly refines the top-k selection.The tail pass implements a sophisticated multi-round approach:
- Ping-pong buffering (line 118) efficiently reuses shared memory.
- Progressive bit extraction (line 130-132) refines the selection 8 bits at a time.
- Early exit (lines 115-116) avoids unnecessary work.
- Final round (line 167) directly writes results with proper bounds checking.
The hardcoded limit of 4 rounds provides 32 bits of resolution (4 × 8 bits), which is appropriate for float32 inputs.
188-244: Test code follows GPU benchmarking best practices.The test/demo code is well-structured:
- Uses CUDA events for accurate timing (lines 220-221).
- Includes proper warmup iterations (lines 224-226) to avoid cold-start effects.
- Validates correctness via set intersection (lines 207-215), which is appropriate for top-k where order isn't guaranteed.
- Compares against PyTorch's built-in implementation for both correctness and performance.
examples/deepseek_v32/inference/model.py (6)
624-626: Cache dequantized weights to avoid redundant computation.Lines 624-626 lazily dequantize
wkv_bduring decode passes. This is good, but ensure the cache (self.dequant_wkv_b) is properly invalidated if weights are updated (e.g., during fine-tuning or weight swapping).The lazy dequantization pattern improves decode performance. Just confirm that weight updates (if any) invalidate the cache.
933-934: Clarify dtype choice for lm_head.The comment on line 933 explains that
lm_headis stored infp32for "easier computation of logits later." Ensure this doesn't cause precision issues when the rest of the model is inbf16orfp8.Using
fp32for the final projection is a reasonable choice to maintain numerical stability for logits. The comment is helpful.
440-443: Document therotate_activationfunction and its bfloat16 requirement.The function lacks a docstring. Add one that explains:
- What the function does (applies Hadamard transform to activations)
- The bfloat16 dtype requirement and why it's necessary
- Parameters and return type
The
fast_hadamard_transformdependency is already listed inrequirements.txt.
958-962: Concatenation order is consistent with vocab sharding.The vocab sharding in
ParallelEmbedding(lines 108-110) assigns consecutive shards by rank: rank 0 owns vocab indices[0, part_vocab_size), rank 1 owns[part_vocab_size, 2*part_vocab_size), etc. Theall_gatheroperation collects logits in rank order, andtorch.cat(all_logits, dim=-1)concatenates them in that same order, correctly reconstructing the full vocabulary dimension.
107-110: Manual verification recommended for distributed embedding sharding.The
ParallelEmbeddingimplementation appears mathematically correct:
- Each rank owns vocabulary indices
[vocab_start_idx, vocab_end_idx)- Tokens outside a rank's range are masked, zeroed after lookup, and summed via all-reduce
- Edge cases at shard boundaries (e.g.,
vocab_start_idx,vocab_end_idx-1) are handled correctly by the mask conditions(x < vocab_start_idx) | (x >= vocab_end_idx)However, no tests were found for this class. Given the complexity of distributed operations and the critical nature of embedding correctness, verify the implementation with multi-rank testing to ensure token lookups, masking, and all-reduce behavior work as expected across different vocabulary distributions.
514-522: Verify weight_dequant correctness - weight quantization format is external and untested.The
weight_dequantfunction performs block-wise dequantization with a transpose operation (line 517:.transpose(1, 2)), but verification reveals several concerns:
No matching quantization code: The codebase contains no weight quantization implementation. The
kernel.pyfile only provides activation quantization (act_quant), not weight quantization. Weights come pre-quantized from the HuggingFace checkpoint viaconvert.py.
No tests: No test coverage exists for
weight_dequantto validate the block layout transformations.
Single usage: Only called once (line 625) for
wkv_bweights during MHA decode, making bugs harder to detect.
Unclear block layout: The transpose suggests weights are stored with interleaved dimensions:
(out_blocks, block_size, in_blocks, block_size)→(out_blocks, in_blocks, block_size, block_size), but this layout is not documented or validated against the checkpoint format.Recommendation: Verify that the weight block layout in the checkpoint matches the expected format, especially the dimension ordering and transpose operation. Consider adding unit tests to validate dequantization correctness.
examples/deepseek_v32/inference/generate.py (1)
50-52: Verify assertion message consistency.The assertion message on line 52 uses
max_seq_lenin the message, which matches the variable name inModelArgs. Ensure this is clear to users who may not be familiar with the model internals.The assertion is clear and prevents a common error. Good defensive programming.
examples/deepseek_v32/README.md (2)
14-168: Excellent documentation expansion.The added sections provide clear, detailed explanations of the Lightning Indexer, Top-k Selector, and Sparse MLA components. The code snippets and architecture diagram references help readers understand the implementation.
This documentation significantly improves the usability of the example. Great work!
123-123: Fix markdown emphasis formatting.Line 123 has spaces inside emphasis markers. The linter flags
* seq_len *which should be*seq_len*or similar.Review the emphasis syntax on line 123 and remove extra spaces:
-This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). +This reduces compute from O(seq_len × seq_len_kv) to O(seq_len × topk).Or if using italics:
-This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). +This reduces compute from O(*seq_len* × *seq_len_kv*) to O(*seq_len* × *topk*).Likely an incorrect or invalid review comment.
| global print | ||
| if rank != 0: | ||
| print = lambda *_, **__: None | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Use a proper no-op function instead of lambda.
Line 107 assigns a lambda to print to suppress output on non-zero ranks. Static analysis flags this (Ruff E731). Use a proper function definition.
Apply this diff:
     global print
     if rank != 0:
-        print = lambda *_, **__: None
+        def print(*args, **kwargs):
+            pass📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| global print | |
| if rank != 0: | |
| print = lambda *_, **__: None | |
| global print | |
| if rank != 0: | |
| def print(*args, **kwargs): | |
| pass | 
🧰 Tools
🪛 Ruff (0.13.1)
107-107: Do not assign a lambda expression, use a def
Rewrite print as a def
(E731)
🤖 Prompt for AI Agents
In examples/deepseek_v32/inference/generate.py around lines 105 to 107, the code
suppresses printing by assigning a lambda to the built-in print (which triggers
Ruff E731); replace the lambda with a proper no-op function definition (e.g.,
define def _noop(*args, **kwargs): pass) and assign print = _noop for non-zero
ranks so static analysis is satisfied and intent remains the same.
| torch | ||
| transformers | ||
| safetensors | ||
| fast_hadamard_transform | ||
| tilelang==0.1.6 No newline at end of file | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Pin dependency versions for reproducibility and security.
Most dependencies lack version constraints, which can lead to:
- Non-reproducible builds when dependency APIs change
- Unexpected breakage from major version updates
- Difficulty tracking security vulnerabilities
Consider pinning minimum (or exact) versions:
-torch
-transformers
-safetensors
-fast_hadamard_transform
+torch>=2.0.0
+transformers>=4.30.0
+safetensors>=0.3.0
+fast_hadamard_transform>=0.3.0
 tilelang==0.1.6Alternatively, generate a requirements-lock.txt using pip freeze after testing to capture exact working versions.
Run the following to check for known vulnerabilities in these packages:
🏁 Script executed:
#!/bin/bash
# Check for security advisories on the core dependencies
for pkg in torch transformers safetensors fast_hadamard_transform; do
  echo "=== Checking $pkg ==="
  gh api graphql -f query="
  {
    securityVulnerabilities(first: 3, ecosystem: PIP, package: \"$pkg\") {
      nodes {
        advisory {
          summary
          severity
          publishedAt
        }
        vulnerableVersionRange
        firstPatchedVersion {
          identifier
        }
      }
    }
  }"
doneLength of output: 1778
🏁 Script executed:
#!/bin/bash
# Check latest stable versions of torch and transformers via PyPI
echo "=== Checking torch latest version ==="
curl -s https://pypi.org/pypi/torch/json | jq -r '.info.version'
echo "=== Checking transformers latest version ==="
curl -s https://pypi.org/pypi/transformers/json | jq -r '.info.version'
echo "=== Checking safetensors latest version ==="
curl -s https://pypi.org/pypi/safetensors/json | jq -r '.info.version'
echo "=== Checking fast_hadamard_transform latest version ==="
curl -s https://pypi.org/pypi/fast_hadamard_transform/json | jq -r '.info.version'Length of output: 213
Pin dependency versions to avoid critical security vulnerabilities.
Without version constraints, unpinned dependencies can install vulnerable versions. Security scan reveals:
- torch < 2.6.0: CRITICAL remote code execution via torch.load
- torch ≤ 2.7.1: MODERATE resource shutdown vulnerability
- transformers < 4.53.0: Multiple MODERATE ReDoS vulnerabilities
Recommend pinning to patched versions:
-torch
-transformers
-safetensors
-fast_hadamard_transform
+torch>=2.8.0
+transformers>=4.53.0
+safetensors>=0.4.0
+fast_hadamard_transform>=0.3.0
 tilelang==0.1.6Latest stable versions (as of Sept 2025): torch 2.8.0, transformers 4.56.2, safetensors 0.6.2, fast_hadamard_transform 1.0.4.post1.
Alternatively, use pip freeze > requirements-lock.txt after testing to capture exact working versions.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/deepseek_v32/inference/requirements.txt lines 1-5 the dependencies
are unpinned which allows installing vulnerable package versions; update the
file to pin packages to safe versions (e.g. torch>=2.8.0,<2.9,
transformers>=4.56.2,<4.57, safetensors>=0.6.2,<0.7,
fast_hadamard_transform>=1.0.4.post1,<1.1 and keep tilelang==0.1.6 or pin it if
needed) and after validating the environment run pip freeze >
requirements-lock.txt to capture exact tested versions for reproducible, secure
installs.
…clarity and consistency - Updated `fp8_lighting_indexer.py` to define a dedicated test function for the lighting indexer. - Refactored `sparse_mla_fwd_pipelined.py` and `sparse_mla_fwd.py` to standardize test function parameters and improve readability. - Enhanced `topk_selector.py` by introducing a test function with parameters for batch size and sequence length. - Ensured all test functions are invoked correctly in the main execution block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1)
1-24: Fix formatting to resolve pipeline failure.The CI pipeline failed with: "Reformatted files. Please review and stage the changes." Run the project's formatter (likely
ruff formatorblack) and commit the changes.#!/bin/bash # Run formatter to fix the issue cd examples/deepseek_v32 ruff format test_tilelang_example_deepseek_v32.py || black test_tilelang_example_deepseek_v32.py
🧹 Nitpick comments (4)
examples/deepseek_v32/topk_selector.py (1)
181-181: Address static analysis hint: unused variableseq_len.Ruff flags that
seq_lenis unpacked but never used intl_topk. Since the kernel requiresseq_lenas a symbolic dimension, the variable is actually used implicitly by the kernel invocation. You can silence the warning by prefixing with an underscore if it's truly not needed, or keep it as-is if the JIT compilation infers the shape from the tensor.If
seq_lenis not explicitly referenced, apply this diff to silence the warning:def tl_topk(input, starts, ends, topk): - batch, seq_len = input.shape + batch, _seq_len = input.shape indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device)Alternatively, if the kernel's symbolic dimension inference requires the variable to be named
seq_len, you can ignore the static analysis hint and add a comment explaining the necessity.examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (2)
1-1: Consider disabling specific linting rules instead of blanketnoqa.A blanket
# ruff: noqasuppresses all linting feedback for the entire file. If certain rules need to be disabled (e.g., for relative imports), specify them explicitly (e.g.,# ruff: noqa: TID252).
9-19: Consider removing wrapper functions if not required by test framework.The wrapper functions (
test_example_*) simply delegate to the imported test functions without adding parameters, documentation, or additional logic. If the test discovery framework doesn't require thetest_example_prefix, consider importing and using the original functions directly.examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)
438-439: Consider removing commented-out debug prints.The commented debug statements are no longer needed and can be removed to keep the code clean.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
- examples/deepseek_v32/fp8_lighting_indexer.py(2 hunks)
- examples/deepseek_v32/sparse_mla_fwd.py(2 hunks)
- examples/deepseek_v32/sparse_mla_fwd_pipelined.py(3 hunks)
- examples/deepseek_v32/test_tilelang_example_deepseek_v32.py(1 hunks)
- examples/deepseek_v32/topk_selector.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (3)
examples/deepseek_v32/topk_selector.py (1)
test_topk_selector(188-245)examples/deepseek_v32/sparse_mla_fwd.py (2)
sparse_mla_fwd(14-173)
test_sparse_mla_fwd(234-268)examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)
sparse_mla_fwd(18-311)
test_sparse_mla_fwd_pipelined(400-451)
examples/deepseek_v32/topk_selector.py (7)
tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)tilelang/language/tir/op.py (2)
reinterpret(1816-1835)
if_then_else(2824-2854)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/__init__.py (1)
symbolic(83-94)tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/fill.py (1)
fill(9-21)tilelang/language/builtin.py (1)
sync_threads(333-341)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)
examples/deepseek_v32/sparse_mla_fwd.py (3)
fn(256-257)
ref_sparse_mla_fwd_interface(197-231)
test_sparse_mla_fwd(234-268)
🪛 GitHub Actions: CI Test on AMD
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
[error] 1-1: Reformatted files. Please review and stage the changes.
🪛 Ruff (0.13.1)
examples/deepseek_v32/topk_selector.py
181-181: Unpacked variable seq_len is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (11)
examples/deepseek_v32/sparse_mla_fwd.py (2)
234-242: LGTM! Good parameterization for test reusability.The test function now accepts configurable parameters with sensible defaults, making it easier to exercise different tensor shapes and configurations. The defaults align with the original hard-coded values and the explicit call in
__main__.
272-273: LGTM! Explicit parameter passing preserves original behavior.The
__main__block now calls the parameterized test function with explicit keyword arguments, maintaining the original test behavior while benefiting from the new flexible interface.examples/deepseek_v32/fp8_lighting_indexer.py (2)
261-282: LGTM! Test function now supports parameterized configuration.The refactoring extracts the test logic into a reusable function with sensible defaults (S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1), enabling flexible test invocation from external test harnesses.
305-306: LGTM! Simplified entry point.The
__main__block now delegates to the parameterized test function, maintaining the original behavior while improving modularity.examples/deepseek_v32/topk_selector.py (5)
5-7: Pass config disables automatic synchronization—ensure manual sync is correct.The
TL_DISABLE_THREAD_STORAGE_SYNCflag disables automatic__syncthreads()insertion for shared memory coordination. Verify that all manualT.sync_threads()calls in the kernel (lines 68, 74, 80, 83, 88, 91, 94, 98, etc.) correctly guard shared memory accesses and prevent race conditions.
10-14: Verify bit manipulation logic for signed float→uint16 conversion.The function reinterprets float as float16, then flips bits based on sign to establish an ordering where larger floats map to larger uints. The logic inverts negative values and sets the sign bit for positives, then shifts right by 8 to produce an 8-bit bucket ID.
Double-check that the bit manipulation correctly handles:
- Negative zero vs. positive zero
- NaN and infinity values
- The shift-right by 8 to extract the exponent bucket
17-24: Verify bit manipulation logic for signed float32→uint32 conversion.Similar to
convert_to_uint16, this function reinterprets float32 as uint32 and flips bits to create an ordering. The logic inverts negative values and sets the MSB for positives.Verify correctness for:
- Sign transitions (negative/positive boundary)
- Special float values (NaN, ±Inf, ±0.0)
- Bit ordering consistency with the radix-sort phases
27-177: Complex two-stage top-k kernel—confirm correctness of stage transitions and boundary checks.The kernel implements a two-stage radix-selection algorithm:
Stage 1 (lines 64-111):
- Builds an 8-bit histogram by hashing input to uint16 buckets (lines 69-74)
- Computes cumulative sums to find the threshold bucket (lines 77-93)
- Collects elements above the threshold directly into
index, and elements equal to the threshold intos_input_idxfor stage 2 (lines 97-111)Stage 2 (lines 114-175):
- Iterates up to 4 rounds, refining the threshold with 8-bit slices of uint32 representation (lines 114-175)
- Each round builds a histogram, finds a new threshold, and partitions elements (lines 128-175)
- Final round (round == 3) directly writes remaining elements to output with bounds check (lines 167-171)
Potential concerns:
- Boundary checks: Ensure
input_idx < seq_lenandinput_idx >= l_start_idx and input_idx < l_end_idxare applied consistently to prevent out-of-bounds access.- Atomic races: Multiple
T.atomic_addcalls ons_histogram,s_num_input, and outputindexpositions—verify that all accesses are race-free and output indices stay within[0, topk).- SMEM_INPUT_SIZE=4096 assumption: Line 33 assumes the threshold bucket after stage 1 has ≤4096 elements. If this is violated,
s_input_idxoverflows. Consider adding a runtime assertion or documenting this constraint.- Loop break condition:
if l_new_topk <= 0: T.loop_break()(line 115-116)—confirm that all threads in the block agree on the break condition to avoid divergence issues.- Synchronization: Manual syncs are present, but verify coverage of all shared memory writes/reads (especially around
s_histogram,s_input_idx,s_num_input, ands_threshold_bin_id).Consider adding:
- A runtime assertion that
s_num_input[0] <= SMEM_INPUT_SIZEafter stage 1- Comments documenting the invariants for each stage
- A bounds check on
posbefore writing toindex[bx, pos](e.g., lines 106, 165, 169, 171)
180-185: LGTM! Clean wrapper API.The
tl_topkfunction provides a simple interface: allocates the output tensor, instantiates the kernel, and returns the result. The API is clear and consistent with PyTorch conventions.examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)
400-409: LGTM! Clear parameterization for the pipelined test.The function signature is well-structured with explicit parameters and sensible defaults for testing the pipelined sparse MLA forward pass.
432-433: Masking logic correctly references KV_stride.The masking condition properly zeroes out positions
[:KV_stride - 1]whenq_start_s_index == 0and stride > 1, consistent with the interface function behavior.
| if args.test_correctness: | ||
| B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16 | ||
| else: | ||
| B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 | ||
| test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype) | ||
| test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix function name mismatch: test_sparse_mla_fwd is not defined in this file.
The main block calls test_sparse_mla_fwd, but the function defined in this file is test_sparse_mla_fwd_pipelined. This will cause a NameError at runtime.
Apply this diff to fix the function name:
     if args.test_correctness:
         B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16
     else:
         B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
-    test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
-    test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
+    test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
+    test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if args.test_correctness: | |
| B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16 | |
| else: | |
| B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 | |
| test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype) | |
| test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype) | |
| if args.test_correctness: | |
| B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16 | |
| else: | |
| B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 | |
| test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype) | |
| test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype) | 
🤖 Prompt for AI Agents
In examples/deepseek_v32/sparse_mla_fwd_pipelined.py around lines 458 to 463 the
code calls undefined function test_sparse_mla_fwd causing a NameError; update
the two calls to use the correctly defined function name
test_sparse_mla_fwd_pipelined so they match the function implementation (i.e.,
replace test_sparse_mla_fwd with test_sparse_mla_fwd_pipelined).
| from topk_selector import test_topk_selector | ||
| from fp8_lighting_indexer import test_fp8_lighting_indexer | ||
| from sparse_mla_fwd import test_sparse_mla_fwd | ||
| from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use explicit relative imports for sibling modules.
The imports assume the current working directory is examples/deepseek_v32/, which will fail when test runners or scripts invoke this module from other directories.
Apply this diff to use explicit relative imports:
-from topk_selector import test_topk_selector
-from fp8_lighting_indexer import test_fp8_lighting_indexer
-from sparse_mla_fwd import test_sparse_mla_fwd
-from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
+from .topk_selector import test_topk_selector
+from .fp8_lighting_indexer import test_fp8_lighting_indexer
+from .sparse_mla_fwd import test_sparse_mla_fwd
+from .sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from topk_selector import test_topk_selector | |
| from fp8_lighting_indexer import test_fp8_lighting_indexer | |
| from sparse_mla_fwd import test_sparse_mla_fwd | |
| from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined | |
| from .topk_selector import test_topk_selector | |
| from .fp8_lighting_indexer import test_fp8_lighting_indexer | |
| from .sparse_mla_fwd import test_sparse_mla_fwd | |
| from .sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined | 
🤖 Prompt for AI Agents
In examples/deepseek_v32/test_tilelang_example_deepseek_v32.py around lines 4 to
7, the module uses implicit (absolute) sibling imports which break when the
module is executed from outside the package; change each import to an explicit
relative import (e.g., from .topk_selector import test_topk_selector, from
.fp8_lighting_indexer import test_fp8_lighting_indexer, etc.) so the tests
reliably import sibling modules regardless of the current working directory or
test runner.
| def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): | ||
|  | ||
| batch = 64 | ||
| seq_len = 32 * 1024 | ||
| topk = 2048 | ||
| torch.manual_seed(1) | ||
| input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() | ||
| starts = torch.zeros(batch, dtype=torch.int32).cuda() | ||
| ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len | ||
|  | ||
| indexes = tl_topk(input, starts, ends, topk) | ||
| print(indexes) | ||
|  | ||
| indexes_ref = torch.topk(input, topk, dim=-1)[1] | ||
| print(indexes_ref) | ||
|  | ||
| # indexes_ref = fast_topk(input, topk) | ||
| # print(indexes_ref) | ||
|  | ||
| # Calculate intersection of out_ref and out_trt | ||
| for i in range(batch): | ||
| ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() | ||
| trt_np = indexes[i].cpu().to(torch.int32).numpy() | ||
|  | ||
| set_ref = set(ref_np) | ||
| set_trt = set(trt_np) | ||
| intersection = set_ref & set_trt | ||
| print("selected/all:", len(intersection), "/", len(set_ref), "=", | ||
| len(intersection) / len(set_ref)) | ||
|  | ||
| # Performance test with CUDA events | ||
|  | ||
| torch.cuda.synchronize() | ||
| start_event = torch.cuda.Event(enable_timing=True) | ||
| end_event = torch.cuda.Event(enable_timing=True) | ||
|  | ||
| # Warmup | ||
| for _ in range(5): | ||
| _ = tl_topk(input, starts, ends, topk) | ||
| torch.cuda.synchronize() | ||
|  | ||
| n_iters = 20 | ||
| start_event.record() | ||
| for _ in range(n_iters): | ||
| _ = tl_topk(input, starts, ends, topk) | ||
| end_event.record() | ||
| torch.cuda.synchronize() | ||
| elapsed_time_ms = start_event.elapsed_time(end_event) | ||
| print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms") | ||
|  | ||
| # Torch topk time | ||
| start_event.record() | ||
| for _ in range(n_iters): | ||
| _ = torch.topk(input, topk, dim=-1)[1] | ||
| end_event.record() | ||
| torch.cuda.synchronize() | ||
| elapsed_time_ms = start_event.elapsed_time(end_event) | ||
| print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test function hardcodes parameters that shadow the function arguments.
Lines 190-192 redefine batch, seq_len, and topk with hard-coded values, ignoring the function parameters. This defeats the purpose of parameterization and will confuse callers who pass different values.
Apply this diff to use the function parameters:
 def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
-
-    batch = 64
-    seq_len = 32 * 1024
-    topk = 2048
     torch.manual_seed(1)
     input = torch.randn(batch, seq_len, dtype=torch.float32).cuda()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): | |
| batch = 64 | |
| seq_len = 32 * 1024 | |
| topk = 2048 | |
| torch.manual_seed(1) | |
| input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() | |
| starts = torch.zeros(batch, dtype=torch.int32).cuda() | |
| ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len | |
| indexes = tl_topk(input, starts, ends, topk) | |
| print(indexes) | |
| indexes_ref = torch.topk(input, topk, dim=-1)[1] | |
| print(indexes_ref) | |
| # indexes_ref = fast_topk(input, topk) | |
| # print(indexes_ref) | |
| # Calculate intersection of out_ref and out_trt | |
| for i in range(batch): | |
| ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() | |
| trt_np = indexes[i].cpu().to(torch.int32).numpy() | |
| set_ref = set(ref_np) | |
| set_trt = set(trt_np) | |
| intersection = set_ref & set_trt | |
| print("selected/all:", len(intersection), "/", len(set_ref), "=", | |
| len(intersection) / len(set_ref)) | |
| # Performance test with CUDA events | |
| torch.cuda.synchronize() | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| # Warmup | |
| for _ in range(5): | |
| _ = tl_topk(input, starts, ends, topk) | |
| torch.cuda.synchronize() | |
| n_iters = 20 | |
| start_event.record() | |
| for _ in range(n_iters): | |
| _ = tl_topk(input, starts, ends, topk) | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_time_ms = start_event.elapsed_time(end_event) | |
| print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms") | |
| # Torch topk time | |
| start_event.record() | |
| for _ in range(n_iters): | |
| _ = torch.topk(input, topk, dim=-1)[1] | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_time_ms = start_event.elapsed_time(end_event) | |
| print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") | |
| def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): | |
| torch.manual_seed(1) | |
| input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() | |
| starts = torch.zeros(batch, dtype=torch.int32).cuda() | |
| ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len | |
| indexes = tl_topk(input, starts, ends, topk) | |
| print(indexes) | |
| indexes_ref = torch.topk(input, topk, dim=-1)[1] | |
| print(indexes_ref) | |
| # indexes_ref = fast_topk(input, topk) | |
| # print(indexes_ref) | |
| # Calculate intersection of out_ref and out_trt | |
| for i in range(batch): | |
| ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() | |
| trt_np = indexes[i].cpu().to(torch.int32).numpy() | |
| set_ref = set(ref_np) | |
| set_trt = set(trt_np) | |
| intersection = set_ref & set_trt | |
| print("selected/all:", len(intersection), "/", len(set_ref), "=", | |
| len(intersection) / len(set_ref)) | |
| # Performance test with CUDA events | |
| torch.cuda.synchronize() | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| # Warmup | |
| for _ in range(5): | |
| _ = tl_topk(input, starts, ends, topk) | |
| torch.cuda.synchronize() | |
| n_iters = 20 | |
| start_event.record() | |
| for _ in range(n_iters): | |
| _ = tl_topk(input, starts, ends, topk) | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_time_ms = start_event.elapsed_time(end_event) | |
| print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms") | |
| # Torch topk time | |
| start_event.record() | |
| for _ in range(n_iters): | |
| _ = torch.topk(input, topk, dim=-1)[1] | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_time_ms = start_event.elapsed_time(end_event) | |
| print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") | 
🤖 Prompt for AI Agents
In examples/deepseek_v32/topk_selector.py around lines 188 to 246, the
test_topk_selector function currently overwrites its parameters by reassigning
batch, seq_len, and topk on lines 190-192; remove those three hard-coded
assignments so the function uses the incoming arguments, and ensure subsequent
uses (torch.randn, ends, tl_topk, torch.topk, warmup/iteration loops) reference
the original function parameters as intended.
…irements and parameterization - Added CUDA requirements decorators to `test_example_sparse_mla_fwd` and `test_example_sparse_mla_fwd_pipelined`. - Parameterized test functions to use specific small shapes for testing, improving test coverage and clarity.
* Remove unused `fp8_mqa_logits.py` file and update README.md to reflect new directory structure and file descriptions for deepseek_v32 example. Added sections for architecture overview, Lightning Indexer, Top-k Selector, and Sparse MLA Forward implementations. * Update linting configurations and improve code formatting in deepseek_v32 example scripts - Added per-file ignores for the inference directory in `pyproject.toml`. - Refactored code in `topk_selector.py`, `convert.py`, `generate.py`, `kernel.py`, and `model.py` to enhance readability by adjusting spacing and line breaks. - Ensured consistent formatting across function definitions and assertions for better clarity. * Refactor test functions in deepseek_v32 example scripts for improved clarity and consistency - Updated `fp8_lighting_indexer.py` to define a dedicated test function for the lighting indexer. - Refactored `sparse_mla_fwd_pipelined.py` and `sparse_mla_fwd.py` to standardize test function parameters and improve readability. - Enhanced `topk_selector.py` by introducing a test function with parameters for batch size and sequence length. - Ensured all test functions are invoked correctly in the main execution block. * Enhance test functions in deepseek_v32 example scripts with CUDA requirements and parameterization - Added CUDA requirements decorators to `test_example_sparse_mla_fwd` and `test_example_sparse_mla_fwd_pipelined`. - Parameterized test functions to use specific small shapes for testing, improving test coverage and clarity. * lint fix * Update README.md to correct image path for DeepSeek V3.2 architecture diagram
as title.
Summary by CodeRabbit
New Features
Documentation
Tests
Chores