Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 11, 2025

  • Moved GEMM-related dispatch instructions to the cute::tl_mma namespace for better organization.
  • Introduced TL_DISPATCH_MMA and TL_DISPATCH_MMA_TEMPLATE macros to streamline the definition of dispatch instructions for various data types and architectures.
  • Updated the handling of CUDA architecture checks to include additional support for newer architectures.
  • Improved clarity and maintainability of the code by restructuring the layout and organization of dispatch instructions.
  • Ensured consistent usage of tensor views and memory clearing operations across different GEMM implementations.

Summary by CodeRabbit

  • Refactor

    • GEMM public implementations reorganized: a new high-performance Hopper wgmma path is now the canonical implementation, with the previous path retained as fallback.
    • Public GEMM entry points now reference the new implementation path; callers may need to update qualified references/imports.
  • Chores

    • Updated architecture guards and dispatch logic for consistent behavior across supported GPUs.

…ch macros

- Moved GEMM-related dispatch instructions to the `cute::tl_mma` namespace for better organization.
- Introduced `TL_DISPATCH_MMA` and `TL_DISPATCH_MMA_TEMPLATE` macros to streamline the definition of dispatch instructions for various data types and architectures.
- Updated the handling of CUDA architecture checks to include additional support for newer architectures.
- Improved clarity and maintainability of the code by restructuring the layout and organization of dispatch instructions.
- Ensured consistent usage of tensor views and memory clearing operations across different GEMM implementations.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 11, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Moves GEMM implementation: introduces cute::tl_mma namespace and macro-driven DispatchInstruction specializations, adds a new tl_wgmma GemmTensorOp for Hopper wgmma, routes tl::gemm_* wrappers to tl_wgmma when enabled, and adjusts accumulation-clear timing within GEMM control flow.

Changes

Cohort / File(s) Summary
tl_mma refactor & dispatch consolidation
src/tl_templates/cuda/gemm_mma.h
Introduces cute::tl_mma namespace, moves GemmTensorOp and related traits into it, replaces per-arch DispatchInstruction specializations with TL_DISPATCH_MMA / TL_DISPATCH_MMA_TEMPLATE macros, updates includes/guards, and moves clear(acc) to after view creation.
wgmma Hopper implementation & wrapper routing
src/tl_templates/cuda/gemm_sm90.h
Removes prior tl_mma GEMM path for SM90 and adds cute::tl_wgmma::GemmTensorOp (new signature, raw A/B/C types, smem layout aliases, static layout checks) with body, body_rs, body_sr using Hopper wgmma; updates tl::gemm_ss/gemm_rs/gemm_sr to prefer tl_wgmma when use_wgmma is true, falling back to tl_mma otherwise.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    actor Caller
    participant TL_WRAPPER as tl::gemm_{ss,rs,sr}
    participant TL_WGMMA as cute::tl_wgmma::GemmTensorOp
    participant TL_MMA as cute::tl_mma::GemmTensorOp
    participant DISPATCH as DispatchInstruction (macro-generated)
    participant ARCH as Arch-specific MMA/WGMMA kernel

    Caller->>TL_WRAPPER: call gemm_*
    TL_WRAPPER->>TL_WGMMA: if use_wgmma -> configure types/layouts
    TL_WRAPPER->>TL_MMA: else -> configure tl_mma types
    alt use_wgmma
        TL_WGMMA->>DISPATCH: choose wgmma path / sync/barrier handling
        DISPATCH->>ARCH: invoke wgmma kernel (hopper)
        ARCH-->>TL_WGMMA: compute tiles
        TL_WGMMA-->>TL_WRAPPER: return accum result
    else fallback
        TL_MMA->>DISPATCH: select arch/type instr via macros
        DISPATCH->>ARCH: invoke arch-specific MMA kernel
        ARCH-->>TL_MMA: compute tiles
        TL_MMA-->>TL_WRAPPER: return accum result
    end
    TL_WRAPPER-->>Caller: final result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

Pre-merge checks (2 passed, 1 warning)

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly and accurately captures the primary changes in the PR: moving GEMM dispatch into a new namespace and introducing/enhancing dispatch macros for MMA, which directly matches the PR objectives and file summaries; it is concise, specific, and avoids vague language or noise.

Poem

I hop through namespaces, tidy each lair,
Macros line up instructions with care.
WGMMA whistles, tl_mma nods on the side,
Accums cleared later — the loop takes its stride.
A rabbit refactors, then bounds off with pride. 🐇✨


📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 13b9c97 and dfd8ec8.

📒 Files selected for processing (1)
  • src/tl_templates/cuda/gemm_sm90.h (0 hunks)
💤 Files with no reviewable changes (1)
  • src/tl_templates/cuda/gemm_sm90.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: bot-task
  • GitHub Check: format-check
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the GEMM-related dispatch instructions within the codebase. The primary goal is to enhance code organization and maintainability by introducing a new namespace and a set of helper macros. These changes also extend support for a wider range of CUDA architectures, ensuring future compatibility and improved performance across different hardware.

Highlights

  • Namespace Refactoring: Moved GEMM-related dispatch instructions to the cute::tl_mma namespace for better organization.
  • Macro Introduction: Introduced TL_DISPATCH_MMA and TL_DISPATCH_MMA_TEMPLATE macros to streamline the definition of dispatch instructions for various data types and architectures.
  • CUDA Architecture Support: Updated the handling of CUDA architecture checks to include additional support for newer architectures.
  • Code Clarity and Maintainability: Improved clarity and maintainability of the code by restructuring the layout and organization of dispatch instructions.
  • Consistency in Tensor Operations: Ensured consistent usage of tensor views and memory clearing operations across different GEMM implementations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively refactors the GEMM dispatch logic by introducing macros to reduce code duplication and reorganizing the code into a new cute::tl_mma namespace. The changes improve maintainability and add support for newer CUDA architectures. However, I've found a critical issue where the refactoring for SM75 architecture is incorrect, potentially breaking support for Turing GPUs. I've also identified an opportunity to further reduce code duplication in the preprocessor directives for architecture-specific code. The other changes, like moving accumulator clearing operations, are minor but positive improvements to code clarity.

Comment on lines +43 to 104
#ifdef __CUDA_ARCH_LIST__
#if __CUDA_ARCH_LIST__ >= 1200
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e4_t, fp8_e4_t, float>>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e5_t, fp8_e5_t, float>>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
#include "cuda_fp8.h"
#include <cute/arch/mma_sm120.hpp>
#include <cute/arch/mma_sm80.hpp>
TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN)
TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN)
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 1000
#include "cuda_fp8.h"
#include <cute/arch/mma_sm100.hpp>
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 900
#include "cuda_fp8.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 890
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
#include "cuda_fp8.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 800
#include <cute/arch/mma_sm80.hpp>
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 750
TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN)
#endif
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
};
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This section has two issues: a critical bug for SM75 and a maintainability issue due to code duplication.

  1. Critical Bug for SM75 (Turing): The refactoring to use the TL_DISPATCH_MMA macro for the SM75 architecture is incorrect. The macro hardcodes the last dimension of MMA_Group to _X, while the original code correctly used _16. This will likely cause compilation errors or incorrect behavior on Turing GPUs.

  2. Code Duplication: The current #if/#elif structure leads to significant code duplication. For example, the SM80 instructions are repeated for every newer architecture block (SM89, SM90, SM100, SM120). Also, the blocks for __CUDA_ARCH_LIST__ >= 900 and __CUDA_ARCH_LIST__ >= 890 are identical.

I recommend refactoring this entire section to use a series of non-exclusive #if blocks. This will fix the SM75 bug, eliminate redundancy, and make the code much easier to maintain.

Here is a suggested refactoring:

#ifdef __CUDA_ARCH_LIST__

// SM75 (Turing)
#if __CUDA_ARCH_LIST__ >= 750 && __CUDA_ARCH_LIST__ < 800
#include <cute/arch/mma_sm75.hpp>
namespace cute::tl_mma {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
};
} // namespace cute::tl_mma
#endif

// SM80+ (Ampere)
#if __CUDA_ARCH_LIST__ >= 800
#include <cute/arch/mma_sm80.hpp>
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#endif

// SM89+ (Ada Lovelace)
#if __CUDA_ARCH_LIST__ >= 890
#include "cuda_fp8.h"
#include <cute/arch/mma_sm89.hpp>
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
#endif

// SM90 is covered by SM89 for these instructions.
// SM100+ (Blackwell)
#if __CUDA_ARCH_LIST__ >= 1000
#include <cute/arch/mma_sm100.hpp>
#endif

// SM120+ (Future Arch)
#if __CUDA_ARCH_LIST__ >= 1200
#include <cute/arch/mma_sm120.hpp>
// Override fp8 dispatches for SM120+
TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN)
TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN)
#endif

#endif // __CUDA_ARCH_LIST__

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/tl_templates/cuda/gemm_mma.h (2)

270-272: Bug: B_type alias falls back to A_type_raw instead of B_type_raw.

This will mis-typed B operands when B_type_raw != float.

-  using B_type =
-      typename std::conditional<std::is_same<B_type_raw, float>::value,
-                                tfloat32_t, A_type_raw>::type;
+  using B_type =
+      typename std::conditional<std::is_same<B_type_raw, float>::value,
+                                tfloat32_t, B_type_raw>::type;

453-479: Namespace move is breaking — preserve backward-compatibility or update callers

  • Verified occurrences referencing the old symbol name (string/handle) exist and must be updated or kept as a shim:
    • src/op/gemm.cc (op_name = "tl::gemm_ss" / "tl::gemm_rs" / "tl::gemm_sr").
    • testing/python/transform/* (several tests embed "tl::gemm_ss<...>" in call_extern handles).
  • Action: either add a temporary compatibility alias exposing the new implementations under the old tl:: names, or update the callers (C++ op_name and tests) to the new namespace.
  • Suggested compatibility shim (place in a header included by callers, outside this file):
namespace tl {
  using tl_mma::gemm_ss;
  using tl_mma::gemm_rs;
  using tl_mma::gemm_sr;
}  // namespace tl
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 91a7bb2 and 13b9c97.

📒 Files selected for processing (1)
  • src/tl_templates/cuda/gemm_mma.h (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tl_templates/cuda/gemm_mma.h (3)
tilelang/tileop/gemm/gemm_base.py (6)
  • N (37-38)
  • clear_accum (106-107)
  • M (33-34)
  • K (41-42)
  • trans_A (45-46)
  • trans_B (49-50)
tilelang/language/fill.py (1)
  • clear (24-48)
tilelang/language/proxy.py (1)
  • make_tensor (301-305)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check
🔇 Additional comments (2)
src/tl_templates/cuda/gemm_mma.h (2)

361-363: clear(acc) after view creation — good.

Semantics preserved; avoids unnecessary clears if early exits were ever added pre-view creation.


395-395: Precompute {tCrB|tCrA}_view before the loop — good.

Minor readability/perf win; matches body() sequencing.

Also applies to: 431-431

Comment on lines +24 to +41
#define TL_DISPATCH_MMA(A_type, B_type, C_type, MMA_instr) \
namespace cute::tl_mma { \
template <int num_warp_m, int num_warp_n, int N> \
struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \
N> { \
using MMA = MMA_Atom<MMA_instr>; \
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \
}; \
}
#define TL_DISPATCH_MMA_TEMPLATE(A_type, B_type, C_type, MMA_instr) \
namespace cute::tl_mma { \
template <int num_warp_m, int num_warp_n, int N> \
struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \
N> { \
using MMA = MMA_Atom<MMA_instr<A_type, B_type, C_type>>; \
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \
}; \
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add missing (for std::min in template argument) and prefer explicit constexpr min to avoid surprises.

std::min is used in a non-type template argument; some toolchains require . Also consider a constexpr min to avoid ADL/overload issues.

Apply at file top:

+#include <algorithm>
+#include <type_traits>

Optional (safer constexpr min inside the macro body):

-    using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
+    static constexpr int kGroupN = (num_warp_n * 16 < N) ? (num_warp_n * 16) : N;
+    using MMA_Group = Tile<_X, Int<kGroupN>, _X>;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#define TL_DISPATCH_MMA(A_type, B_type, C_type, MMA_instr) \
namespace cute::tl_mma { \
template <int num_warp_m, int num_warp_n, int N> \
struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \
N> { \
using MMA = MMA_Atom<MMA_instr>; \
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \
}; \
}
#define TL_DISPATCH_MMA_TEMPLATE(A_type, B_type, C_type, MMA_instr) \
namespace cute::tl_mma { \
template <int num_warp_m, int num_warp_n, int N> \
struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \
N> { \
using MMA = MMA_Atom<MMA_instr<A_type, B_type, C_type>>; \
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \
}; \
}
#include <algorithm>
#include <type_traits>
#define TL_DISPATCH_MMA(A_type, B_type, C_type, MMA_instr) \
namespace cute::tl_mma { \
template <int num_warp_m, int num_warp_n, int N> \
struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \
N> { \
using MMA = MMA_Atom<MMA_instr>; \
static constexpr int kGroupN = (num_warp_n * 16 < N) ? (num_warp_n * 16) : N; \
using MMA_Group = Tile<_X, Int<kGroupN>, _X>; \
}; \
}
#define TL_DISPATCH_MMA_TEMPLATE(A_type, B_type, C_type, MMA_instr) \
namespace cute::tl_mma { \
template <int num_warp_m, int num_warp_n, int N> \
struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \
N> { \
using MMA = MMA_Atom<MMA_instr<A_type, B_type, C_type>>; \
static constexpr int kGroupN = (num_warp_n * 16 < N) ? (num_warp_n * 16) : N; \
using MMA_Group = Tile<_X, Int<kGroupN>, _X>; \
}; \
}
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_mma.h around lines 24-41, add the missing #include
<algorithm> at the top of the file and replace the use of std::min inside the
non-type template argument with an explicit constexpr min to avoid ADL/overload
surprises; specifically, declare a small constexpr inline function (e.g.,
tl_constexpr_min) that returns the smaller of two ints and use that function in
the Tile<_X, Int<...>, _X> template argument (or alternatively call std::min
with full std:: qualification and ensure <algorithm> is included), so toolchains
that evaluate non-type template expressions resolve correctly.

Comment on lines +43 to 104
#ifdef __CUDA_ARCH_LIST__
#if __CUDA_ARCH_LIST__ >= 1200
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e4_t, fp8_e4_t, float>>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e5_t, fp8_e5_t, float>>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
#include "cuda_fp8.h"
#include <cute/arch/mma_sm120.hpp>
#include <cute/arch/mma_sm80.hpp>
TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN)
TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN)
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 1000
#include "cuda_fp8.h"
#include <cute/arch/mma_sm100.hpp>
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 900
#include "cuda_fp8.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 890
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
#include "cuda_fp8.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 800
#include <cute/arch/mma_sm80.hpp>
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#elif __CUDA_ARCH_LIST__ >= 750
TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN)
#endif
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
};
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix arch gating and headers: add SM75 header, drop duplicate cuda_fp8.h includes, and support CUDA_ARCH fallback.

  • SM75 instructions are used but mma_sm75.hpp is never included (compilation will fail for sm_75).
  • cuda_fp8.h is included redundantly in each branch.
  • Only checking CUDA_ARCH_LIST breaks older nvcc; add fallback to CUDA_ARCH.

Proposed patch to this block:

-#ifdef __CUDA_ARCH_LIST__
-#if __CUDA_ARCH_LIST__ >= 1200
-#include "cuda_fp8.h"
-#include <cute/arch/mma_sm120.hpp>
-#include <cute/arch/mma_sm80.hpp>
+#if defined(__CUDA_ARCH_LIST__) || defined(__CUDA_ARCH__)
+#ifndef TL_ARCH_LEVEL
+#  if defined(__CUDA_ARCH_LIST__)
+#    define TL_ARCH_LEVEL __CUDA_ARCH_LIST__
+#  else
+#    define TL_ARCH_LEVEL __CUDA_ARCH__
+#  endif
+#endif
+#if TL_ARCH_LEVEL >= 1200
+#include <cute/arch/mma_sm120.hpp>
+#include <cute/arch/mma_sm80.hpp>
 TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN)
 TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN)
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 1000
-#include "cuda_fp8.h"
-#include <cute/arch/mma_sm100.hpp>
-#include <cute/arch/mma_sm80.hpp>
-#include <cute/arch/mma_sm89.hpp>
+#elif TL_ARCH_LEVEL >= 1000
+#include <cute/arch/mma_sm100.hpp>
+#include <cute/arch/mma_sm80.hpp>
+#include <cute/arch/mma_sm89.hpp>
 TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
 TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 900
-#include "cuda_fp8.h"
-#include <cute/arch/mma_sm80.hpp>
-#include <cute/arch/mma_sm89.hpp>
+#elif TL_ARCH_LEVEL >= 900
+#include <cute/arch/mma_sm80.hpp>
+#include <cute/arch/mma_sm89.hpp>
 TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
 TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 890
-#include "cuda_fp8.h"
-#include <cute/arch/mma_sm80.hpp>
-#include <cute/arch/mma_sm89.hpp>
+#elif TL_ARCH_LEVEL >= 890
+#include <cute/arch/mma_sm80.hpp>
+#include <cute/arch/mma_sm89.hpp>
 TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
 TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 800
+#elif TL_ARCH_LEVEL >= 800
 #include <cute/arch/mma_sm80.hpp>
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 750
-TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN)
-#endif
-#endif
+#elif TL_ARCH_LEVEL >= 750
+#include <cute/arch/mma_sm75.hpp>
+TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN)
+#endif
+#undef TL_ARCH_LEVEL

Additionally (outside this block), include LDSM ops used by SelectCopy:

+#include <cute/arch/ldsm.hpp>

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_mma.h around lines 43 to 104, fix the arch gating
and header organization: move the cuda_fp8.h include out of the per-arch
branches and include it once before the conditional block; add an #include for
the SM75 MMA header (cute/arch/mma_sm75.hpp) so SM75 dispatches compile; replace
the sole use of __CUDA_ARCH_LIST__ with a guarded check that handles both
__CUDA_ARCH_LIST__ and a fallback to __CUDA_ARCH__ (e.g., check
defined(__CUDA_ARCH_LIST__) ? use it : if defined(__CUDA_ARCH__) compare
__CUDA_ARCH__ numeric value) so older nvcc works; remove duplicate cuda_fp8.h
includes from the branches; and, outside this block, add the LDSM ops include
needed by SelectCopy (for example include the cute LDSM header).

…ace from CUDA GEMM implementation. This cleanup enhances code clarity and maintainability by eliminating unused structures and streamlining the overall organization of the GEMM operations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant