Skip to content

Conversation

@botbw
Copy link
Contributor

@botbw botbw commented Aug 4, 2025

As title

Benchmark file: benchmark/matmul/benchmark_matmul_sp.py

Benchmark result on 4090:

mnk Ref TFLOPS (CuBLAS Dense)
(2 experiments)
Ref TFLOPS
(Torch, CUTlASS backend)
Ref TFLOPS
(Torch, CUSPARSELT backend)
Best TFLOPS
(TileLang Sparse, fp32 accum)
Best TFLOPS
(TileLang Sparse, fp16 accum)
1024 77.418 / 64.313 85.820 4.628 123.362 131.072
2048 144.972 / 145.965 238.84 35.703 218.418 322.639
4096 161.983 / 163.666 295.478 256.995 268.973 434.537
8192 153.190 / 152.554 287.410 260.614 278.677 468.269
16284 154.379 / 154.395 195.060 183.264 282.267 468.142

Thoretical 4090 FP16 Sparse Tensor Core TFLOPS:

Tag Ref Sparse TFLOPS
Peak FP16 Tensor TFLOPS with FP32 Accumulate 330.4
Peak FP16 Tensor TFLOPS with FP16 Accumulate 660.6

Summary by CodeRabbit

  • New Features

    • Added SM80 support and automatic GPU-arch detection for sparse GEMM, with per-arch tiling and accumulation dtype controls.
    • New CLI flags to disable caching, select accumulation dtype, and benchmark against torch sparse.
    • New compression API that dispatches to SM80/SM90 paths; non-transposed B layout and improved benchmarking metrics.
  • Documentation

    • New sparse GEMM example demonstrating setup, validation, and benchmarking.
  • Tests

    • Split and expanded tests for SM90 and SM80 with broader dtype and block coverage.
  • Chores

    • Added int16 device debug print and minor formatting fixes.

@github-actions
Copy link

github-actions bot commented Aug 4, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @botbw, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for sparse General Matrix Multiply (GEMM) operations on NVIDIA Ampere architecture GPUs. It integrates new CUTLASS-based kernels and updates the system's layout inference and metadata generation to handle Ampere-specific requirements.

Highlights

  • Ampere Architecture Support: Added support for sparse General Matrix Multiply (GEMM) operations on NVIDIA Ampere (SM80/SM89) GPUs.
  • CUTLASS Integration for Sparse GEMM: Implemented the core sparse GEMM kernel for Ampere using CUTLASS, including instruction shapes and shared memory layouts for various data types.
  • Metadata Layout Generation: Extended the layout generation logic to create metadata layouts compatible with CUTLASS SM8x sparse kernels, handling 16-bit metadata types and column-major interleaved layouts.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for sparse GEMM on NVIDIA's Ampere architecture (sm80). The changes include new C++ CUDA kernel implementations using CUTLASS, updates to the operator logic in C++ to handle Ampere targets, and new Python layout functions for metadata. My review focuses on improving code maintainability by reducing duplication, enhancing debuggability with better error messages, and adhering to language-specific best practices and style guides (C++ casts, Python PEP 8 naming).

@botbw botbw changed the title [feat] add gemm_sp for ampere arch [feat] support gemm_sp for ampere arch Aug 11, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 11, 2025

Caution

Review failed

The pull request is closed.

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds SM80/SM90 structured-sparse GEMM support and arch-aware dispatch across Python and C++: new CUDA templates, layouts, GemmSP warp-policy, compress/annotate_layout dispatch, tests split per-arch, and benchmark/example updates (B layout, E tiling, accum dtype, CLI flags, optional torch-sparse).

Changes

Cohort / File(s) Summary
Benchmarks & Examples
benchmark/matmul/benchmark_matmul_sp.py, examples/gemm_sp/example_gemm_sp.py, examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
Make benchmark and examples arch-aware: change B layout to (K,N), add module-level arch and ARCH_INFO, extend matmul_sp signature with accum_dtype, add CLI flags (--disable_cache, --accum_dtype, --bench_torch_sparse), update E tiling/typing and example metadata strings.
Layout API (C++)
src/layout/gemm_layouts.cc, src/layout/layout.h
Add sparse C fragment, CDNA C fragment, 128B tensor-op multiplicand layout, and Ampere sparse AB layout; expose new factory functions in header with docs.
GemmSP Operator (C++)
src/op/gemm_sp.cc, src/op/gemm_sp.h
Introduce GemmSP-specific warp policy (bit-width-aware ComputeWarpPartition), switch GemmSP policy type to GemmSPWarpPolicy, pass dtype bits to partitioning, add Ampere infer/lower paths using new sparse fragments/layouts.
CUDA Templates & CUDA helpers
src/tl_templates/cuda/gemm_sp.h, src/tl_templates/cuda/gemm_sp_sm80.h, src/tl_templates/cuda/gemm_sp_sm90.h, src/tl_templates/cuda/compress_sm90.cu, src/tl_templates/cuda/debug.h
Add SM80 sparse GEMM implementation/header (gemm_sp_sm80.h), adjust SM90 template alias/GMMA usage, add int16 device debug print specialization, minor formatting tweak in compress_sm90.cu.
Python Layout & Sparse Utils
tilelang/layout/gemm_sp.py, tilelang/utils/sparse.py
Auto-detect arch via nvcc, dispatch make_metadata_layout to SM90/SM8x builders, add compress_sm80 and a compress(..., arch=None) wrapper selecting SM90/SM80 compression backends.
Tests
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
Split tests per-arch (SM90/SM80), add run_gemm_sp_sm90/run_gemm_sp_sm80 wrappers, update compression API usage, broaden dtype/config coverage, and gate tests by compute version.
Miscellaneous
tilelang/language/builtin.py, src/target/codegen_webgpu.cc
Trailing newline addition (no-op) and minor C++ idiomatic improvements (empty(), const & iteration).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Bench as benchmark_matmul_sp.py
  participant NVCC as nvcc
  participant TL as tilelang.layout & utils
  participant Kernel as CUDA kernels (SM80/SM90)
  User->>Bench: run (M,N,K, --accum_dtype, --bench_torch_sparse)
  Bench->>NVCC: get_target_compute_version()
  NVCC-->>Bench: arch_version
  Bench->>TL: make_metadata_layout(..., arch=None, backend="cutlass")
  TL->>NVCC: (if arch None) get_target_compute_version()
  NVCC-->>TL: compute_version
  TL-->>Bench: metadata layout (SM80 or SM90)
  Bench->>Kernel: launch gemm_sp (B:(K,N), E tiled per-arch, accum dtype)
  Kernel-->>Bench: C result + timing
  alt bench_torch_sparse requested
    Bench->>Kernel: run torch-sparse path
    Kernel-->>Bench: torch sparse TFlops
  end
Loading
sequenceDiagram
  autonumber
  actor PyTest
  participant Test as test_tilelibrary_gemm_sp.py
  participant Utils as tilelang.utils.sparse
  participant NVCC as nvcc
  participant Kernel as SM80/SM90 kernels
  PyTest->>Test: run_gemm_sp_sm80/sm90(...)
  Test->>Utils: compress(A, transposed, arch=None)
  Utils->>NVCC: get_target_compute_version()
  NVCC-->>Utils: compute_version
  alt compute_version >= 9.0
    Utils-->>Test: compress_sm90(A,...)
  else compute_version >= 8.0
    Utils-->>Test: compress_sm80(A,...)
  end
  Test->>Kernel: launch arch-specific kernel
  Kernel-->>Test: output & validate
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Poem

In the lattice I nibble, tiles in a row,
SM80 and SM90 help carrots to grow.
B turns its jacket, E finds its stride,
Warps line up neatly—no meander, no slide.
I hop, I compress, TFlops gleam—hooray! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[feat] support gemm_sp for ampere arch" succinctly and accurately summarizes the primary change in the PR—adding gemm_sp support for Ampere (SM80) architectures along with related kernel, layout, and benchmark updates—and is directly related to the changeset and PR description. It is concise, focused on a single feature, and clear enough for a teammate scanning history.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 597b66f and 27b4491.

📒 Files selected for processing (1)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (6 hunks)

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

This could be a missing check related to cutlass 2.x implementation.
Using the cutlass example can't trigger this cause it's bypassed by
padding the input.

For now I think it might be safe to increase the atom size and inve-
sgate in the future.
@botbw botbw marked this pull request as ready for review September 12, 2025 09:29
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/gemm_sp.h (1)

91-98: Equality/hash contract mismatch: policy ignored in equality but included in hash.

SEqualReduce should include policy to stay consistent with SHashReduce and avoid hard-to-track bugs in memoization/Maps.

-  bool SEqualReduce(const GemmSPNode *other, SEqualReducer equal) const {
-    return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
-           equal(E, other->E) && equal(trans_A, other->trans_A) &&
-           equal(trans_B, other->trans_B) && equal(M, other->M) &&
-           equal(N, other->N) && equal(K, other->K) &&
-           equal(clear_accum, other->clear_accum) &&
-           equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait);
-  }
+  bool SEqualReduce(const GemmSPNode *other, SEqualReducer equal) const {
+    return equal(policy, other->policy) &&
+           equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
+           equal(E, other->E) && equal(trans_A, other->trans_A) &&
+           equal(trans_B, other->trans_B) && equal(M, other->M) &&
+           equal(N, other->N) && equal(K, other->K) &&
+           equal(clear_accum, other->clear_accum) &&
+           equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait);
+  }
♻️ Duplicate comments (3)
tilelang/layout/gemm_sp.py (1)

125-134: PEP8 naming for locals (duplicate of past review).

Rename kInterleaved -> K_INTERLEAVED (constant) or k_interleaved, and ColumnMajorInterleaved -> column_major_interleaved for consistency.

-    kInterleaved = 2
-    stride = buffer.shape[0] * kInterleaved
+    K_INTERLEAVED = 2
+    stride = buffer.shape[0] * K_INTERLEAVED
@@
-    def ColumnMajorInterleaved(i: int, j: int) -> int:
-        column_major = j // kInterleaved
-        column_minor = j % kInterleaved
-        return column_major * stride + i * kInterleaved + column_minor
+    def column_major_interleaved(i: int, j: int) -> int:
+        column_major = j // K_INTERLEAVED
+        column_minor = j % K_INTERLEAVED
+        return column_major * stride + i * K_INTERLEAVED + column_minor
@@
-    return T.Layout(buffer.shape, ColumnMajorInterleaved)
+    return T.Layout(buffer.shape, column_major_interleaved)
src/op/gemm_sp.cc (1)

261-295: DRY the Ampere A/B layout code and improve ICHECK messages

There’s duplication across A and B branches and bare ICHECK(0). Factor into a helper and emit descriptive messages. This mirrors prior feedback.

@@
-    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
-      int dim_A = A->shape.size();
-      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
-      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
-      results.Set(A, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
-                                                  A->dtype.bits()));
-    } else if (A.scope() == "local.fragment") {
-      // auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
-      //                                   A->dtype.bits(), trans_A);
-      // results.Set(A, fragment->BindThreadRange(thread_range));
-      ICHECK(false) << "Not Implemented";
-    } else {
-      ICHECK(0);
-    }
+    auto set_ampere_ab_layout = [&](const Buffer& buf, const char* name) {
+      if (buf.scope() == "shared" || buf.scope() == "shared.dyn") {
+        int dim = buf->shape.size();
+        const int64_t mat_stride = *as_const_int(buf->shape[dim - 2]);
+        const int64_t mat_continuous = *as_const_int(buf->shape[dim - 1]);
+        results.Set(buf, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
+                                                      buf->dtype.bits()));
+      } else if (buf.scope() == "local.fragment") {
+        ICHECK(false) << "Ampere GEMM_SP does not support local.fragment for " << name;
+      } else {
+        ICHECK(false) << "Unsupported scope for Ampere GEMM_SP (" << name << "): " << buf.scope();
+      }
+    };
+    set_ampere_ab_layout(A, "A");
@@
-    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
-      int dim_B = B->shape.size();
-      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
-      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
-      results.Set(B, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
-                                                  B->dtype.bits()));
-    } else if (B.scope() == "local.fragment") {
-      // auto fragment =
-      //     makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
-      // results.Set(B, fragment->BindThreadRange(thread_range));
-      ICHECK(false) << "Not Implemented";
-    } else {
-      ICHECK(0);
-    }
+    set_ampere_ab_layout(B, "B");
src/tl_templates/cuda/gemm_sp_sm80.h (1)

222-231: Prefer C++ casts for TensorRef construction

Replace C-style casts with reinterpret_cast for clarity and type safety.

-    const TensorRefA ref_A(
-        (ElementA *)pA,
+    const TensorRefA ref_A(
+        reinterpret_cast<ElementA *>(pA),
         MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}));
-    const TensorRefE ref_E(
-        (ElementE *)pE,
+    const TensorRefE ref_E(
+        reinterpret_cast<ElementE *>(pE),
         MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn}));
-    const TensorRefB ref_B(
-        (ElementB *)pB,
+    const TensorRefB ref_B(
+        reinterpret_cast<ElementB *>(pB),
         MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}));
🧹 Nitpick comments (18)
src/tl_templates/cuda/debug.h (2)

229-238: Safer printf arg type and include for fixed-width ints

  • Use an int cast with %d to avoid potential mismatch with int32_t across platforms.
  • Ensure <cstdint> (or <stdint.h>) is included for int16_t.

Apply this diff in this hunk:

-         threadIdx.z, buf_name, index, (int32_t)var);
+         threadIdx.z, buf_name, index, static_cast<int>(var));

If not already available via transitive headers, add near the top of this header:

#include <cstdint>  // for int16_t

229-238: Add matching int16_t/uint16_t specializations for completeness

For parity with other dtypes, consider:

  • Adding debug_print_var<int16_t> (scalar) specialization.
  • Optionally adding debug_print_buffer_value<uint16_t> and debug_print_var<uint16_t>.

Example additions (outside this hunk):

// scalar int16_t
template <> __device__ void debug_print_var<int16_t>(const char* msg, int16_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int16_t value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z,
         static_cast<int>(var));
}

// buffer uint16_t
template <>
__device__ void debug_print_buffer_value<uint16_t>(const char* msg, const char* buf_name,
                                                   int index, uint16_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, index=%d, dtype=uint16_t value=%u\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z,
         buf_name, index, static_cast<unsigned int>(var));
}

// scalar uint16_t
template <> __device__ void debug_print_var<uint16_t>(const char* msg, uint16_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=uint16_t value=%u\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z,
         static_cast<unsigned int>(var));
}

Happy to push a patch if you want.

src/layout/gemm_layouts.cc (3)

138-157: Sparse C fragment tiling order looks correct; add explicit unsupported msg.

Replace the bare ICHECK(false) with a message for element_size == 64 to aid diagnostics.

Apply:

-  if (element_size == 64) {
-    ICHECK(false) << "Not supported";
-  }
+  if (element_size == 64) {
+    ICHECK(false) << "makeGemmSparseFragmentC: element_size == 64 is not supported";
+  }

593-681: Guard assumptions and improve naming in makeTensorOpMultiplicand.

  • Add sanity checks for elementsize and positive dims.
  • Rename local stride to ld_contig to avoid confusion with function arg mat_stride.

Proposed edits:

-  static int const kAccessSize = 128;
+  static constexpr int kAccessSize = 128;
+  ICHECK(mat_stride > 0 && mat_continuous > 0) << "invalid matrix extents";
+  ICHECK(elementsize == 8 || elementsize == 16 || elementsize == 32)
+      << "unsupported element size (bits): " << elementsize;

@@
-  const int stride = mat_continuous;
+  const int ld_contig = mat_continuous;
@@
-  return Layout(Array{i, j},
-                {element_contiguous + element_strided * stride * kFactor});
+  return Layout(Array{i, j},
+                {element_contiguous + element_strided * ld_contig * kFactor});

683-689: Crosswise selection: consider constraining to a power-of-two.

Ampere tensor-op multiplicand layouts typically expect crosswise to be a small power-of-two divisor. Clamping to the nearest supported power-of-two (e.g., min(128, next_pow2(...))) would make this more robust across shapes.

benchmark/matmul/benchmark_matmul_sp.py (1)

198-199: Avoid extra C_shared roundtrip.

Copy C_local directly to global C to reduce traffic and simplify the kernel.

Apply:

-                # Allocate a shared memory for C sub-block of shape (block_M, block_N)
-                C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
@@
-                T.copy(C_local, C_shared)
-                T.copy(C_shared, C[by * block_M, bx * block_N])
+                T.copy(C_local, C[by * block_M, bx * block_N])

Also applies to: 231-234

src/layout/layout.h (2)

181-185: Unify parameter naming to match existing API style.

Use element_size (with underscore) for consistency with other declarations in this header.

-Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous,
-                                int elementsize, int crosswise);
-Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous,
-                                    int elementsize);
+Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous,
+                                int element_size, int crosswise);
+Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous,
+                                    int element_size);

140-142: Consider exporting new symbols (optional).

If these factories are used across shared library boundaries, annotate with TVM_DLL for consistent visibility, mirroring constructors above. If not needed, ignore.

Also applies to: 181-185

tilelang/utils/sparse.py (1)

63-64: Shorten exception messages to satisfy TRY003 (ruff).

Move long messages into variables or keep them concise.

Also applies to: 95-96

src/op/gemm_sp.h (1)

19-24: Optional: Add reflection for GemmSPWarpPolicyNode if used via FFI.

If policy is exposed through TVM FFI, consider adding reflection hooks (mirroring GemmWarpPolicyNode) for robustness. If unused externally, ignore.

tilelang/layout/gemm_sp.py (1)

93-95: Typo: rep_k_stirde.

Minor readability nit; rename to rep_k_stride.

-    rep_k_stirde = prod(shape_i + shape_k)
+    rep_k_stride = prod(shape_i + shape_k)
@@
-    stride_k.append(rep_k_stirde)
+    stride_k.append(rep_k_stride)
examples/gemm_sp/example_gemm_sp.py (2)

17-18: Handle more SM8x variants (e.g., 8.6).

ARCH_INFO misses 8.6 (A100/3080 class). Derive e_factor/e_dtype from compute_version instead of hard-coding strings.

-ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
+major_minor = tuple(map(int, arch.split(".")[:2]))
+ARCH_INFO = {**({} if major_minor < (8,0) else {
+    arch: (8, "uint8") if major_minor >= (9,0) else (16, "int16")
+})}

Alternatively:

e_factor, e_dtype = ((8, "uint8") if tuple(map(int, arch.split(".")[:2])) >= (9, 0) else (16, "int16"))

145-147: Fix assertion message to reference the correct tensor.

You’re checking c for NaNs; update the message accordingly.

-    assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
+    assert not c.isnan().any(), "Kernel result contains NaNs, please report an issue"
src/op/gemm_sp.cc (1)

157-165: Correct scope error messages to match the actual check

The check allows "shared" or "shared.dyn", but the message says “Only support shared.dyn”. Align the text with behavior.

-  ICHECK((A.scope() == "shared" || A.scope() == "shared.dyn") &&
-         (B.scope() == "shared" || B.scope() == "shared.dyn"))
-      << "Only support shared.dyn scope for A and B, but received " << A.scope()
+  ICHECK((A.scope() == "shared" || A.scope() == "shared.dyn") &&
+         (B.scope() == "shared" || B.scope() == "shared.dyn"))
+      << "Only support shared or shared.dyn scope for A and B, but received " << A.scope()
       << " and " << B.scope();
@@
-  ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn"))
-      << "Only support shared.dyn scope for E as copy from smem to rmem are "
-         "delegated to cute implementation, found "
+  ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn"))
+      << "Only support shared or shared.dyn scope for E (smem->rmem copy is delegated), found "
       << E.scope();
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (2)

231-233: Pass architecture explicitly to compression when testing cross-arch

For deterministic testing across machines, consider fixing arch in compress() to match the kernel under test (SM80 here). Otherwise, a SM90 host will take the SM90 compression path.

-    A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
+    A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0" if "sm80" in kernel.__name__ else "9.0")

130-138: Annotate SM80 metadata dtype explicitly in comments

E is int16/int32 on SM80 but uint8 on SM90. A short comment here will prevent accidental cross-arch reuse.

No code change needed; add a one-line comment clarifying E’s dtype per arch.

src/tl_templates/cuda/gemm_sp_sm80.h (2)

262-266: Avoid C-style cast for accum; also consider aligning warp-id mapping with SM90

Use reinterpret_cast for FragmentC. Additionally, SM90 uses (warp_id / num_warp_n, warp_id % num_warp_n). Please verify mapping consistency across archs.

-  MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m,
-            warp_id / num_warp_m, lane_id);
+  auto &acc_ref = *reinterpret_cast<FragmentC *>(accum);
+  MMA::body(pA, pE, pB, acc_ref,
+            warp_id % num_warp_m,  // consider: warp_id / num_warp_n
+            warp_id / num_warp_m,  // consider: warp_id % num_warp_n
+            lane_id);

1-4: Clean includes: drop unused stdio and add missing headers for traits/types

stdio.h is unused. Add type_traits (used by std::is_same_v) and cstdint (int8_t/uint8_t).

-#include <stdio.h>
+#include <type_traits>
+#include <cstdint>
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 409ab83 and fecadc8.

📒 Files selected for processing (14)
  • benchmark/matmul/benchmark_matmul_sp.py (6 hunks)
  • examples/gemm_sp/example_gemm_sp.py (1 hunks)
  • src/layout/gemm_layouts.cc (2 hunks)
  • src/layout/layout.h (2 hunks)
  • src/op/gemm_sp.cc (5 hunks)
  • src/op/gemm_sp.h (2 hunks)
  • src/tl_templates/cuda/compress_sm90.cu (1 hunks)
  • src/tl_templates/cuda/debug.h (1 hunks)
  • src/tl_templates/cuda/gemm_sp.h (1 hunks)
  • src/tl_templates/cuda/gemm_sp_sm80.h (1 hunks)
  • src/tl_templates/cuda/gemm_sp_sm90.h (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (6 hunks)
  • tilelang/layout/gemm_sp.py (3 hunks)
  • tilelang/utils/sparse.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (10)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (6)
  • makeGemmSparseFragmentC (138-157)
  • makeGemmSparseFragmentC (138-140)
  • makeTensorOpMultiplicand (593-681)
  • makeTensorOpMultiplicand (593-594)
  • makeGemmSparseAmpereABLayout (683-688)
  • makeGemmSparseAmpereABLayout (683-684)
examples/gemm_sp/example_gemm_sp.py (4)
tilelang/layout/gemm_sp.py (1)
  • make_metadata_layout (136-158)
tilelang/utils/sparse.py (1)
  • compress (78-96)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (257-298)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/layout/gemm_sp.py (1)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (257-298)
src/op/gemm_sp.h (3)
src/op/gemm.h (3)
  • GemmWarpPolicyNode (25-70)
  • int (28-28)
  • GemmWarpPolicyType (18-201)
src/op/gemm_sp.cc (2)
  • ComputeWarpPartition (21-65)
  • ComputeWarpPartition (21-25)
src/op/gemm.cc (2)
  • ComputeWarpPartition (110-280)
  • ComputeWarpPartition (111-112)
tilelang/utils/sparse.py (2)
tilelang/contrib/nvcc.py (2)
  • get_target_compute_version (257-298)
  • parse_compute_version (301-323)
src/tl_templates/cuda/compress_sm90.cu (2)
  • compress_sm90 (156-159)
  • compress_sm90 (156-156)
src/tl_templates/cuda/gemm_sp_sm80.h (1)
src/tl_templates/cuda/gemm_sp_sm90.h (1)
  • gemm_sp_ss (224-231)
benchmark/matmul/benchmark_matmul_sp.py (5)
tilelang/layout/gemm_sp.py (1)
  • make_metadata_layout (136-158)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (257-298)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (3)
  • matmul_sp (9-61)
  • main (30-59)
  • main (126-127)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp (9-86)
tilelang/env.py (1)
  • disable_cache (232-233)
src/layout/gemm_layouts.cc (1)
src/layout/layout.cc (2)
  • Layout (57-70)
  • Layout (72-75)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (3)
tilelang/utils/sparse.py (1)
  • compress (78-96)
tilelang/layout/gemm_sp.py (1)
  • make_metadata_layout (136-158)
tilelang/env.py (1)
  • disable_cache (232-233)
src/op/gemm_sp.cc (4)
src/op/gemm.cc (2)
  • ComputeWarpPartition (110-280)
  • ComputeWarpPartition (111-112)
src/target/utils.cc (4)
  • TargetGetWarpSize (114-119)
  • TargetGetWarpSize (114-114)
  • TargetIsAmpere (42-47)
  • TargetIsAmpere (42-42)
src/op/gemm_sp.h (4)
  • GemmSPWarpPolicy (26-50)
  • GemmSPWarpPolicy (31-35)
  • GemmSPWarpPolicy (37-41)
  • GemmSPWarpPolicy (43-49)
src/layout/gemm_layouts.cc (4)
  • makeGemmSparseFragmentC (138-157)
  • makeGemmSparseFragmentC (138-140)
  • makeGemmSparseAmpereABLayout (683-688)
  • makeGemmSparseAmpereABLayout (683-684)
🪛 Ruff (0.12.2)
tilelang/layout/gemm_sp.py

120-120: Avoid specifying long messages outside the exception class

(TRY003)


123-123: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/utils/sparse.py

63-64: Avoid specifying long messages outside the exception class

(TRY003)


95-96: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (4)
src/tl_templates/cuda/compress_sm90.cu (1)

150-151: No-op formatting change is fine.

Macro continuation and control flow remain correct. No action needed.

benchmark/matmul/benchmark_matmul_sp.py (1)

247-257: ```shell
#!/bin/bash
set -euo pipefail
echo "PWD: $(pwd)"
echo "--- top-level ---"
ls -la | sed -n '1,200p'
echo "--- first few files ---"
find . -maxdepth 3 -type f | sed -n '1,200p' || true

echo "--- search for 'disable_cache' occurrences (including ignored files) ---"
rg -n --hidden -S -uu 'disable_cache' -C2 || true

echo "--- search for Python function defs named disable_cache ---"
rg -n --hidden -S -uu 'def\s+disable_cache\s*(' -C2 || true

echo "--- search for 'tilelang' references ---"
rg -n --hidden -S -uu '\btilelang\b' -C2 || true

echo "--- search for Env / Environment classes ---"
rg -n --hidden -S -uu 'class\s+\w*(Env|Environment)\b' -C3 || true

echo "--- list tilelang dir or env.py if present ---"
if [ -d "tilelang" ]; then echo "tilelang/ exists:"; ls -la tilelang || true; fi
rg -n --hidden -S -uu 'env.py' -C2 || true

echo "--- show lines around the call site in benchmark/matmul/benchmark_matmul_sp.py ---"
sed -n '1,240p' benchmark/matmul/benchmark_matmul_sp.py || true
sed -n '240,320p' benchmark/matmul/benchmark_matmul_sp.py || true


</blockquote></details>
<details>
<summary>src/tl_templates/cuda/gemm_sp_sm90.h (1)</summary><blockquote>

`216-228`: **GMMA aliasing and call site switch LGTM.**

Template alias and the `GMMA::body` invocation are consistent with the rest of the file.

</blockquote></details>
<details>
<summary>examples/gemm_sp/example_gemm_sp.py (1)</summary><blockquote>

`149-156`: **Confirm do_bench units; TFLOPS/latency may be misreported.**

triton.testing.do_bench returns time in ms in recent versions. If so, convert to seconds when computing TFLOPS and print latency accordingly.


```diff
-    latency = do_bench(lambda: kernel(a_sparse, e, b))
-    ref_latency = do_bench(lambda: a @ b)
-
-    total_flops = 2 * args.m * args.n * args.k
-    tflops = total_flops / latency / 1e9
-    ref_tflops = total_flops / ref_latency / 1e9
-    print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
-    print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
+    latency_ms = do_bench(lambda: kernel(a_sparse, e, b))
+    ref_latency_ms = do_bench(lambda: a @ b)
+    total_flops = 2 * args.m * args.n * args.k
+    tflops = total_flops / (latency_ms / 1e3) / 1e12
+    ref_tflops = total_flops / (ref_latency_ms / 1e3) / 1e12
+    print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency_ms/1e3:.6f} s")
+    print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency_ms/1e3:.6f} s")

If your triton version returns µs, adjust divisors accordingly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
src/op/gemm_sp.cc (2)

260-293: DRY the Ampere A/B layout setup and improve failure messages

Refactor duplicated A/B layout code into a small lambda and replace bare ICHECK(0) with descriptive messages. This was previously suggested.

-    if (A.scope() == "shared" || A.scope() == "shared.dyn") {
-      int dim_A = A->shape.size();
-      const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
-      const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
-      results.Set(A, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
-                                                  A->dtype.bits()));
-    } else if (A.scope() == "local.fragment") {
-      // auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
-      //                                   A->dtype.bits(), trans_A);
-      // results.Set(A, fragment->BindThreadRange(thread_range));
-      ICHECK(false) << "Not Implemented";
-    } else {
-      ICHECK(0);
-    }
-    if (B.scope() == "shared" || B.scope() == "shared.dyn") {
-      int dim_B = B->shape.size();
-      const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
-      const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
-      results.Set(B, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
-                                                  B->dtype.bits()));
-    } else if (B.scope() == "local.fragment") {
-      // auto fragment =
-      //     makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
-      // results.Set(B, fragment->BindThreadRange(thread_range));
-      ICHECK(false) << "Not Implemented";
-    } else {
-      ICHECK(0);
-    }
+    auto set_ampere_ab_layout = [&](const Buffer& buf) {
+      if (buf.scope() == "shared" || buf.scope() == "shared.dyn") {
+        int dim = buf->shape.size();
+        const int64_t mat_stride = *as_const_int(buf->shape[dim - 2]);
+        const int64_t mat_continuous = *as_const_int(buf->shape[dim - 1]);
+        results.Set(buf, makeGemmSparseAmpereABLayout(
+                             mat_stride, mat_continuous, buf->dtype.bits()));
+      } else if (buf.scope() == "local.fragment") {
+        ICHECK(false) << "Not Implemented for scope: " << buf.scope();
+      } else {
+        ICHECK(false) << "Unsupported scope for Ampere GEMM_SP: " << buf.scope();
+      }
+    };
+    set_ampere_ab_layout(A);
+    set_ampere_ab_layout(B);

39-61: Critical: Ampere warp partition can hit divide-by-zero; replace with safe factor search

m_warp = M / m_atom_size can exceed num_warps, yielding n_warp = 0 and N/0 before any ICHECK. Same risk on the N path. Replace with a factor search that preserves m_warp * n_warp == num_warps and atom alignment, avoiding 0 divisors.

-  if (TargetIsAmpere(target)) {
-    int warp_shape_m = M / m_warp;
-    int warp_shape_n = N / n_warp;
-    if (warp_shape_m % m_atom_size) { // GemmWarpPolicy::kFullRow
-      m_warp = M / m_atom_size;
-      ICHECK(m_warp > 0) << err_msg;
-      n_warp = num_warps / m_warp;
-      warp_shape_n = N / n_warp;
-      ICHECK(warp_shape_n % n_atom_size == 0) << err_msg;
-    } else if (warp_shape_n % n_atom_size != 0) { // GemmWarpPolicy::kFullColumn
-      n_warp = N / n_atom_size;
-      ICHECK(n_warp > 0) << err_msg;
-      m_warp = num_warps / n_warp;
-      warp_shape_m = M / m_warp;
-      ICHECK(warp_shape_m % m_atom_size == 0) << err_msg;
-    }
-    ICHECK(m_warp * n_warp == num_warps)
-        << "m_warp * n_warp must equal num_warps, please report an issue when "
-           "encounter this"
-        << ", m_warp: " << m_warp << ", n_warp: " << n_warp << ", num_warps"
-        << num_warps;
-    this->m_warp = m_warp;
-    this->n_warp = n_warp;
-  }
+  if (TargetIsAmpere(target)) {
+    auto aligned = [&](int mw, int nw) {
+      return mw > 0 && nw > 0 &&
+             mw * nw == num_warps &&
+             (M / mw) % m_atom_size == 0 &&
+             (N / nw) % n_atom_size == 0;
+    };
+    if (!aligned(m_warp, n_warp)) {
+      int best_m = -1, best_n = -1;
+      // Prefer fixing M first
+      for (int cand_m = std::min(M / m_atom_size, num_warps); cand_m >= 1; --cand_m) {
+        if (num_warps % cand_m) continue;
+        int cand_n = num_warps / cand_m;
+        if (aligned(cand_m, cand_n)) { best_m = cand_m; best_n = cand_n; break; }
+      }
+      // If not found, try fixing N
+      if (best_m < 0) {
+        for (int cand_n = std::min(N / n_atom_size, num_warps); cand_n >= 1; --cand_n) {
+          if (num_warps % cand_n) continue;
+          int cand_m = num_warps / cand_n;
+          if (aligned(cand_m, cand_n)) { best_m = cand_m; best_n = cand_n; break; }
+        }
+      }
+      ICHECK(best_m > 0 && best_n > 0) << err_msg;
+      m_warp = best_m;
+      n_warp = best_n;
+    }
+    ICHECK(m_warp * n_warp == num_warps)
+        << "m_warp * n_warp must equal num_warps"
+        << ", m_warp: " << m_warp << ", n_warp: " << n_warp
+        << ", num_warps: " << num_warps;
+    this->m_warp = m_warp;
+    this->n_warp = n_warp;
+  }
🧹 Nitpick comments (3)
src/op/gemm_sp.cc (1)

174-178: Avoid unchecked access to buffer_remap for C

Accessing T.buffer_remap[C] without count() can default-construct and crash later. Mirror A/B handling.

-  auto C_buffer = T.buffer_remap[C];
+  auto C_buffer = T.buffer_remap.count(C) ? T.buffer_remap[C] : C;
src/layout/gemm_layouts.cc (2)

683-688: Ensure crosswise divides the 1024/element product: use gcd for robustness

min(mat_continuous, 1024/elementsize) can yield a kCrosswise that doesn’t divide 1024/elementsize, causing fractional factors internally. Choose gcd to guarantee integer tiling.

-Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous,
-                                    int elementsize) {
-  int kCrosswise = std::min(mat_continuous, (1024 / elementsize));
+Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous,
+                                    int elementsize) {
+  int max_cross = 1024 / elementsize;  // 8 * (128 / elementsize)
+  int kCrosswise = std::gcd(mat_continuous, max_cross);
+  if (kCrosswise == 0) kCrosswise = std::min(mat_continuous, max_cross);
   return makeTensorOpMultiplicand(mat_stride, mat_continuous, elementsize,
                                   kCrosswise);
 }

Add include (outside this hunk):

#include <numeric>  // for std::gcd

606-612: Add divisibility precondition to prevent fractional tile math

Guard that kCrosswise divides kTileShapeContiguous * kElementsPerAccess; otherwise kFactor truncates and can mis-shape the layout.

-  int kFactor = kTileShapeContiguous * kElementsPerAccess / kCrosswise;
-
-  ICHECK(kFactor > 0)
+  int kFactor = kTileShapeContiguous * kElementsPerAccess / kCrosswise;
+  ICHECK(((kTileShapeContiguous * kElementsPerAccess) % kCrosswise) == 0)
+      << "kCrosswise must divide " << (kTileShapeContiguous * kElementsPerAccess)
+      << " for integer vector tiling.";
+  ICHECK(kFactor > 0)
       << "kCrosswise should be no large than one shared memory cache line.";
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fecadc8 and 51655d7.

📒 Files selected for processing (2)
  • src/layout/gemm_layouts.cc (2 hunks)
  • src/op/gemm_sp.cc (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/gemm_sp.cc (4)
src/op/gemm.cc (2)
  • ComputeWarpPartition (110-280)
  • ComputeWarpPartition (111-112)
src/target/utils.cc (4)
  • TargetGetWarpSize (114-119)
  • TargetGetWarpSize (114-114)
  • TargetIsAmpere (42-47)
  • TargetIsAmpere (42-42)
src/op/gemm_sp.h (4)
  • GemmSPWarpPolicy (26-50)
  • GemmSPWarpPolicy (31-35)
  • GemmSPWarpPolicy (37-41)
  • GemmSPWarpPolicy (43-49)
src/layout/gemm_layouts.cc (4)
  • makeGemmSparseFragmentC (138-157)
  • makeGemmSparseFragmentC (138-140)
  • makeGemmSparseAmpereABLayout (683-688)
  • makeGemmSparseAmpereABLayout (683-684)
src/layout/gemm_layouts.cc (1)
src/layout/layout.cc (2)
  • Layout (57-70)
  • Layout (72-75)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check
🔇 Additional comments (2)
src/op/gemm_sp.cc (1)

31-37: Confirm atom-size assumptions for SM80 sparse

bits==16 uses 32×32 atom; others use 16. Please confirm SM80 sparse only expects 16-bit inputs here; if not, add an ICHECK or a comment clarifying supported dtypes.

src/layout/gemm_layouts.cc (1)

138-157: Sparse C-fragment repeat order differs from dense; add minimal verification

Since this reverses the Repeat order vs dense C, add/point to a unit test that validates per-thread index mapping against CUTLASS-like expectations for a few (warp_m, warp_n) cases.

cursor[bot]

This comment was marked as outdated.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/layout/gemm_sp.py (1)

49-55: Validate block_k > 0 and improve divisibility check/message.

As written, block_k == 0 passes the modulus check and leads to invalid shapes; also the error message conflates “too small” with “not a multiple.”

-    BlockK = 512 // bits_map[mma_dtype]
-    if block_k % BlockK != 0:
-        raise ValueError(f"Tile K is too small, which should be at least {BlockK} for {mma_dtype}")
-    NumK = block_k // BlockK  # block_k is MinTileShapeK
+    BlockK = 512 // bits_map[mma_dtype]
+    if block_k <= 0 or block_k % BlockK != 0:
+        raise ValueError(
+            f"block_k must be a positive multiple of {BlockK} for {mma_dtype}, got {block_k}"
+        )
+    NumK = block_k // BlockK  # block_k is MinTileShapeK
♻️ Duplicate comments (3)
tilelang/layout/gemm_sp.py (3)

125-134: PEP 8: use snake_case and constants for interleave layout.

Rename kInterleaved → K_INTERLEAVED and ColumnMajorInterleaved → column_major_interleaved.

-    kInterleaved = 2
-    stride = buffer.shape[0] * kInterleaved
+    K_INTERLEAVED = 2
+    stride = buffer.shape[0] * K_INTERLEAVED
-
-    def ColumnMajorInterleaved(i: int, j: int) -> int:
-        column_major = j // kInterleaved
-        column_minor = j % kInterleaved
-        return column_major * stride + i * kInterleaved + column_minor
+    def column_major_interleaved(i: int, j: int) -> int:
+        column_major = j // K_INTERLEAVED
+        column_minor = j % K_INTERLEAVED
+        return column_major * stride + i * K_INTERLEAVED + column_minor
-
-    return T.Layout(buffer.shape, ColumnMajorInterleaved)
+    return T.Layout(buffer.shape, column_major_interleaved)

4-4: Fix incorrect Optional import (module fails at import time).

Import Optional from typing, not tilelang.autotuner.capture.

-from tilelang.autotuner.capture import Optional
 ...
-from typing import List
+from typing import List, Optional

Also applies to: 10-10


119-124: Bug: dtype guard never matches due to a single combined string.

Split into two entries so fp16/bf16 path actually validates 16‑bit metadata.

-    if mma_dtype in ["float16, bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
+    if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
         raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
🧹 Nitpick comments (5)
tilelang/layout/gemm_sp.py (5)

31-35: Warn before clamping block_k; message currently shows the clamped value.

Preserve the original value in the warning, then clamp.

-    if block_k > 128:
-        block_k = 128
-        # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
-        warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2)
+    if block_k > 128:
+        original_block_k = block_k
+        # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
+        warnings.warn(
+            f"block_k {original_block_k} is too large; clamping to 128 for {mma_dtype}.",
+            stacklevel=2,
+        )
+        block_k = 128

93-95: Typo: rep_k_stirde → rep_k_stride.

Rename for clarity.

-    rep_k_stirde = prod(shape_i + shape_k)
+    rep_k_stride = prod(shape_i + shape_k)
     shape_k.append(rep_k)
-    stride_k.append(rep_k_stirde)
+    stride_k.append(rep_k_stride)

97-105: Remove unnecessary nonlocal and avoid shadowing parameter 'k' in comprehensions.

Cleaner and less error‑prone.

-    def transform(i: int, k: int) -> int:
-        nonlocal shape_i, shape_k, stride_i, stride_k
-        i_decomposed = decompose_col_major(i, shape_i)
-        k_decomposed = decompose_col_major(k, shape_k)
-        i_offset = sum(i_decomposed[k] * stride_i[k] for k in range(len(i_decomposed)))
-        k_offset = sum(k_decomposed[k] * stride_k[k] for k in range(len(k_decomposed)))
-        return i_offset + k_offset
+    def transform(i: int, k: int) -> int:
+        i_decomposed = decompose_col_major(i, shape_i)
+        k_decomposed = decompose_col_major(k, shape_k)
+        i_offset = sum(i_decomposed[idx] * stride_i[idx] for idx in range(len(i_decomposed)))
+        k_offset = sum(k_decomposed[idx] * stride_k[idx] for idx in range(len(k_decomposed)))
+        return i_offset + k_offset

108-114: Optional: validate metadata buffer rank early.

Both SM8x/SM90 paths assume a 2D metadata buffer; add a guard for clearer errors.

 def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str):
@@
-    Args:
-        buffer: metadata buffer shape, for sm80 it should be a 16bit type
+    Args:
+        buffer: metadata buffer shape, for sm80 it should be a 16bit type
@@
+    if len(buffer.shape) != 2:
+        raise ValueError(f"expected 2D metadata buffer, got shape={tuple(buffer.shape)}")

141-145: nvcc helpers verified — parse_compute_version exists; unify compute-version suffix format.

  • parse_compute_version is defined and returns (major:int, minor:int). (tilelang/contrib/nvcc.py)
  • get_target_compute_version returns strings like "8.6" and, for suffixed arches, "9.0.a" (dot before the suffix) — parse_compute_version accepts "9.0.a" but would fail on "9.0a" (no dot).
  • Recommend normalizing the format (either return "9.0a" and update parse_compute_version to strip suffixes, or keep "9.0.a" and document it) — change in tilelang/contrib/nvcc.py.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 51655d7 and 122bff0.

📒 Files selected for processing (2)
  • benchmark/matmul/benchmark_matmul_sp.py (6 hunks)
  • tilelang/layout/gemm_sp.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmark/matmul/benchmark_matmul_sp.py
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/layout/gemm_sp.py (1)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (257-298)
🪛 Ruff (0.12.2)
tilelang/layout/gemm_sp.py

120-120: Avoid specifying long messages outside the exception class

(TRY003)


123-123: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Cursor Bugbot
  • GitHub Check: bot-task
  • GitHub Check: format-check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)

3-3: Dispatch compression and metadata layout by device compute capability (SM80/SM90).

Example currently imports/calls compress_sm90 and hardcodes layout arch="9.0". Use the library dispatcher or runtime CC and set metadata arch from the device capability so the script works on Ampere (sm_80) and Hopper (sm_90) and avoids metadata/layout mismatches.

Apply these minimal changes in examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py:

-from tilelang.utils.sparse import compress_sm90
+from tilelang.utils.sparse import compress
+major, minor = torch.cuda.get_device_capability()
+arch_str = f"{major}.{minor}"
-    A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False)
+    A_sparse, E = compress(A, transposed=False, block_k=block_K)
-                        make_metadata_layout(
-                        E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
+                        make_metadata_layout(
+                        E, mma_dtype="float16", arch=arch_str, backend="cutlass", block_k=block_K),
-                        make_metadata_layout(
-                        E_shared,
-                        mma_dtype="float16",
-                        arch="9.0",
-                        backend="cutlass",
-                        block_k=block_K),
+                        make_metadata_layout(
+                        E_shared,
+                        mma_dtype="float16",
+                        arch=arch_str,
+                        backend="cutlass",
+                        block_k=block_K),
🧹 Nitpick comments (1)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)

41-52: Avoid hard‑coding arch; rely on auto‑detection to support SM80/SM90 seamlessly.

make_metadata_layout already auto-detects the device arch when arch is None. The hard-coded "9.0" will misconfigure metadata on SM80/Ada. Drop the explicit arch to prevent mismatches with the compression path and keep this example portable.

Apply this diff:

             T.annotate_layout({
                 E:
                     make_metadata_layout(
-                        E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
+                        E, mma_dtype="float16", backend="cutlass", block_k=block_K),
                 E_shared:
                     make_metadata_layout(
                         E_shared,
                         mma_dtype="float16",
-                        arch="9.0",
                         backend="cutlass",
                         block_k=block_K),
             })
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 122bff0 and b8e195f.

📒 Files selected for processing (1)
  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
tilelang/layout/gemm_sp.py (1)
  • make_metadata_layout (136-158)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: format-check
  • GitHub Check: bot-task

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
benchmark/matmul/benchmark_matmul_sp.py (1)

165-173: Update the e_factor/e_dtype selection to use the new arch config function.

Replace the dictionary lookup with the proper compute version parsing to prevent runtime failures.

Apply this diff:

-        e_factor, e_dtype = ARCH_INFO[arch]
+        e_factor, e_dtype = get_arch_config(arch)
🧹 Nitpick comments (3)
benchmark/matmul/benchmark_matmul_sp.py (3)

89-89: Add accum_dtype parameter documentation.

The function signature was updated to include accum_dtype but the docstring doesn't document this new parameter.

Add the missing parameter documentation:

     Parameters
     ----------
     M : int
         The dimension M of the matrix multiplication.
     N : int
         The dimension N of the matrix multiplication.
     K : int
         The dimension K of the matrix multiplication.
+    accum_dtype : str
+        The accumulation data type ("float" or "float16").

278-284: Handle potential import failures for torch sparse functionality.

The torch sparse import and functionality may not be available in all environments or PyTorch versions.

Add error handling for the torch sparse functionality:

     if args.bench_torch_sparse is not None:
-        from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
-        if args.bench_torch_sparse == 'cutlass':
-            SparseSemiStructuredTensor._FORCE_CUTLASS = True
-        A_sp = to_sparse_semi_structured(A, transposed=False)
-        torch_sparse_latency = do_bench(lambda: A_sp @ B)
+        try:
+            from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
+            if args.bench_torch_sparse == 'cutlass':
+                SparseSemiStructuredTensor._FORCE_CUTLASS = True
+            A_sp = to_sparse_semi_structured(A, transposed=False)
+            torch_sparse_latency = do_bench(lambda: A_sp @ B)
+        except ImportError as e:
+            print(f"Warning: Torch sparse functionality not available: {e}")
+            args.bench_torch_sparse = None

290-293: Handle torch_sparse_latency reference when benchmarking fails.

If torch sparse benchmarking fails or is disabled, torch_sparse_latency will be undefined when referenced.

Fix the undefined variable reference:

     if args.bench_torch_sparse is not None:
+        try:
             from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
             if args.bench_torch_sparse == 'cutlass':
                 SparseSemiStructuredTensor._FORCE_CUTLASS = True
             A_sp = to_sparse_semi_structured(A, transposed=False)
             torch_sparse_latency = do_bench(lambda: A_sp @ B)
+            print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}")
+        except Exception as e:
+            print(f"Warning: Torch sparse benchmarking failed: {e}")

-    if args.bench_torch_sparse is not None:
-        print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b8e195f and de09434.

📒 Files selected for processing (1)
  • benchmark/matmul/benchmark_matmul_sp.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmark/matmul/benchmark_matmul_sp.py (5)
tilelang/layout/gemm_sp.py (1)
  • make_metadata_layout (136-158)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (257-298)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (3)
  • matmul_sp (9-61)
  • main (30-59)
  • main (126-127)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp (9-86)
tilelang/env.py (1)
  • disable_cache (232-233)
🪛 GitHub Actions: CI Test on AMD
benchmark/matmul/benchmark_matmul_sp.py

[error] 1-1: Clang-format reformatted this file. Changes not staged for commit.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: format-check
  • GitHub Check: bot-task
🔇 Additional comments (4)
benchmark/matmul/benchmark_matmul_sp.py (4)

93-94: Update matrix shape documentation to reflect layout changes.

The documentation incorrectly states B has shape (K, N), but based on the code changes, B now uses transpose_B=False with shape (K, N).

The documentation and implementation are now consistent with the new B matrix layout.


246-264: LGTM! Good CLI argument additions.

The new command-line arguments enhance the benchmark's flexibility:

  • --disable_cache allows bypassing TileLang caching for development
  • --accum_dtype enables testing different accumulation precisions
  • --bench_torch_sparse provides comparative benchmarking against PyTorch sparse implementations

228-228: Verify transpose_B=False is correct for the new layout.

The change from default (likely transpose_B=True) to transpose_B=False aligns with the B matrix layout change from (N, K) to (K, N).

This change correctly reflects the new B matrix layout where B is now shaped as (K, N) instead of (N, K).


208-212: Resolved — make_metadata_layout auto-detects arch and only forwards block_k for SM90.

make_metadata_layout sets arch via nvcc.get_target_compute_version() when arch is None, parses the compute version, and forwards **extra_args (including block_k) only in the compute_version >= (9, 0) branch; the SM8x branch does not forward extra_args. Current calls that omit arch but pass block_k are safe.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tilelang/layout/gemm_sp.py (1)

125-133: PEP 8 naming for locals.

Use CONSTANT_CASE for constants and snake_case for callables.

Apply:

-    kInterleaved = 2
-    stride = buffer.shape[0] * kInterleaved
+    K_INTERLEAVED = 2
+    stride = buffer.shape[0] * K_INTERLEAVED
@@
-    def ColumnMajorInterleaved(i: int, j: int) -> int:
-        column_major = j // kInterleaved
-        column_minor = j % kInterleaved
-        return column_major * stride + i * kInterleaved + column_minor
+    def column_major_interleaved(i: int, j: int) -> int:
+        column_major = j // K_INTERLEAVED
+        column_minor = j % K_INTERLEAVED
+        return column_major * stride + i * K_INTERLEAVED + column_minor
@@
-    return T.Layout(buffer.shape, ColumnMajorInterleaved)
+    return T.Layout(buffer.shape, column_major_interleaved)
benchmark/matmul/benchmark_matmul_sp.py (1)

18-21: Remove fragile ARCH_INFO mapping; derive by compute version.

Hardcoding will KeyError for 8.6/8.7/8.9/9.0a and future arches.

Apply:

-arch = nvcc.get_target_compute_version()
-
-ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
+arch = nvcc.get_target_compute_version()

And compute e_factor/e_dtype at use site (see next comment).

🧹 Nitpick comments (12)
tilelang/layout/gemm_sp.py (4)

31-35: Fix clamped-warning to show the original value.

You mutate block_k before warning, so the message is misleading.

Apply:

-    if block_k > 128:
-        block_k = 128
-        warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2)
+    if block_k > 128:
+        orig_block_k = block_k
+        block_k = 128
+        warnings.warn(
+            f"block_k {orig_block_k} is too large; clamped to 128 for {mma_dtype}.",
+            stacklevel=2,
+        )

93-96: Typo in variable name (rep_k_stirde) — rename for clarity.

Pure readability; keeps future edits sane.

Apply:

-    rep_k_stirde = prod(shape_i + shape_k)
+    rep_k_stride = prod(shape_i + shape_k)
     shape_k.append(rep_k)
-    stride_k.append(rep_k_stirde)
+    stride_k.append(rep_k_stride)

97-103: Avoid shadowing parameter k inside comprehensions.

Shadowing isn’t a bug here but hurts readability.

Apply:

-        i_offset = sum(i_decomposed[k] * stride_i[k] for k in range(len(i_decomposed)))
-        k_offset = sum(k_decomposed[k] * stride_k[k] for k in range(len(k_decomposed)))
+        i_offset = sum(i_decomposed[idx] * stride_i[idx] for idx in range(len(i_decomposed)))
+        k_offset = sum(k_decomposed[idx] * stride_k[idx] for idx in range(len(k_decomposed)))

119-124: Guard unsupported mma_dtype early.

If mma_dtype isn’t one of the known sets, both checks are skipped with no error. Fail fast.

Apply:

+    supported = {"float16", "bfloat16", "float8", "int8", "uint8"}
+    if mma_dtype not in supported:
+        raise NotImplementedError(f"Unsupported mma_dtype for sm8x: {mma_dtype}")
benchmark/matmul/benchmark_matmul_sp.py (4)

165-173: Select e_factor/e_dtype programmatically (8.x vs 9.x).

Prevents runtime KeyError and future-proofs.

Apply:

-        e_factor, e_dtype = ARCH_INFO[arch]
+        major_minor = nvcc.parse_compute_version(arch)
+        if major_minor >= (9, 0):
+            e_factor, e_dtype = 8, "uint8"
+        elif major_minor >= (8, 0):
+            e_factor, e_dtype = 16, "int16"
+        else:
+            raise NotImplementedError(f"Unsupported compute capability: {arch}")

23-39: Docstring drift: B shape and result expression.

Code uses B with shape (K, N) and returns A @ B (not A @ B.T). Update text.

Apply:

-    B : numpy.ndarray
-        The matrix with shape (N, K).
+    B : numpy.ndarray
+        The matrix with shape (K, N).
@@
-        The result of A @ B.T, shape (M, N).
+        The result of A @ B, shape (M, N).

140-161: Inline kernel docstring: fix “A @ B^T”.

The kernel computes C = A @ B. Adjust wording.

Apply:

-        The actual kernel to compute C = A @ B^T.
+        The actual kernel to compute C = A @ B.

278-285: Gate torch sparse benchmark to SM80+.

PyTorch 2:4 path is SM80-specific; fail fast on older CC.

Apply:

-    if args.bench_torch_sparse is not None:
+    if args.bench_torch_sparse is not None:
+        if nvcc.parse_compute_version(arch) < (8, 0):
+            raise NotImplementedError("--bench_torch_sparse requires SM80+")
         from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
src/tl_templates/cuda/gemm_sp_sm80.h (2)

220-228: Replace C-style pointer casts with reinterpret_cast.

Improves clarity and type-safety.

Apply:

-    const TensorRefA ref_A(
-        (ElementA *)pA,
+    const TensorRefA ref_A(
+        reinterpret_cast<ElementA *>(pA),
         MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}));
-    const TensorRefE ref_E(
-        (ElementE *)pE,
+    const TensorRefE ref_E(
+        reinterpret_cast<ElementE *>(pE),
         MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn}));
-    const TensorRefB ref_B(
-        (ElementB *)pB,
+    const TensorRefB ref_B(
+        reinterpret_cast<ElementB *>(pB),
         MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}));

260-264: Avoid C-style cast on accum.

Use reinterpret_cast.

Apply:

-  MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m,
+  MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id % num_warp_m,
             warp_id / num_warp_m, lane_id);
tilelang/utils/sparse.py (2)

40-41: Don’t hardcode -arch=sm_90.

Rely on TORCH_CUDA_ARCH_LIST set by env._initialize_torch_cuda_arch_flags() to support 9.0/9.0a and future CCs.

Apply:

-            f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include',
-            '-arch=sm_90',
+            f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include',

48-53: Fix clamped-warning to show the original block_k.

Same pattern as SM90 layout builder.

Apply:

-    if block_k > 128:
-        block_k = 128
-        warnings.warn(
-            f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2)
+    if block_k > 128:
+        orig_block_k = block_k
+        block_k = 128
+        warnings.warn(
+            f"block_k {orig_block_k} is too large; clamped to 128 for sm90 compression.",
+            stacklevel=2)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between de09434 and 9962ec5.

📒 Files selected for processing (8)
  • benchmark/matmul/benchmark_matmul_sp.py (6 hunks)
  • src/tl_templates/cpp/half.hpp (4 hunks)
  • src/tl_templates/cuda/common.h (1 hunks)
  • src/tl_templates/cuda/gemm_sp_sm80.h (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (6 hunks)
  • tilelang/language/builtin.py (1 hunks)
  • tilelang/layout/gemm_sp.py (3 hunks)
  • tilelang/utils/sparse.py (2 hunks)
✅ Files skipped from review due to trivial changes (3)
  • tilelang/language/builtin.py
  • src/tl_templates/cpp/half.hpp
  • src/tl_templates/cuda/common.h
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.975Z
Learnt from: botbw
PR: tile-ai/tilelang#691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.975Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/tl_templates/cuda/gemm_sp_sm80.h
  • tilelang/layout/gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.975Z
Learnt from: botbw
PR: tile-ai/tilelang#691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.975Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/tl_templates/cuda/gemm_sp_sm80.h
  • tilelang/layout/gemm_sp.py
🧬 Code graph analysis (5)
src/tl_templates/cuda/gemm_sp_sm80.h (2)
src/tl_templates/cuda/common.h (3)
  • tl (299-335)
  • int (92-95)
  • int (130-137)
src/tl_templates/cuda/gemm_sp_sm90.h (1)
  • gemm_sp_ss (224-231)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (3)
tilelang/utils/sparse.py (1)
  • compress (76-94)
tilelang/layout/gemm_sp.py (1)
  • make_metadata_layout (136-158)
tilelang/env.py (1)
  • disable_cache (232-233)
benchmark/matmul/benchmark_matmul_sp.py (4)
tilelang/layout/gemm_sp.py (1)
  • make_metadata_layout (136-158)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (257-298)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (3)
  • matmul_sp (9-61)
  • main (30-59)
  • main (126-127)
tilelang/env.py (1)
  • disable_cache (232-233)
tilelang/layout/gemm_sp.py (1)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (257-298)
tilelang/utils/sparse.py (2)
tilelang/contrib/nvcc.py (2)
  • get_target_compute_version (257-298)
  • parse_compute_version (301-323)
src/tl_templates/cuda/compress_sm90.cu (2)
  • compress_sm90 (156-159)
  • compress_sm90 (156-156)
🪛 GitHub Actions: CI Test on AMD
src/tl_templates/cuda/gemm_sp_sm80.h

[error] 1-1: Tile-lang formatter reformatted this file. Changes not staged for commit.

🪛 GitHub Actions: CI
src/tl_templates/cuda/gemm_sp_sm80.h

[error] 1-1: clang-format reformatted this file. Changes not staged for commit. Please review and stage the changes.

🪛 Ruff (0.12.2)
tilelang/layout/gemm_sp.py

120-120: Avoid specifying long messages outside the exception class

(TRY003)


123-123: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/utils/sparse.py

63-64: Avoid specifying long messages outside the exception class

(TRY003)


93-94: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task
🔇 Additional comments (5)
src/tl_templates/cuda/gemm_sp_sm80.h (3)

83-87: Crosswise uses sizeof(T): correct per CUTLASS.

Matches CUTLASS’s SM80 sparse patterns (Crosswise in bytes; layout template in bits). Nothing to change.


97-101: Crosswise uses sizeof(T): correct per CUTLASS.

Same note for B-side.


1-1: Run formatters and commit the reformat (clang-format & tile-lang).
The verification run matched no files because fd treated 'src/tl_templates/cuda' as a pattern; run either make fmt or python -m tilelang.format, or run clang-format manually, e.g.:
fd . src/tl_templates/cuda -t f -e h -e cu | xargs -r -I{} clang-format -i {}
Then git add/commit. Affected: src/tl_templates/cuda/gemm_sp_sm80.h.

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (2)

367-369: Good: re-enabled SM80 gating.

Prevents running Ampere kernels on unsupported GPUs.


119-139: SM80 E dtypes and factors look correct.

int8/float8 -> 32 (int32 metadata), fp16/bf16 -> 16 (int16 metadata). Matches layout checks.

If you want auto-detect instead of hard-coding arch="8.0" in make_metadata_layout, you can drop arch=… and rely on autodetect, keeping tests arch-agnostic.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
src/target/codegen_webgpu.cc (2)

727-731: Keep this copy-elision; also make output order deterministic

The const auto& avoids copies—good. Consider emitting smap_ in a stable order to ensure reproducible artifacts across runs.

Apply this diff:

-      for (const auto &kv : smap_) {
-        os << kv.second;
-      }
+      // Emit in a deterministic order to ensure reproducible binaries.
+      std::vector<std::string> keys;
+      keys.reserve(smap_.size());
+      for (const auto &kv : smap_) keys.push_back(kv.first);
+      std::sort(keys.begin(), keys.end());
+      for (const auto &k : keys) {
+        os << smap_.at(k);
+      }

754-771: Use const-reference here as well to avoid copies

Mirror the pattern used above to prevent copying function map entries.

Apply this diff:

-  for (auto kv : mod->functions) {
+  for (const auto &kv : mod->functions) {
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9962ec5 and 88435da.

📒 Files selected for processing (1)
  • src/target/codegen_webgpu.cc (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: format-check
  • GitHub Check: bot-task
🔇 Additional comments (1)
src/target/codegen_webgpu.cc (1)

80-98: Idiomatic empty() check — good change

Switching to empty() is clearer and avoids redundant size computation. No behavior change.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (3)
src/tl_templates/cuda/gemm_sp_sm80.h (3)

83-85: Layout computation correctly uses sizeof(T) for crosswise calculation.

The use of sizeof(T) (bytes) for computing crosswise constants aligns with the official CUTLASS implementation pattern, as documented in the retrieved learnings. This is correct and should not be changed to cutlass::sizeof_bits<T>::value.


224-232: Replace C-style casts with C++ reinterpret_cast for type safety.

Using C-style casts for pointer conversions reduces type safety. Modern C++ prefers explicit reinterpret_cast to make the type-punning operation clear and intentional.

     const TensorRefA ref_A(
-        (ElementA *)pA,
+        reinterpret_cast<ElementA *>(pA),
         MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}));
     const TensorRefE ref_E(
-        (ElementE *)pE,
+        reinterpret_cast<ElementE *>(pE),
         MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn}));
     const TensorRefB ref_B(
-        (ElementB *)pB,
+        reinterpret_cast<ElementB *>(pB),
         MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}));

266-267: Replace C-style cast with reinterpret_cast for consistency.

For consistency with modern C++ practices and the previous suggestions, replace the C-style cast with reinterpret_cast.

-  MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m,
+  MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id % num_warp_m,
             warp_id / num_warp_m, lane_id);
🧹 Nitpick comments (2)
src/tl_templates/cuda/gemm_sp_sm80.h (2)

22-24: Consider documenting the rationale for int8_t shape constraints.

The int8_t and uint8_t types require K to be divisible by 64 (vs 32 for fp16), likely due to tensor core instruction requirements for integer types on SM80.

Add a brief comment explaining the K dimension requirement:

 template <typename Shape> struct ShapeCheck<int8_t, Shape> {
+  // Int8 tensor cores on SM80 require K dimension to be multiple of 64
   static constexpr bool value =
       (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0);
 };

255-258: Consider adding documentation for the gemm_sp_ss interface.

The template parameters and their constraints would benefit from documentation to help users understand the requirements and expected behavior.

Add a documentation comment before the function:

+/// Sparse GEMM operation for SM80 architecture with 2:4 structured sparsity
+/// @tparam M, N, K - Matrix dimensions  
+/// @tparam num_warp_m, num_warp_n - Warp tiling configuration
+/// @tparam trans_A, trans_B - Whether to transpose A and B matrices
+/// @tparam clear_accum - Whether to clear accumulator before operation
+/// @param pA - Pointer to sparse matrix A (2:4 structured)
+/// @param pB - Pointer to dense matrix B  
+/// @param accum - Accumulator for result C
+/// @param pE - Pointer to sparsity metadata for matrix A
 template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
           bool trans_B, bool clear_accum = false, typename A_type,
           typename B_type, typename C_type, typename E_type>
 TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) {
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 88435da and 597b66f.

📒 Files selected for processing (1)
  • src/tl_templates/cuda/gemm_sp_sm80.h (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.975Z
Learnt from: botbw
PR: tile-ai/tilelang#691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.975Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/tl_templates/cuda/gemm_sp_sm80.h
📚 Learning: 2025-09-15T10:51:06.975Z
Learnt from: botbw
PR: tile-ai/tilelang#691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.975Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/tl_templates/cuda/gemm_sp_sm80.h
🧬 Code graph analysis (1)
src/tl_templates/cuda/gemm_sp_sm80.h (2)
src/tl_templates/cuda/common.h (3)
  • tl (301-337)
  • int (92-95)
  • int (130-137)
src/tl_templates/cuda/gemm_sp_sm90.h (1)
  • gemm_sp_ss (224-231)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-amd
🔇 Additional comments (2)
src/tl_templates/cuda/gemm_sp_sm80.h (2)

143-146: Comprehensive input validation with clear error messages.

The static assertions correctly validate type compatibility and shape requirements. This ensures compile-time safety for the sparse GEMM operations.


94-96: Verify int8/uint8/float8 row-major support for matrix B

Static assert at src/tl_templates/cuda/gemm_sp_sm80.h:94–96 tests cutlass::sizeof_bits::value != 8 while this specialization is for transpose = false (row-major). Repository search produced no output; cannot confirm whether row-major B is actually unsupported for 8‑bit element types. Confirm intended behavior and take one of the actions below:

  • If row-major B must be disallowed for 8‑bit types: remove this specialization or make the assert explicitly target concrete types (int8_t/uint8_t/float8_t) and update the message.
  • If row-major B is allowed: relax/remove the assert or adjust it to check bytes vs bits (sizeof(T) vs cutlass::sizeof_bits::value) and clarify the message.

@@ -0,0 +1,270 @@
#include <cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h>
#include <stdio.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove unused <stdio.h> header.

The <stdio.h> header is included but never used in this file.

 #include <cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h>
-#include <stdio.h>
 
 namespace tl {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <stdio.h>
#include <cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h>
namespace tl {
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_sp_sm80.h around lines 2 to 2, remove the unused
#include <stdio.h> directive because the header is not referenced anywhere in
the file; simply delete that include line to avoid unnecessary dependency and
potential warnings.

cutlass::gemm::warp::SparseMmaTensorOp<WarpShape, ElementA, SmemLayoutA,
ElementB, SmemLayoutB, ElementC,
LayoutC, Policy>;
static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove duplicate static assertion.

The same assertion kSparse == MmaWarp::kSparse appears on both line 184 and line 206. The second occurrence is redundant.

   static int const kElementsPerElementE = MmaWarp::kElementsPerElementE;
-  static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse");
 
   using ShapeA = cutlass::MatrixShape<Shape::kM, Shape::kK / kSparse>;

Also applies to: 206-206

🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_sp_sm80.h around lines 184 and 206, there are
duplicate static_asserts checking kSparse == MmaWarp::kSparse; keep the first
occurrence (line 184) and remove the redundant second occurrence at line 206 to
avoid duplication; ensure no other dependent code expects the second assert and
run a quick build to confirm no compilation regressions.

Comment on lines +264 to +267
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m,
warp_id / num_warp_m, lane_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Incorrect warp index mapping for MMA operation.

The warp indices passed to MMA::body are swapped. The current code passes warp_id % num_warp_m as warp_idx_m and warp_id / num_warp_m as warp_idx_n, but this is incorrect for the standard warp tiling layout.

In CUTLASS and standard GEMM tiling, warps are typically arranged in row-major order within a threadblock, so:

  • warp_idx_m = warp_id / num_warp_n (row index)
  • warp_idx_n = warp_id % num_warp_n (column index)
   int warp_id = threadIdx.x / 32;
   int lane_id = threadIdx.x % 32;
-  MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id % num_warp_m,
-            warp_id / num_warp_m, lane_id);
+  MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id / num_warp_n,
+            warp_id % num_warp_n, lane_id);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m,
warp_id / num_warp_m, lane_id);
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id / num_warp_n,
warp_id % num_warp_n, lane_id);
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_sp_sm80.h around lines 264 to 267, the warp index
arguments to MMA::body are swapped; change the mapping so warp_idx_m = warp_id /
num_warp_n and warp_idx_n = warp_id % num_warp_n (i.e., pass warp_id /
num_warp_n as the first warp index and warp_id % num_warp_n as the second) to
match the standard row-major warp tiling used by CUTLASS/GEMM.

@LeiWang1999 LeiWang1999 merged commit 0b3683b into tile-ai:main Sep 15, 2025
4 of 5 checks passed
@botbw botbw deleted the gemm_sp_sm80 branch September 16, 2025 07:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants