Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 29, 2025

as title.

Summary by CodeRabbit

  • New Features

    • DeepSeek V3.2 inference toolkit: model-weight conversion and sharded checkpoint support, distributed interactive chat and batch generation, FP8 quantization support, accelerated kernels, and a fast two-stage top-k selection pipeline.
    • Ready-to-use 671B v3.2 inference configuration and a modular Transformer with Mixture-of-Experts, rotary embeddings, and local indexing.
  • Documentation

    • Expanded README with architecture overview, component descriptions, and inference workflow instructions.
  • Tests

    • Added example test harnesses exercising indexer, top-k, FP8 pipelines, and pipelined MLA.
  • Chores

    • Added runtime dependencies, local permission settings, and lint exclusions.

…t new directory structure and file descriptions for deepseek_v32 example. Added sections for architecture overview, Lightning Indexer, Top-k Selector, and Sparse MLA Forward implementations.
…_v32 example scripts

- Added per-file ignores for the inference directory in `pyproject.toml`.
- Refactored code in `topk_selector.py`, `convert.py`, `generate.py`, `kernel.py`, and `model.py` to enhance readability by adjusting spacing and line breaks.
- Ensured consistent formatting across function definitions and assertions for better clarity.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 29, 2025

Warning

Rate limit exceeded

@LeiWang1999 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 28 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between ca216e7 and 6ad57ee.

📒 Files selected for processing (1)
  • examples/deepseek_v32/README.md (1 hunks)

Walkthrough

Adds a complete DeepSeek v3.2 inference example: documentation, config and requirements, HF-to-sharded safetensors conversion, interactive/batch generation, TileLang FP8 kernels and top‑k selector, a distributed Transformer (MLA/MoE/indexer), tests, and local tooling/settings. No changes to existing public APIs outside new modules.

Changes

Cohort / File(s) Summary
Documentation
examples/deepseek_v32/README.md, examples/deepseek_v32/inference/README.md
Expanded READMEs with file descriptions, architecture overview, and step‑by‑step inference instructions; removed fp8_mqa_logits reference and added figures/, inference/, topk_selector entries.
Inference config & tooling
examples/deepseek_v32/inference/config_671B_v3.2.json, examples/deepseek_v32/inference/requirements.txt, examples/deepseek_v32/inference/.claude/settings.local.json, pyproject.toml
Adds model config, dependency list, local Claude permissions, and Ruff ignore for inference directory.
Checkpoint conversion
examples/deepseek_v32/inference/convert.py
New script mapping HF parameter names, sharding tensors for model‑parallelism, saving per‑shard safetensors, and copying tokenizer files.
Generation CLI
examples/deepseek_v32/inference/generate.py
New interactive and batch generation CLI with distributed setup, sampling/argmax logic, EOS handling, safetensors weight loading, and prompt handling.
FP8 TileLang kernels & indexer
examples/deepseek_v32/inference/kernel.py, examples/deepseek_v32/fp8_lighting_indexer.py
Adds TileLang JIT kernels and wrappers for FP8 activation quantization, FP8 GEMM, FP8 index scoring; refactors fp8_lighting_indexer into a parameterized test entrypoint.
Model implementation
examples/deepseek_v32/inference/model.py
New distributed Transformer implementation with ModelArgs, ParallelEmbedding, parallel linear variants, rotary embeddings, MLA attention, MoE/Expert/Gate, weight dequant, and forward/inference paths.
Top‑k selector
examples/deepseek_v32/topk_selector.py
New two‑stage TileLang top‑k kernel (histogram + tail passes), wrapper tl_topk, uint conversions, and a test/demo harness.
Pipelined / sparse MLA tests
examples/deepseek_v32/sparse_mla_fwd.py, examples/deepseek_v32/sparse_mla_fwd_pipelined.py
Converted tests to parameterized functions; introduced pipelined variant, renamed/added entrypoints, and updated main invocation.
TileLang test harness
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
New test module exposing wrapper test functions for topk, fp8 indexer, sparse MLA (regular and pipelined) and adding tilelang test entry.
Top‑level example edits
examples/deepseek_v32/README.md
README reorganized and expanded with file descriptions and architecture overview.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor U as User
  participant C as convert.py
  participant HF as HF safetensors (*.safetensors)
  participant S as Sharded outputs (model{i}-mp{mp}.safetensors)
  participant T as Token files

  U->>C: Run convert.py (--hf-ckpt-path, --save-path, --n-experts, --model-parallel)
  C->>HF: Load and iterate tensors
  C->>C: Remap names, compute shard slices, shard expert/dense tensors
  C->>S: Save per‑shard safetensors
  C->>T: Copy tokenizer files
  C-->>U: Return saved model path
Loading
sequenceDiagram
  autonumber
  actor U as User
  participant G as generate.py
  participant Tok as AutoTokenizer
  participant M as Transformer (model.py)
  participant K as FP8 Kernels / Topk (kernel.py / topk_selector.py)

  U->>G: torchrun generate.py --ckpt-path --config ...
  G->>Tok: Tokenize prompt(s)
  G->>M: Load model shards, init distributed
  loop autoregressive steps
    G->>M: Forward(tokens)
    M->>K: FP8 quant / GEMM / Index / Topk
    K-->>M: Intermediate tensors
    M-->>G: Logits
    G->>G: Sample/argmax, append token(s)
  end
  G->>Tok: Decode output
  G-->>U: Print/emit completions
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • tzj-fxz

Poem

I hop through shards and quantized light,
kernels hum through the silent night.
Top‑k finds whispers in token streams,
MoE gates open to bright new dreams.
V3.2 lands — hop, compute, delight! 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Title Check ⚠️ Warning The PR title states "Add topk into sparse mla example and append some docs." The changeset does introduce a new topk_selector.py module implementing top-k selection functionality and adds extensive documentation (expanded README.md with architecture overviews, new inference README with setup instructions). However, the changeset also includes a complete inference implementation suite (convert.py, generate.py, kernel.py, model.py with distributed transformer code), new test infrastructure, configuration files, and refactoring of existing test functions to be parameterized—none of which are reflected in the title. The title accurately describes two parts of the change (topk addition and documentation) but omits the substantial inference implementation and testing infrastructure that represent a significant portion of the changeset. Consider revising the title to reflect the broader scope of changes, for example: "Add DeepSeek v3.2 inference implementation, topk selector, and expanded documentation" or "Extend DeepSeek v32 example with inference suite, topk selection, and architecture docs." This would better communicate to teammates that the PR includes not just topk and docs, but also a complete distributed inference pipeline with model conversion, generation, and custom kernels.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🧪 Early access (Sonnet 4.5): enabled

We are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience.

Note:

  • Public repositories are always opted into early access features.
  • You can enable or disable early access features from the CodeRabbit UI or by updating the CodeRabbit configuration file.

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (18)
examples/deepseek_v32/inference/README.md (1)

1-14: Enhance documentation with prerequisites and guidance.

The workflow is clearly structured, but users may need additional context:

  1. Hardware requirements: No mention of GPU count, VRAM requirements, or how to choose the MP value
  2. Model download: Missing instructions for obtaining HF_CKPT_PATH (Hugging Face model location)
  3. EXPERTS parameter: No explanation of why 256 or how to adjust it
  4. Generation modes: Only --interactive is shown; batch mode is not documented

Consider adding a prerequisites section:

## Prerequisites

- **Hardware**: 8x A100 80GB GPUs (or adjust `MP` for your setup)
- **Model weights**: Download from Hugging Face:
  ```bash
  export HF_CKPT_PATH=/path/to/deepseek-v3.2-671B
  # huggingface-cli download deepseek-ai/DeepSeek-V3.2 --local-dir $HF_CKPT_PATH
  • Experts: Set to match your model configuration (256 for v3.2)

Usage

1. Convert Model Weights

...

2. Interactive Generation

...

3. Batch Generation (Optional)

torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --prompts prompts.txt

This helps users understand requirements before running commands.

</blockquote></details>
<details>
<summary>examples/deepseek_v32/inference/convert.py (1)</summary><blockquote>

`56-57`: **Document the reason for skipping layer 61.**

The skip of `model.layers.61` is hardcoded without explanation. Consider adding a comment explaining why this specific layer is excluded from conversion.


```diff
             for name in f.keys():
+                # Skip layer 61: [reason for exclusion]
                 if "model.layers.61" in name:
                     continue
examples/deepseek_v32/inference/kernel.py (1)

8-12: Consider using the non-deprecated fast math configuration.

According to the TileLang documentation, TL_DISABLE_FAST_MATH is deprecated and will be removed in version 0.1.7. Consider using TL_ENABLE_FAST_MATH: False instead.

Apply this diff:

 pass_configs = {
     tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
     tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
-    tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
+    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
 }
examples/deepseek_v32/topk_selector.py (1)

180-185: Public API wrapper is straightforward.

The function correctly allocates the output tensor and invokes the kernel.

Minor note: As flagged by static analysis, seq_len is unpacked on line 181 but not used in the function body (it's passed implicitly via the input tensor shape). This is harmless but could be simplified:

 def tl_topk(input, starts, ends, topk):
-    batch, seq_len = input.shape
+    batch, _ = input.shape
     indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device)
examples/deepseek_v32/inference/model.py (9)

11-11: Use absolute import for kernel module.

The relative import from kernel import act_quant, fp8_gemm, fp8_index assumes the script is run from a specific directory. For better portability and explicit dependencies, use an absolute import.

Apply this diff:

-from kernel import act_quant, fp8_gemm, fp8_index
+from examples.deepseek_v32.inference.kernel import act_quant, fp8_gemm, fp8_index

Alternatively, if this is intended to remain a local import, add a comment explaining the expected execution context.


13-15: Initialize distributed globals properly.

The module-level globals world_size, rank, and block_size are reassigned in Transformer.__init__. This pattern can lead to confusion and makes testing harder. Consider encapsulating these in a configuration object or passing them explicitly.

If you prefer to keep globals for simplicity in this inference script, consider adding a comment explaining that these are overwritten during Transformer.__init__ based on distributed context.


162-168: Clarify assertion for bias parameter.

Line 162 asserts bias is None, but the function signature allows Optional[torch.Tensor]. This indicates bias is not yet supported. Either remove the parameter or add a TODO comment explaining future support.

Apply this diff:

 def linear(x: torch.Tensor,
            weight: torch.Tensor,
            bias: Optional[torch.Tensor] = None,
            scale_fmt: Optional[str] = None) -> torch.Tensor:
     """
     ...
     """
-    assert bias is None
+    assert bias is None, "Bias is not yet supported in quantized linear"

508-511: Remove debug assertion in production code.

Lines 508-511 broadcast topk_indices to verify consistency across ranks using an assertion. This is useful for debugging but adds overhead in production inference. Consider removing or gating it behind a debug flag.

         topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
-        topk_indices_ = topk_indices.clone()
-        dist.broadcast(topk_indices_, src=0)
-        assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
+        # Debug check (can be removed in production):
+        # topk_indices_ = topk_indices.clone()
+        # dist.broadcast(topk_indices_, src=0)
+        # assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
         return topk_indices

714-714: Document the hardcoded dimension check for bias.

Line 714 conditionally creates a bias parameter only when self.dim == 7168. This appears to be a model-specific configuration. Document why this dimension requires bias.

-        self.bias = nn.Parameter(torch.empty(args.n_routed_experts,
-                                             dtype=torch.float32)) if self.dim == 7168 else None
+        # Bias is only used for specific model sizes (e.g., 671B model uses dim=7168)
+        self.bias = nn.Parameter(torch.empty(args.n_routed_experts,
+                                             dtype=torch.float32)) if self.dim == 7168 else None

839-849: Optimize expert routing to avoid sparse indexing overhead.

Lines 839-849 use torch.where to find indices for each expert and then index into the input. For large batch sizes or many experts, this can be inefficient. Consider batching expert calls or using grouped GEMMs if available.

This is a known pattern in MoE inference. If performance is critical, explore grouped GEMM kernels or expert batching strategies used in libraries like Megablocks or tutel.


891-898: Clarify residual connection flow in Block.forward.

The Block.forward signature includes both x and residual parameters, and the norm layers can return updated residuals. This pattern is correct but may be confusing. Add a comment explaining the fused residual + norm pattern.

     def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int,
                 freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
         """
         Forward pass for the Transformer block.
+        
+        Note: RMSNorm can fuse residual addition for efficiency.
+        If residual is None, this is the first block; otherwise, residual accumulates.
         ...
         """

921-925: Avoid mutating global state in class initialization.

Lines 921-925 modify module-level globals world_size, rank, and Linear.dtype/scale_fmt during Transformer.__init__. This makes the module stateful and harder to test. Consider passing these as explicit parameters or using a context manager.

For an inference script, this pattern is acceptable for simplicity, but document that Transformer.__init__ has side effects on module-level state.


965-972: Add guard for main block execution.

The if __name__ == "__main__" block instantiates a model and runs a forward pass. This is useful for testing but should not run during imports. Consider moving to a separate test file or adding a comment that this is for quick validation.

The test code is fine for a standalone script. If this file is ever imported as a module, ensure the main block doesn't execute unintentionally.

examples/deepseek_v32/inference/generate.py (5)

25-27: Clarify Gumbel sampling implementation.

Lines 25-27 implement Gumbel-max sampling by dividing probabilities by exponential noise. This is correct but uncommon. The more typical approach is torch.multinomial. Document why this method is preferred (e.g., efficiency, determinism).

If you prefer the Gumbel trick for performance or reproducibility, add a comment:

 def sample(logits, temperature: float = 1.0):
     """
     Samples a token from the logits using temperature scaling.
+    
+    Uses Gumbel-max trick: argmax(log(p) - log(-log(u))) for efficiency.
     ...
     """

111-111: Document the choice of random seed.

Line 111 sets a manual seed to 33377335. If this is for reproducibility, document it. If it's arbitrary, consider using a configurable seed via CLI or environment variable.

-    torch.manual_seed(33377335)
+    torch.manual_seed(33377335)  # Fixed seed for reproducibility

125-134: Simplify interactive prompt collection.

Lines 125-134 handle prompt input differently based on world size. The logic for world_size == 1 is redundant since the elif rank == 0 branch can handle both cases.

Apply this diff:

-            if world_size == 1:
-                prompt = input(">>> ")
-            elif rank == 0:
+            if rank == 0:
                 prompt = input(">>> ")
                 objects = [prompt]
-                dist.broadcast_object_list(objects, 0)
+                if world_size > 1:
+                    dist.broadcast_object_list(objects, 0)
             else:
                 objects = [None]
                 dist.broadcast_object_list(objects, 0)
                 prompt = objects[0]

148-166: Validate input file format in batch mode.

Lines 148-166 read prompts from a file using split("\n\n"). This assumes double-newline delimited prompts. Document this format or add validation for malformed input.

Add a docstring or comment explaining the expected format:

     else:
+        # Batch mode: expects prompts separated by double newlines in input_file
         with open(input_file) as f:
             prompts = f.read().split("\n\n")

195-195: Improve CLI validation message.

Line 195 asserts that either input_file or interactive must be specified. The assertion message is clear, but using argparse groups or mutual exclusivity would be more idiomatic.

Consider using argparse mutual exclusivity:

     parser = ArgumentParser()
     parser.add_argument("--ckpt-path", type=str, required=True)
     parser.add_argument("--config", type=str, required=True)
-    parser.add_argument("--input-file", type=str, default="")
-    parser.add_argument("--interactive", action="store_true")
+    mode_group = parser.add_mutually_exclusive_group(required=True)
+    mode_group.add_argument("--input-file", type=str)
+    mode_group.add_argument("--interactive", action="store_true")
     parser.add_argument("--max-new-tokens", type=int, default=200)
     parser.add_argument("--temperature", type=float, default=0.6)
     args = parser.parse_args()
-    assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 54fc6ba and 1719847.

⛔ Files ignored due to path filters (1)
  • examples/deepseek_v32/figures/v32_arch.png is excluded by !**/*.png
📒 Files selected for processing (11)
  • examples/deepseek_v32/README.md (1 hunks)
  • examples/deepseek_v32/inference/.claude/settings.local.json (1 hunks)
  • examples/deepseek_v32/inference/README.md (1 hunks)
  • examples/deepseek_v32/inference/config_671B_v3.2.json (1 hunks)
  • examples/deepseek_v32/inference/convert.py (1 hunks)
  • examples/deepseek_v32/inference/generate.py (1 hunks)
  • examples/deepseek_v32/inference/kernel.py (1 hunks)
  • examples/deepseek_v32/inference/model.py (1 hunks)
  • examples/deepseek_v32/inference/requirements.txt (1 hunks)
  • examples/deepseek_v32/topk_selector.py (1 hunks)
  • pyproject.toml (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
examples/deepseek_v32/inference/generate.py (2)
examples/deepseek_v32/inference/model.py (16)
  • Transformer (901-962)
  • ModelArgs (19-91)
  • forward (113-134)
  • forward (202-212)
  • forward (231-242)
  • forward (267-283)
  • forward (301-321)
  • forward (336-337)
  • forward (478-511)
  • forward (579-644)
  • forward (670-680)
  • forward (716-748)
  • forward (774-784)
  • forward (825-849)
  • forward (877-898)
  • forward (938-962)
examples/deepseek_v32/inference/convert.py (1)
  • main (36-89)
examples/deepseek_v32/inference/kernel.py (10)
tilelang/__init__.py (1)
  • set_log_level (25-35)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
tilelang/language/tir/op.py (2)
  • reinterpret (1816-1835)
  • if_then_else (2824-2854)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/__init__.py (2)
  • symbolic (83-94)
  • use_swizzle (97-103)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/reduce.py (2)
  • reduce_absmax (127-139)
  • reduce_sum (87-109)
tilelang/language/customize.py (1)
  • clamp (110-123)
examples/deepseek_v32/topk_selector.py (6)
tilelang/language/tir/op.py (2)
  • reinterpret (1816-1835)
  • if_then_else (2824-2854)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/__init__.py (1)
  • symbolic (83-94)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/fill.py (1)
  • fill (9-21)
tilelang/language/builtin.py (1)
  • sync_threads (333-341)
examples/deepseek_v32/inference/model.py (1)
examples/deepseek_v32/inference/kernel.py (3)
  • act_quant (83-106)
  • fp8_gemm (167-190)
  • fp8_index (248-268)
examples/deepseek_v32/inference/convert.py (1)
examples/deepseek_v32/inference/generate.py (1)
  • main (81-169)
🪛 markdownlint-cli2 (0.18.1)
examples/deepseek_v32/README.md

123-123: Spaces inside emphasis markers

(MD037, no-space-in-emphasis)


123-123: Spaces inside emphasis markers

(MD037, no-space-in-emphasis)

🪛 Ruff (0.13.1)
examples/deepseek_v32/inference/generate.py

107-107: Do not assign a lambda expression, use a def

Rewrite print as a def

(E731)

examples/deepseek_v32/topk_selector.py

181-181: Unpacked variable seq_len is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check
🔇 Additional comments (25)
pyproject.toml (1)

60-60: LGTM! Appropriate lint exclusion for example code.

The per-file ignore follows the established pattern and is appropriate for the new inference example directory.

examples/deepseek_v32/inference/config_671B_v3.2.json (1)

1-26: Configuration is well-aligned with model code.

The JSON configuration keys match the ModelArgs dataclass fields in model.py. The config is loaded via ModelArgs(**json.load(f)) in generate.py, ensuring exact key matching. All 25 parameters (vocab_size, n_routed_experts, q_lora_rank, kv_lora_rank, moe_inter_dim, index_topk, etc.) are correctly referenced throughout the model initialization and forward passes.

examples/deepseek_v32/inference/.claude/settings.local.json (1)

1-10: Verify whether local settings files should be committed.

This file contains user-specific filesystem paths (wanglei) and internal infrastructure references (//weka-hg/prod/deepseek/permanent/). The .local. naming convention typically indicates environment-specific configuration that should not be version-controlled.

Verification confirmed:

  • This file is currently tracked by git and being committed
  • No .gitignore pattern exists to exclude .local.json files
  • This is the only .claude/settings.local.json file in the repository
  • No template or example file exists for reference

Recommendations:

  1. Add .claude/settings.local.json (or the pattern **/.claude/settings.local.json) to .gitignore
  2. Remove this file from version control: git rm --cached examples/deepseek_v32/inference/.claude/settings.local.json
  3. Create a template file like .claude/settings.example.json documenting the expected structure with placeholder paths
examples/deepseek_v32/inference/convert.py (5)

10-33: Well-structured parameter mapping.

The mapping dictionary provides a clear and maintainable way to handle parameter name transformations and shard dimensions. The tuple structure (new_name, shard_dimension) effectively encodes both the renaming and sharding strategy.


69-80: Sharding logic is well-implemented.

The dual sharding strategy (expert-based and dimension-based) is clearly implemented. The divisibility check on line 76-77 prevents runtime errors, and the use of .contiguous() ensures proper memory layout after slicing.

Note: The expert index extraction on line 72 using split(".")[-3] makes a similar assumption about name structure as the earlier key extraction. This should be covered by the verification script suggested above.


82-89: File I/O operations look correct.

The directory creation, shard file saving, and token file copying are all implemented correctly. The naming convention model{i}-mp{mp}.safetensors aligns with the loading pattern in generate.py.


92-100: CLI argument handling is appropriate.

The argument parser is well-structured with clear names, and the divisibility assertion on line 99 provides early validation before the conversion begins. This prevents wasted processing time on invalid configurations.


65-66: Key extraction logic verified and robust.

The split(".")[-2] approach correctly extracts mapping keys for all standard PyTorch checkpoint parameter names. Verification confirms that parameter names follow the convention <module_path>.<param_name>.<param_type>, ensuring at least 2 dot-separated components (minimum: "param_name.weight"). The runtime assertion on line 66 provides additional validation, making this implementation both safe and appropriate for the expected checkpoint format.

examples/deepseek_v32/inference/kernel.py (4)

19-32: Helper functions implement standard FP8 quantization patterns.

The bit manipulation helpers (fast_log2_ceil, fast_pow2, fast_round_scale) correctly implement IEEE 754 float32 exponent extraction and reconstruction for efficient FP8 scaling. The use of T.reinterpret and T.if_then_else is appropriate for TileLang kernels.


35-106: Activation quantization kernel is well-structured.

The implementation correctly handles block-wise FP8 quantization with per-block scaling:

  • The safety bound max(amax, 1e-4) on line 68 prevents division by zero.
  • The clamping to [fp8_min, fp8_max] on line 74 ensures valid FP8 range.
  • Input validation in the public API ensures contiguity and block alignment.
  • The conditional pipeline depth based on round_scale is a good optimization.

109-190: FP8 GEMM implementation follows best practices.

The kernel demonstrates advanced TileLang techniques:

  • 4-stage pipeline (line 145) for overlapping computation and memory transfers.
  • Swizzle pattern (line 140) to improve L2 cache locality.
  • Proper double-buffer accumulation pattern to enable 2x accumulation parallelism.
  • Correct handling of per-block scales for both inputs.

The public API provides appropriate input validation and tensor reshaping.


193-268: FP8 index kernel correctly implements the scoring computation.

The kernel follows the documented computation flow:

  1. FP8 Q @ FP8 K → FP32 logits (lines 225-232)
  2. ReLU(logits) * q_s → scaled logits (line 235)
  3. Sum reduction → logits_sum (line 238)
  4. logits_sum * k_s → index_score (line 241)

The use of out_idx=[4] properly indicates the output tensor position in the signature, and the 2-stage pipeline provides good performance.

examples/deepseek_v32/topk_selector.py (4)

5-24: Bit conversion helpers correctly handle float-to-unsigned mapping.

The conversion functions implement the standard technique for treating floats as unsigned integers for magnitude comparison (essential for radix-based selection):

  • Negative values are bitwise-inverted to maintain ordering.
  • Positive values have the sign bit set to ensure they sort after negatives.

The TL_DISABLE_THREAD_STORAGE_SYNC configuration means synchronization must be manually managed (which is done correctly throughout the kernel).


27-112: Stage 1 histogram-based filtering is well-implemented.

The first stage correctly implements the 8-bit histogram approach:

  • Atomic histogram accumulation (line 73) safely handles concurrent updates.
  • The parallel cumulative sum (lines 78-85) uses proper barriers with barrier_id=3 and arrive_count=RADIX.
  • Threshold bin identification (line 89) correctly finds where cumsum exceeds topk.
  • Position allocation via atomic_add(..., return_prev=True) (line 105) prevents race conditions.

The assumption that the threshold bucket size is less than 4K (line 33) is reasonable for typical distributions but should be validated for edge cases.


113-176: Stage 2 multi-pass refinement correctly refines the top-k selection.

The tail pass implements a sophisticated multi-round approach:

  • Ping-pong buffering (line 118) efficiently reuses shared memory.
  • Progressive bit extraction (line 130-132) refines the selection 8 bits at a time.
  • Early exit (lines 115-116) avoids unnecessary work.
  • Final round (line 167) directly writes results with proper bounds checking.

The hardcoded limit of 4 rounds provides 32 bits of resolution (4 × 8 bits), which is appropriate for float32 inputs.


188-244: Test code follows GPU benchmarking best practices.

The test/demo code is well-structured:

  • Uses CUDA events for accurate timing (lines 220-221).
  • Includes proper warmup iterations (lines 224-226) to avoid cold-start effects.
  • Validates correctness via set intersection (lines 207-215), which is appropriate for top-k where order isn't guaranteed.
  • Compares against PyTorch's built-in implementation for both correctness and performance.
examples/deepseek_v32/inference/model.py (6)

624-626: Cache dequantized weights to avoid redundant computation.

Lines 624-626 lazily dequantize wkv_b during decode passes. This is good, but ensure the cache (self.dequant_wkv_b) is properly invalidated if weights are updated (e.g., during fine-tuning or weight swapping).

The lazy dequantization pattern improves decode performance. Just confirm that weight updates (if any) invalidate the cache.


933-934: Clarify dtype choice for lm_head.

The comment on line 933 explains that lm_head is stored in fp32 for "easier computation of logits later." Ensure this doesn't cause precision issues when the rest of the model is in bf16 or fp8.

Using fp32 for the final projection is a reasonable choice to maintain numerical stability for logits. The comment is helpful.


440-443: Document the rotate_activation function and its bfloat16 requirement.

The function lacks a docstring. Add one that explains:

  • What the function does (applies Hadamard transform to activations)
  • The bfloat16 dtype requirement and why it's necessary
  • Parameters and return type

The fast_hadamard_transform dependency is already listed in requirements.txt.


958-962: Concatenation order is consistent with vocab sharding.

The vocab sharding in ParallelEmbedding (lines 108-110) assigns consecutive shards by rank: rank 0 owns vocab indices [0, part_vocab_size), rank 1 owns [part_vocab_size, 2*part_vocab_size), etc. The all_gather operation collects logits in rank order, and torch.cat(all_logits, dim=-1) concatenates them in that same order, correctly reconstructing the full vocabulary dimension.


107-110: Manual verification recommended for distributed embedding sharding.

The ParallelEmbedding implementation appears mathematically correct:

  • Each rank owns vocabulary indices [vocab_start_idx, vocab_end_idx)
  • Tokens outside a rank's range are masked, zeroed after lookup, and summed via all-reduce
  • Edge cases at shard boundaries (e.g., vocab_start_idx, vocab_end_idx-1) are handled correctly by the mask conditions (x < vocab_start_idx) | (x >= vocab_end_idx)

However, no tests were found for this class. Given the complexity of distributed operations and the critical nature of embedding correctness, verify the implementation with multi-rank testing to ensure token lookups, masking, and all-reduce behavior work as expected across different vocabulary distributions.


514-522: Verify weight_dequant correctness - weight quantization format is external and untested.

The weight_dequant function performs block-wise dequantization with a transpose operation (line 517: .transpose(1, 2)), but verification reveals several concerns:

  1. No matching quantization code: The codebase contains no weight quantization implementation. The kernel.py file only provides activation quantization (act_quant), not weight quantization. Weights come pre-quantized from the HuggingFace checkpoint via convert.py.

  2. No tests: No test coverage exists for weight_dequant to validate the block layout transformations.

  3. Single usage: Only called once (line 625) for wkv_b weights during MHA decode, making bugs harder to detect.

  4. Unclear block layout: The transpose suggests weights are stored with interleaved dimensions: (out_blocks, block_size, in_blocks, block_size)(out_blocks, in_blocks, block_size, block_size), but this layout is not documented or validated against the checkpoint format.

Recommendation: Verify that the weight block layout in the checkpoint matches the expected format, especially the dimension ordering and transpose operation. Consider adding unit tests to validate dequantization correctness.

examples/deepseek_v32/inference/generate.py (1)

50-52: Verify assertion message consistency.

The assertion message on line 52 uses max_seq_len in the message, which matches the variable name in ModelArgs. Ensure this is clear to users who may not be familiar with the model internals.

The assertion is clear and prevents a common error. Good defensive programming.

examples/deepseek_v32/README.md (2)

14-168: Excellent documentation expansion.

The added sections provide clear, detailed explanations of the Lightning Indexer, Top-k Selector, and Sparse MLA components. The code snippets and architecture diagram references help readers understand the implementation.

This documentation significantly improves the usability of the example. Great work!


123-123: Fix markdown emphasis formatting.

Line 123 has spaces inside emphasis markers. The linter flags * seq_len * which should be *seq_len* or similar.

Review the emphasis syntax on line 123 and remove extra spaces:

-This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk).
+This reduces compute from O(seq_len × seq_len_kv) to O(seq_len × topk).

Or if using italics:

-This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk).
+This reduces compute from O(*seq_len* × *seq_len_kv*) to O(*seq_len* × *topk*).

Likely an incorrect or invalid review comment.

Comment on lines +105 to +107
global print
if rank != 0:
print = lambda *_, **__: None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use a proper no-op function instead of lambda.

Line 107 assigns a lambda to print to suppress output on non-zero ranks. Static analysis flags this (Ruff E731). Use a proper function definition.

Apply this diff:

     global print
     if rank != 0:
-        print = lambda *_, **__: None
+        def print(*args, **kwargs):
+            pass
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
global print
if rank != 0:
print = lambda *_, **__: None
global print
if rank != 0:
def print(*args, **kwargs):
pass
🧰 Tools
🪛 Ruff (0.13.1)

107-107: Do not assign a lambda expression, use a def

Rewrite print as a def

(E731)

🤖 Prompt for AI Agents
In examples/deepseek_v32/inference/generate.py around lines 105 to 107, the code
suppresses printing by assigning a lambda to the built-in print (which triggers
Ruff E731); replace the lambda with a proper no-op function definition (e.g.,
define def _noop(*args, **kwargs): pass) and assign print = _noop for non-zero
ranks so static analysis is satisfied and intent remains the same.

Comment on lines +1 to +5
torch
transformers
safetensors
fast_hadamard_transform
tilelang==0.1.6 No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Pin dependency versions for reproducibility and security.

Most dependencies lack version constraints, which can lead to:

  • Non-reproducible builds when dependency APIs change
  • Unexpected breakage from major version updates
  • Difficulty tracking security vulnerabilities

Consider pinning minimum (or exact) versions:

-torch
-transformers
-safetensors
-fast_hadamard_transform
+torch>=2.0.0
+transformers>=4.30.0
+safetensors>=0.3.0
+fast_hadamard_transform>=0.3.0
 tilelang==0.1.6

Alternatively, generate a requirements-lock.txt using pip freeze after testing to capture exact working versions.

Run the following to check for known vulnerabilities in these packages:


🏁 Script executed:

#!/bin/bash
# Check for security advisories on the core dependencies
for pkg in torch transformers safetensors fast_hadamard_transform; do
  echo "=== Checking $pkg ==="
  gh api graphql -f query="
  {
    securityVulnerabilities(first: 3, ecosystem: PIP, package: \"$pkg\") {
      nodes {
        advisory {
          summary
          severity
          publishedAt
        }
        vulnerableVersionRange
        firstPatchedVersion {
          identifier
        }
      }
    }
  }"
done

Length of output: 1778


🏁 Script executed:

#!/bin/bash
# Check latest stable versions of torch and transformers via PyPI
echo "=== Checking torch latest version ==="
curl -s https://pypi.org/pypi/torch/json | jq -r '.info.version'

echo "=== Checking transformers latest version ==="
curl -s https://pypi.org/pypi/transformers/json | jq -r '.info.version'

echo "=== Checking safetensors latest version ==="
curl -s https://pypi.org/pypi/safetensors/json | jq -r '.info.version'

echo "=== Checking fast_hadamard_transform latest version ==="
curl -s https://pypi.org/pypi/fast_hadamard_transform/json | jq -r '.info.version'

Length of output: 213


Pin dependency versions to avoid critical security vulnerabilities.

Without version constraints, unpinned dependencies can install vulnerable versions. Security scan reveals:

  • torch < 2.6.0: CRITICAL remote code execution via torch.load
  • torch ≤ 2.7.1: MODERATE resource shutdown vulnerability
  • transformers < 4.53.0: Multiple MODERATE ReDoS vulnerabilities

Recommend pinning to patched versions:

-torch
-transformers
-safetensors
-fast_hadamard_transform
+torch>=2.8.0
+transformers>=4.53.0
+safetensors>=0.4.0
+fast_hadamard_transform>=0.3.0
 tilelang==0.1.6

Latest stable versions (as of Sept 2025): torch 2.8.0, transformers 4.56.2, safetensors 0.6.2, fast_hadamard_transform 1.0.4.post1.

Alternatively, use pip freeze > requirements-lock.txt after testing to capture exact working versions.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/deepseek_v32/inference/requirements.txt lines 1-5 the dependencies
are unpinned which allows installing vulnerable package versions; update the
file to pin packages to safe versions (e.g. torch>=2.8.0,<2.9,
transformers>=4.56.2,<4.57, safetensors>=0.6.2,<0.7,
fast_hadamard_transform>=1.0.4.post1,<1.1 and keep tilelang==0.1.6 or pin it if
needed) and after validating the environment run pip freeze >
requirements-lock.txt to capture exact tested versions for reproducible, secure
installs.

…clarity and consistency

- Updated `fp8_lighting_indexer.py` to define a dedicated test function for the lighting indexer.
- Refactored `sparse_mla_fwd_pipelined.py` and `sparse_mla_fwd.py` to standardize test function parameters and improve readability.
- Enhanced `topk_selector.py` by introducing a test function with parameters for batch size and sequence length.
- Ensured all test functions are invoked correctly in the main execution block.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1)

1-24: Fix formatting to resolve pipeline failure.

The CI pipeline failed with: "Reformatted files. Please review and stage the changes." Run the project's formatter (likely ruff format or black) and commit the changes.

#!/bin/bash
# Run formatter to fix the issue
cd examples/deepseek_v32
ruff format test_tilelang_example_deepseek_v32.py || black test_tilelang_example_deepseek_v32.py
🧹 Nitpick comments (4)
examples/deepseek_v32/topk_selector.py (1)

181-181: Address static analysis hint: unused variable seq_len.

Ruff flags that seq_len is unpacked but never used in tl_topk. Since the kernel requires seq_len as a symbolic dimension, the variable is actually used implicitly by the kernel invocation. You can silence the warning by prefixing with an underscore if it's truly not needed, or keep it as-is if the JIT compilation infers the shape from the tensor.

If seq_len is not explicitly referenced, apply this diff to silence the warning:

 def tl_topk(input, starts, ends, topk):
-    batch, seq_len = input.shape
+    batch, _seq_len = input.shape
     indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device)

Alternatively, if the kernel's symbolic dimension inference requires the variable to be named seq_len, you can ignore the static analysis hint and add a comment explaining the necessity.

examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (2)

1-1: Consider disabling specific linting rules instead of blanket noqa.

A blanket # ruff: noqa suppresses all linting feedback for the entire file. If certain rules need to be disabled (e.g., for relative imports), specify them explicitly (e.g., # ruff: noqa: TID252).


9-19: Consider removing wrapper functions if not required by test framework.

The wrapper functions (test_example_*) simply delegate to the imported test functions without adding parameters, documentation, or additional logic. If the test discovery framework doesn't require the test_example_ prefix, consider importing and using the original functions directly.

examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)

438-439: Consider removing commented-out debug prints.

The commented debug statements are no longer needed and can be removed to keep the code clean.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1719847 and 38ad487.

📒 Files selected for processing (5)
  • examples/deepseek_v32/fp8_lighting_indexer.py (2 hunks)
  • examples/deepseek_v32/sparse_mla_fwd.py (2 hunks)
  • examples/deepseek_v32/sparse_mla_fwd_pipelined.py (3 hunks)
  • examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (1 hunks)
  • examples/deepseek_v32/topk_selector.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py (3)
examples/deepseek_v32/topk_selector.py (1)
  • test_topk_selector (188-245)
examples/deepseek_v32/sparse_mla_fwd.py (2)
  • sparse_mla_fwd (14-173)
  • test_sparse_mla_fwd (234-268)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)
  • sparse_mla_fwd (18-311)
  • test_sparse_mla_fwd_pipelined (400-451)
examples/deepseek_v32/topk_selector.py (7)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
tilelang/language/tir/op.py (2)
  • reinterpret (1816-1835)
  • if_then_else (2824-2854)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/__init__.py (1)
  • symbolic (83-94)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/fill.py (1)
  • fill (9-21)
tilelang/language/builtin.py (1)
  • sync_threads (333-341)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)
examples/deepseek_v32/sparse_mla_fwd.py (3)
  • fn (256-257)
  • ref_sparse_mla_fwd_interface (197-231)
  • test_sparse_mla_fwd (234-268)
🪛 GitHub Actions: CI Test on AMD
examples/deepseek_v32/test_tilelang_example_deepseek_v32.py

[error] 1-1: Reformatted files. Please review and stage the changes.

🪛 Ruff (0.13.1)
examples/deepseek_v32/topk_selector.py

181-181: Unpacked variable seq_len is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (11)
examples/deepseek_v32/sparse_mla_fwd.py (2)

234-242: LGTM! Good parameterization for test reusability.

The test function now accepts configurable parameters with sensible defaults, making it easier to exercise different tensor shapes and configurations. The defaults align with the original hard-coded values and the explicit call in __main__.


272-273: LGTM! Explicit parameter passing preserves original behavior.

The __main__ block now calls the parameterized test function with explicit keyword arguments, maintaining the original test behavior while benefiting from the new flexible interface.

examples/deepseek_v32/fp8_lighting_indexer.py (2)

261-282: LGTM! Test function now supports parameterized configuration.

The refactoring extracts the test logic into a reusable function with sensible defaults (S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1), enabling flexible test invocation from external test harnesses.


305-306: LGTM! Simplified entry point.

The __main__ block now delegates to the parameterized test function, maintaining the original behavior while improving modularity.

examples/deepseek_v32/topk_selector.py (5)

5-7: Pass config disables automatic synchronization—ensure manual sync is correct.

The TL_DISABLE_THREAD_STORAGE_SYNC flag disables automatic __syncthreads() insertion for shared memory coordination. Verify that all manual T.sync_threads() calls in the kernel (lines 68, 74, 80, 83, 88, 91, 94, 98, etc.) correctly guard shared memory accesses and prevent race conditions.


10-14: Verify bit manipulation logic for signed float→uint16 conversion.

The function reinterprets float as float16, then flips bits based on sign to establish an ordering where larger floats map to larger uints. The logic inverts negative values and sets the sign bit for positives, then shifts right by 8 to produce an 8-bit bucket ID.

Double-check that the bit manipulation correctly handles:

  • Negative zero vs. positive zero
  • NaN and infinity values
  • The shift-right by 8 to extract the exponent bucket

17-24: Verify bit manipulation logic for signed float32→uint32 conversion.

Similar to convert_to_uint16, this function reinterprets float32 as uint32 and flips bits to create an ordering. The logic inverts negative values and sets the MSB for positives.

Verify correctness for:

  • Sign transitions (negative/positive boundary)
  • Special float values (NaN, ±Inf, ±0.0)
  • Bit ordering consistency with the radix-sort phases

27-177: Complex two-stage top-k kernel—confirm correctness of stage transitions and boundary checks.

The kernel implements a two-stage radix-selection algorithm:

Stage 1 (lines 64-111):

  • Builds an 8-bit histogram by hashing input to uint16 buckets (lines 69-74)
  • Computes cumulative sums to find the threshold bucket (lines 77-93)
  • Collects elements above the threshold directly into index, and elements equal to the threshold into s_input_idx for stage 2 (lines 97-111)

Stage 2 (lines 114-175):

  • Iterates up to 4 rounds, refining the threshold with 8-bit slices of uint32 representation (lines 114-175)
  • Each round builds a histogram, finds a new threshold, and partitions elements (lines 128-175)
  • Final round (round == 3) directly writes remaining elements to output with bounds check (lines 167-171)

Potential concerns:

  1. Boundary checks: Ensure input_idx < seq_len and input_idx >= l_start_idx and input_idx < l_end_idx are applied consistently to prevent out-of-bounds access.
  2. Atomic races: Multiple T.atomic_add calls on s_histogram, s_num_input, and output index positions—verify that all accesses are race-free and output indices stay within [0, topk).
  3. SMEM_INPUT_SIZE=4096 assumption: Line 33 assumes the threshold bucket after stage 1 has ≤4096 elements. If this is violated, s_input_idx overflows. Consider adding a runtime assertion or documenting this constraint.
  4. Loop break condition: if l_new_topk <= 0: T.loop_break() (line 115-116)—confirm that all threads in the block agree on the break condition to avoid divergence issues.
  5. Synchronization: Manual syncs are present, but verify coverage of all shared memory writes/reads (especially around s_histogram, s_input_idx, s_num_input, and s_threshold_bin_id).

Consider adding:

  • A runtime assertion that s_num_input[0] <= SMEM_INPUT_SIZE after stage 1
  • Comments documenting the invariants for each stage
  • A bounds check on pos before writing to index[bx, pos] (e.g., lines 106, 165, 169, 171)

180-185: LGTM! Clean wrapper API.

The tl_topk function provides a simple interface: allocates the output tensor, instantiates the kernel, and returns the result. The API is clear and consistent with PyTorch conventions.

examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)

400-409: LGTM! Clear parameterization for the pipelined test.

The function signature is well-structured with explicit parameters and sensible defaults for testing the pipelined sparse MLA forward pass.


432-433: Masking logic correctly references KV_stride.

The masking condition properly zeroes out positions [:KV_stride - 1] when q_start_s_index == 0 and stride > 1, consistent with the interface function behavior.

Comment on lines +458 to +463
if args.test_correctness:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16
else:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix function name mismatch: test_sparse_mla_fwd is not defined in this file.

The main block calls test_sparse_mla_fwd, but the function defined in this file is test_sparse_mla_fwd_pipelined. This will cause a NameError at runtime.

Apply this diff to fix the function name:

     if args.test_correctness:
         B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16
     else:
         B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
-    test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
-    test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
+    test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
+    test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if args.test_correctness:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16
else:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
test_sparse_mla_fwd(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
if args.test_correctness:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16
else:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype)
🤖 Prompt for AI Agents
In examples/deepseek_v32/sparse_mla_fwd_pipelined.py around lines 458 to 463 the
code calls undefined function test_sparse_mla_fwd causing a NameError; update
the two calls to use the correctly defined function name
test_sparse_mla_fwd_pipelined so they match the function implementation (i.e.,
replace test_sparse_mla_fwd with test_sparse_mla_fwd_pipelined).

Comment on lines +4 to +7
from topk_selector import test_topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use explicit relative imports for sibling modules.

The imports assume the current working directory is examples/deepseek_v32/, which will fail when test runners or scripts invoke this module from other directories.

Apply this diff to use explicit relative imports:

-from topk_selector import test_topk_selector
-from fp8_lighting_indexer import test_fp8_lighting_indexer
-from sparse_mla_fwd import test_sparse_mla_fwd
-from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
+from .topk_selector import test_topk_selector
+from .fp8_lighting_indexer import test_fp8_lighting_indexer
+from .sparse_mla_fwd import test_sparse_mla_fwd
+from .sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from topk_selector import test_topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
from .topk_selector import test_topk_selector
from .fp8_lighting_indexer import test_fp8_lighting_indexer
from .sparse_mla_fwd import test_sparse_mla_fwd
from .sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
🤖 Prompt for AI Agents
In examples/deepseek_v32/test_tilelang_example_deepseek_v32.py around lines 4 to
7, the module uses implicit (absolute) sibling imports which break when the
module is executed from outside the package; change each import to an explicit
relative import (e.g., from .topk_selector import test_topk_selector, from
.fp8_lighting_indexer import test_fp8_lighting_indexer, etc.) so the tests
reliably import sibling modules regardless of the current working directory or
test runner.

Comment on lines +188 to +246
def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):

batch = 64
seq_len = 32 * 1024
topk = 2048
torch.manual_seed(1)
input = torch.randn(batch, seq_len, dtype=torch.float32).cuda()
starts = torch.zeros(batch, dtype=torch.int32).cuda()
ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len

indexes = tl_topk(input, starts, ends, topk)
print(indexes)

indexes_ref = torch.topk(input, topk, dim=-1)[1]
print(indexes_ref)

# indexes_ref = fast_topk(input, topk)
# print(indexes_ref)

# Calculate intersection of out_ref and out_trt
for i in range(batch):
ref_np = indexes_ref[i].cpu().to(torch.int32).numpy()
trt_np = indexes[i].cpu().to(torch.int32).numpy()

set_ref = set(ref_np)
set_trt = set(trt_np)
intersection = set_ref & set_trt
print("selected/all:", len(intersection), "/", len(set_ref), "=",
len(intersection) / len(set_ref))

# Performance test with CUDA events

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# Warmup
for _ in range(5):
_ = tl_topk(input, starts, ends, topk)
torch.cuda.synchronize()

n_iters = 20
start_event.record()
for _ in range(n_iters):
_ = tl_topk(input, starts, ends, topk)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms")

# Torch topk time
start_event.record()
for _ in range(n_iters):
_ = torch.topk(input, topk, dim=-1)[1]
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Test function hardcodes parameters that shadow the function arguments.

Lines 190-192 redefine batch, seq_len, and topk with hard-coded values, ignoring the function parameters. This defeats the purpose of parameterization and will confuse callers who pass different values.

Apply this diff to use the function parameters:

 def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
-
-    batch = 64
-    seq_len = 32 * 1024
-    topk = 2048
     torch.manual_seed(1)
     input = torch.randn(batch, seq_len, dtype=torch.float32).cuda()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
batch = 64
seq_len = 32 * 1024
topk = 2048
torch.manual_seed(1)
input = torch.randn(batch, seq_len, dtype=torch.float32).cuda()
starts = torch.zeros(batch, dtype=torch.int32).cuda()
ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len
indexes = tl_topk(input, starts, ends, topk)
print(indexes)
indexes_ref = torch.topk(input, topk, dim=-1)[1]
print(indexes_ref)
# indexes_ref = fast_topk(input, topk)
# print(indexes_ref)
# Calculate intersection of out_ref and out_trt
for i in range(batch):
ref_np = indexes_ref[i].cpu().to(torch.int32).numpy()
trt_np = indexes[i].cpu().to(torch.int32).numpy()
set_ref = set(ref_np)
set_trt = set(trt_np)
intersection = set_ref & set_trt
print("selected/all:", len(intersection), "/", len(set_ref), "=",
len(intersection) / len(set_ref))
# Performance test with CUDA events
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(5):
_ = tl_topk(input, starts, ends, topk)
torch.cuda.synchronize()
n_iters = 20
start_event.record()
for _ in range(n_iters):
_ = tl_topk(input, starts, ends, topk)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms")
# Torch topk time
start_event.record()
for _ in range(n_iters):
_ = torch.topk(input, topk, dim=-1)[1]
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms")
def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
torch.manual_seed(1)
input = torch.randn(batch, seq_len, dtype=torch.float32).cuda()
starts = torch.zeros(batch, dtype=torch.int32).cuda()
ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len
indexes = tl_topk(input, starts, ends, topk)
print(indexes)
indexes_ref = torch.topk(input, topk, dim=-1)[1]
print(indexes_ref)
# indexes_ref = fast_topk(input, topk)
# print(indexes_ref)
# Calculate intersection of out_ref and out_trt
for i in range(batch):
ref_np = indexes_ref[i].cpu().to(torch.int32).numpy()
trt_np = indexes[i].cpu().to(torch.int32).numpy()
set_ref = set(ref_np)
set_trt = set(trt_np)
intersection = set_ref & set_trt
print("selected/all:", len(intersection), "/", len(set_ref), "=",
len(intersection) / len(set_ref))
# Performance test with CUDA events
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(5):
_ = tl_topk(input, starts, ends, topk)
torch.cuda.synchronize()
n_iters = 20
start_event.record()
for _ in range(n_iters):
_ = tl_topk(input, starts, ends, topk)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms")
# Torch topk time
start_event.record()
for _ in range(n_iters):
_ = torch.topk(input, topk, dim=-1)[1]
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms")
🤖 Prompt for AI Agents
In examples/deepseek_v32/topk_selector.py around lines 188 to 246, the
test_topk_selector function currently overwrites its parameters by reassigning
batch, seq_len, and topk on lines 190-192; remove those three hard-coded
assignments so the function uses the incoming arguments, and ensure subsequent
uses (torch.randn, ends, tl_topk, torch.topk, warmup/iteration loops) reference
the original function parameters as intended.

…irements and parameterization

- Added CUDA requirements decorators to `test_example_sparse_mla_fwd` and `test_example_sparse_mla_fwd_pipelined`.
- Parameterized test functions to use specific small shapes for testing, improving test coverage and clarity.
@LeiWang1999 LeiWang1999 merged commit 5d66d32 into tile-ai:main Sep 29, 2025
4 checks passed
LeiWang1999 added a commit that referenced this pull request Sep 30, 2025
* Remove unused `fp8_mqa_logits.py` file and update README.md to reflect new directory structure and file descriptions for deepseek_v32 example. Added sections for architecture overview, Lightning Indexer, Top-k Selector, and Sparse MLA Forward implementations.

* Update linting configurations and improve code formatting in deepseek_v32 example scripts

- Added per-file ignores for the inference directory in `pyproject.toml`.
- Refactored code in `topk_selector.py`, `convert.py`, `generate.py`, `kernel.py`, and `model.py` to enhance readability by adjusting spacing and line breaks.
- Ensured consistent formatting across function definitions and assertions for better clarity.

* Refactor test functions in deepseek_v32 example scripts for improved clarity and consistency

- Updated `fp8_lighting_indexer.py` to define a dedicated test function for the lighting indexer.
- Refactored `sparse_mla_fwd_pipelined.py` and `sparse_mla_fwd.py` to standardize test function parameters and improve readability.
- Enhanced `topk_selector.py` by introducing a test function with parameters for batch size and sequence length.
- Ensured all test functions are invoked correctly in the main execution block.

* Enhance test functions in deepseek_v32 example scripts with CUDA requirements and parameterization

- Added CUDA requirements decorators to `test_example_sparse_mla_fwd` and `test_example_sparse_mla_fwd_pipelined`.
- Parameterized test functions to use specific small shapes for testing, improving test coverage and clarity.

* lint fix

* Update README.md to correct image path for DeepSeek V3.2 architecture diagram
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant