- 
                Notifications
    You must be signed in to change notification settings 
- Fork 292
[Refactor] Use new namespace and enhance dispatch macros for mma #801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ch macros - Moved GEMM-related dispatch instructions to the `cute::tl_mma` namespace for better organization. - Introduced `TL_DISPATCH_MMA` and `TL_DISPATCH_MMA_TEMPLATE` macros to streamline the definition of dispatch instructions for various data types and architectures. - Updated the handling of CUDA architecture checks to include additional support for newer architectures. - Improved clarity and maintainability of the code by restructuring the layout and organization of dispatch instructions. - Ensured consistent usage of tensor views and memory clearing operations across different GEMM implementations.
| Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughMoves GEMM implementation: introduces cute::tl_mma namespace and macro-driven DispatchInstruction specializations, adds a new tl_wgmma GemmTensorOp for Hopper wgmma, routes tl::gemm_* wrappers to tl_wgmma when enabled, and adjusts accumulation-clear timing within GEMM control flow. Changes
 Sequence Diagram(s)sequenceDiagram
    autonumber
    actor Caller
    participant TL_WRAPPER as tl::gemm_{ss,rs,sr}
    participant TL_WGMMA as cute::tl_wgmma::GemmTensorOp
    participant TL_MMA as cute::tl_mma::GemmTensorOp
    participant DISPATCH as DispatchInstruction (macro-generated)
    participant ARCH as Arch-specific MMA/WGMMA kernel
    Caller->>TL_WRAPPER: call gemm_*
    TL_WRAPPER->>TL_WGMMA: if use_wgmma -> configure types/layouts
    TL_WRAPPER->>TL_MMA: else -> configure tl_mma types
    alt use_wgmma
        TL_WGMMA->>DISPATCH: choose wgmma path / sync/barrier handling
        DISPATCH->>ARCH: invoke wgmma kernel (hopper)
        ARCH-->>TL_WGMMA: compute tiles
        TL_WGMMA-->>TL_WRAPPER: return accum result
    else fallback
        TL_MMA->>DISPATCH: select arch/type instr via macros
        DISPATCH->>ARCH: invoke arch-specific MMA kernel
        ARCH-->>TL_MMA: compute tiles
        TL_MMA-->>TL_WRAPPER: return accum result
    end
    TL_WRAPPER-->>Caller: final result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
 Possibly related PRs
 Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
 ✅ Passed checks (2 passed)
 Poem
 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
 💤 Files with no reviewable changes (1)
 ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
 ✨ Finishing touches
 🧪 Generate unit tests
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment  | 
| 👋 Hi! Thank you for contributing to the TileLang project. Please remember to run  We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly refactors the GEMM-related dispatch instructions within the codebase. The primary goal is to enhance code organization and maintainability by introducing a new namespace and a set of helper macros. These changes also extend support for a wider range of CUDA architectures, ensuring future compatibility and improved performance across different hardware.
Highlights
- Namespace Refactoring: Moved GEMM-related dispatch instructions to the cute::tl_mma namespace for better organization.
- Macro Introduction: Introduced TL_DISPATCH_MMA and TL_DISPATCH_MMA_TEMPLATE macros to streamline the definition of dispatch instructions for various data types and architectures.
- CUDA Architecture Support: Updated the handling of CUDA architecture checks to include additional support for newer architectures.
- Code Clarity and Maintainability: Improved clarity and maintainability of the code by restructuring the layout and organization of dispatch instructions.
- Consistency in Tensor Operations: Ensured consistent usage of tensor views and memory clearing operations across different GEMM implementations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description | 
|---|---|---|
| Code Review | /gemini review | Performs a code review for the current pull request in its current state. | 
| Pull Request Summary | /gemini summary | Provides a summary of the current pull request in its current state. | 
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. | 
| Help | /gemini help | Displays a list of available commands. | 
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
- 
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩ 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively refactors the GEMM dispatch logic by introducing macros to reduce code duplication and reorganizing the code into a new cute::tl_mma namespace. The changes improve maintainability and add support for newer CUDA architectures. However, I've found a critical issue where the refactoring for SM75 architecture is incorrect, potentially breaking support for Turing GPUs. I've also identified an opportunity to further reduce code duplication in the preprocessor directives for architecture-specific code. The other changes, like moving accumulator clearing operations, are minor but positive improvements to code clarity.
| #ifdef __CUDA_ARCH_LIST__ | ||
| #if __CUDA_ARCH_LIST__ >= 1200 | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n, | ||
| N> { | ||
| using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e4_t, fp8_e4_t, float>>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n, | ||
| N> { | ||
| using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e5_t, fp8_e5_t, float>>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| #include "cuda_fp8.h" | ||
| #include <cute/arch/mma_sm120.hpp> | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN) | ||
| TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 1000 | ||
| #include "cuda_fp8.h" | ||
| #include <cute/arch/mma_sm100.hpp> | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| #include <cute/arch/mma_sm89.hpp> | ||
| TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) | ||
| TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 900 | ||
| #include "cuda_fp8.h" | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| #include <cute/arch/mma_sm89.hpp> | ||
| TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) | ||
| TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 890 | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n, | ||
| N> { | ||
| using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n, | ||
| N> { | ||
| using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| #include "cuda_fp8.h" | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| #include <cute/arch/mma_sm89.hpp> | ||
| TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) | ||
| TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 800 | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 750 | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN) | ||
| #endif | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m, | ||
| num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m, | ||
| num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>; | ||
| using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>; | ||
| }; | ||
| #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>; | ||
| }; | ||
| #endif | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section has two issues: a critical bug for SM75 and a maintainability issue due to code duplication.
- 
Critical Bug for SM75 (Turing): The refactoring to use the TL_DISPATCH_MMAmacro for the SM75 architecture is incorrect. The macro hardcodes the last dimension ofMMA_Groupto_X, while the original code correctly used_16. This will likely cause compilation errors or incorrect behavior on Turing GPUs.
- 
Code Duplication: The current #if/#elifstructure leads to significant code duplication. For example, the SM80 instructions are repeated for every newer architecture block (SM89, SM90, SM100, SM120). Also, the blocks for__CUDA_ARCH_LIST__ >= 900and__CUDA_ARCH_LIST__ >= 890are identical.
I recommend refactoring this entire section to use a series of non-exclusive #if blocks. This will fix the SM75 bug, eliminate redundancy, and make the code much easier to maintain.
Here is a suggested refactoring:
#ifdef __CUDA_ARCH_LIST__
// SM75 (Turing)
#if __CUDA_ARCH_LIST__ >= 750 && __CUDA_ARCH_LIST__ < 800
#include <cute/arch/mma_sm75.hpp>
namespace cute::tl_mma {
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
};
} // namespace cute::tl_mma
#endif
// SM80+ (Ampere)
#if __CUDA_ARCH_LIST__ >= 800
#include <cute/arch/mma_sm80.hpp>
TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
#endif
// SM89+ (Ada Lovelace)
#if __CUDA_ARCH_LIST__ >= 890
#include "cuda_fp8.h"
#include <cute/arch/mma_sm89.hpp>
TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
#endif
// SM90 is covered by SM89 for these instructions.
// SM100+ (Blackwell)
#if __CUDA_ARCH_LIST__ >= 1000
#include <cute/arch/mma_sm100.hpp>
#endif
// SM120+ (Future Arch)
#if __CUDA_ARCH_LIST__ >= 1200
#include <cute/arch/mma_sm120.hpp>
// Override fp8 dispatches for SM120+
TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN)
TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN)
#endif
#endif // __CUDA_ARCH_LIST__There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (2)
src/tl_templates/cuda/gemm_mma.h (2)
270-272: Bug: B_type alias falls back to A_type_raw instead of B_type_raw.This will mis-typed B operands when B_type_raw != float.
- using B_type = - typename std::conditional<std::is_same<B_type_raw, float>::value, - tfloat32_t, A_type_raw>::type; + using B_type = + typename std::conditional<std::is_same<B_type_raw, float>::value, + tfloat32_t, B_type_raw>::type;
453-479: Namespace move is breaking — preserve backward-compatibility or update callers
- Verified occurrences referencing the old symbol name (string/handle) exist and must be updated or kept as a shim:
- src/op/gemm.cc (op_name = "tl::gemm_ss" / "tl::gemm_rs" / "tl::gemm_sr").
- testing/python/transform/* (several tests embed "tl::gemm_ss<...>" in call_extern handles).
- Action: either add a temporary compatibility alias exposing the new implementations under the old tl:: names, or update the callers (C++ op_name and tests) to the new namespace.
- Suggested compatibility shim (place in a header included by callers, outside this file):namespace tl { using tl_mma::gemm_ss; using tl_mma::gemm_rs; using tl_mma::gemm_sr; } // namespace tl
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- src/tl_templates/cuda/gemm_mma.h(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tl_templates/cuda/gemm_mma.h (3)
tilelang/tileop/gemm/gemm_base.py (6)
N(37-38)
clear_accum(106-107)
M(33-34)
K(41-42)
trans_A(45-46)
trans_B(49-50)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/proxy.py (1)
make_tensor(301-305)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
🔇 Additional comments (2)
src/tl_templates/cuda/gemm_mma.h (2)
361-363: clear(acc) after view creation — good.Semantics preserved; avoids unnecessary clears if early exits were ever added pre-view creation.
395-395: Precompute {tCrB|tCrA}_view before the loop — good.Minor readability/perf win; matches body() sequencing.
Also applies to: 431-431
| #define TL_DISPATCH_MMA(A_type, B_type, C_type, MMA_instr) \ | ||
| namespace cute::tl_mma { \ | ||
| template <int num_warp_m, int num_warp_n, int N> \ | ||
| struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \ | ||
| N> { \ | ||
| using MMA = MMA_Atom<MMA_instr>; \ | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \ | ||
| }; \ | ||
| } | ||
| #define TL_DISPATCH_MMA_TEMPLATE(A_type, B_type, C_type, MMA_instr) \ | ||
| namespace cute::tl_mma { \ | ||
| template <int num_warp_m, int num_warp_n, int N> \ | ||
| struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \ | ||
| N> { \ | ||
| using MMA = MMA_Atom<MMA_instr<A_type, B_type, C_type>>; \ | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \ | ||
| }; \ | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add missing (for std::min in template argument) and prefer explicit constexpr min to avoid surprises.
std::min is used in a non-type template argument; some toolchains require . Also consider a constexpr min to avoid ADL/overload issues.
Apply at file top:
+#include <algorithm>
+#include <type_traits>Optional (safer constexpr min inside the macro body):
-    using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
+    static constexpr int kGroupN = (num_warp_n * 16 < N) ? (num_warp_n * 16) : N;
+    using MMA_Group = Tile<_X, Int<kGroupN>, _X>;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #define TL_DISPATCH_MMA(A_type, B_type, C_type, MMA_instr) \ | |
| namespace cute::tl_mma { \ | |
| template <int num_warp_m, int num_warp_n, int N> \ | |
| struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \ | |
| N> { \ | |
| using MMA = MMA_Atom<MMA_instr>; \ | |
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \ | |
| }; \ | |
| } | |
| #define TL_DISPATCH_MMA_TEMPLATE(A_type, B_type, C_type, MMA_instr) \ | |
| namespace cute::tl_mma { \ | |
| template <int num_warp_m, int num_warp_n, int N> \ | |
| struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \ | |
| N> { \ | |
| using MMA = MMA_Atom<MMA_instr<A_type, B_type, C_type>>; \ | |
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; \ | |
| }; \ | |
| } | |
| #include <algorithm> | |
| #include <type_traits> | |
| #define TL_DISPATCH_MMA(A_type, B_type, C_type, MMA_instr) \ | |
| namespace cute::tl_mma { \ | |
| template <int num_warp_m, int num_warp_n, int N> \ | |
| struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \ | |
| N> { \ | |
| using MMA = MMA_Atom<MMA_instr>; \ | |
| static constexpr int kGroupN = (num_warp_n * 16 < N) ? (num_warp_n * 16) : N; \ | |
| using MMA_Group = Tile<_X, Int<kGroupN>, _X>; \ | |
| }; \ | |
| } | |
| #define TL_DISPATCH_MMA_TEMPLATE(A_type, B_type, C_type, MMA_instr) \ | |
| namespace cute::tl_mma { \ | |
| template <int num_warp_m, int num_warp_n, int N> \ | |
| struct DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, \ | |
| N> { \ | |
| using MMA = MMA_Atom<MMA_instr<A_type, B_type, C_type>>; \ | |
| static constexpr int kGroupN = (num_warp_n * 16 < N) ? (num_warp_n * 16) : N; \ | |
| using MMA_Group = Tile<_X, Int<kGroupN>, _X>; \ | |
| }; \ | |
| } | 
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_mma.h around lines 24-41, add the missing #include
<algorithm> at the top of the file and replace the use of std::min inside the
non-type template argument with an explicit constexpr min to avoid ADL/overload
surprises; specifically, declare a small constexpr inline function (e.g.,
tl_constexpr_min) that returns the smaller of two ints and use that function in
the Tile<_X, Int<...>, _X> template argument (or alternatively call std::min
with full std:: qualification and ensure <algorithm> is included), so toolchains
that evaluate non-type template expressions resolve correctly.
| #ifdef __CUDA_ARCH_LIST__ | ||
| #if __CUDA_ARCH_LIST__ >= 1200 | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n, | ||
| N> { | ||
| using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e4_t, fp8_e4_t, float>>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n, | ||
| N> { | ||
| using MMA = MMA_Atom<SM120_16x8x32_TN<fp8_e5_t, fp8_e5_t, float>>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| #include "cuda_fp8.h" | ||
| #include <cute/arch/mma_sm120.hpp> | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN) | ||
| TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 1000 | ||
| #include "cuda_fp8.h" | ||
| #include <cute/arch/mma_sm100.hpp> | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| #include <cute/arch/mma_sm89.hpp> | ||
| TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) | ||
| TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 900 | ||
| #include "cuda_fp8.h" | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| #include <cute/arch/mma_sm89.hpp> | ||
| TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) | ||
| TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 890 | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n, | ||
| N> { | ||
| using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n, | ||
| N> { | ||
| using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| #include "cuda_fp8.h" | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| #include <cute/arch/mma_sm89.hpp> | ||
| TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN) | ||
| TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 800 | ||
| #include <cute/arch/mma_sm80.hpp> | ||
| TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN) | ||
| TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN) | ||
| TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN) | ||
| TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) | ||
| TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) | ||
| #elif __CUDA_ARCH_LIST__ >= 750 | ||
| TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN) | ||
| #endif | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m, | ||
| num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m, | ||
| num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; | ||
| }; | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>; | ||
| using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>; | ||
| }; | ||
| #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) | ||
| template <int num_warp_m, int num_warp_n, int N> | ||
| struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> { | ||
| using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; | ||
| using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>; | ||
| }; | ||
| #endif | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix arch gating and headers: add SM75 header, drop duplicate cuda_fp8.h includes, and support CUDA_ARCH fallback.
- SM75 instructions are used but mma_sm75.hpp is never included (compilation will fail for sm_75).
- cuda_fp8.h is included redundantly in each branch.
- Only checking CUDA_ARCH_LIST breaks older nvcc; add fallback to CUDA_ARCH.
Proposed patch to this block:
-#ifdef __CUDA_ARCH_LIST__
-#if __CUDA_ARCH_LIST__ >= 1200
-#include "cuda_fp8.h"
-#include <cute/arch/mma_sm120.hpp>
-#include <cute/arch/mma_sm80.hpp>
+#if defined(__CUDA_ARCH_LIST__) || defined(__CUDA_ARCH__)
+#ifndef TL_ARCH_LEVEL
+#  if defined(__CUDA_ARCH_LIST__)
+#    define TL_ARCH_LEVEL __CUDA_ARCH_LIST__
+#  else
+#    define TL_ARCH_LEVEL __CUDA_ARCH__
+#  endif
+#endif
+#if TL_ARCH_LEVEL >= 1200
+#include <cute/arch/mma_sm120.hpp>
+#include <cute/arch/mma_sm80.hpp>
 TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN)
 TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN)
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 1000
-#include "cuda_fp8.h"
-#include <cute/arch/mma_sm100.hpp>
-#include <cute/arch/mma_sm80.hpp>
-#include <cute/arch/mma_sm89.hpp>
+#elif TL_ARCH_LEVEL >= 1000
+#include <cute/arch/mma_sm100.hpp>
+#include <cute/arch/mma_sm80.hpp>
+#include <cute/arch/mma_sm89.hpp>
 TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
 TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 900
-#include "cuda_fp8.h"
-#include <cute/arch/mma_sm80.hpp>
-#include <cute/arch/mma_sm89.hpp>
+#elif TL_ARCH_LEVEL >= 900
+#include <cute/arch/mma_sm80.hpp>
+#include <cute/arch/mma_sm89.hpp>
 TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
 TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 890
-#include "cuda_fp8.h"
-#include <cute/arch/mma_sm80.hpp>
-#include <cute/arch/mma_sm89.hpp>
+#elif TL_ARCH_LEVEL >= 890
+#include <cute/arch/mma_sm80.hpp>
+#include <cute/arch/mma_sm89.hpp>
 TL_DISPATCH_MMA(fp8_e4_t, fp8_e4_t, float, SM89_16x8x32_F32E4M3E4M3F32_TN)
 TL_DISPATCH_MMA(fp8_e5_t, fp8_e5_t, float, SM89_16x8x32_F32E5M2E5M2F32_TN)
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 800
+#elif TL_ARCH_LEVEL >= 800
 #include <cute/arch/mma_sm80.hpp>
 TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN)
 TL_DISPATCH_MMA(half_t, half_t, float, SM80_16x8x16_F32F16F16F32_TN)
 TL_DISPATCH_MMA(bfloat16_t, bfloat16_t, float, SM80_16x8x16_F32BF16BF16F32_TN)
 TL_DISPATCH_MMA(tfloat32_t, tfloat32_t, float, SM80_16x8x8_F32TF32TF32F32_TN)
 TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN)
 TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN)
-#elif __CUDA_ARCH_LIST__ >= 750
-TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN)
-#endif
-#endif
+#elif TL_ARCH_LEVEL >= 750
+#include <cute/arch/mma_sm75.hpp>
+TL_DISPATCH_MMA(half_t, half_t, float, SM75_16x8x8_F32F16F16F32_TN)
+#endif
+#undef TL_ARCH_LEVELAdditionally (outside this block), include LDSM ops used by SelectCopy:
+#include <cute/arch/ldsm.hpp>Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_mma.h around lines 43 to 104, fix the arch gating
and header organization: move the cuda_fp8.h include out of the per-arch
branches and include it once before the conditional block; add an #include for
the SM75 MMA header (cute/arch/mma_sm75.hpp) so SM75 dispatches compile; replace
the sole use of __CUDA_ARCH_LIST__ with a guarded check that handles both
__CUDA_ARCH_LIST__ and a fallback to __CUDA_ARCH__ (e.g., check
defined(__CUDA_ARCH_LIST__) ? use it : if defined(__CUDA_ARCH__) compare
__CUDA_ARCH__ numeric value) so older nvcc works; remove duplicate cuda_fp8.h
includes from the branches; and, outside this block, add the LDSM ops include
needed by SelectCopy (for example include the cute LDSM header).
…ace from CUDA GEMM implementation. This cleanup enhances code clarity and maintainability by eliminating unused structures and streamlining the overall organization of the GEMM operations.
cute::tl_mmanamespace for better organization.TL_DISPATCH_MMAandTL_DISPATCH_MMA_TEMPLATEmacros to streamline the definition of dispatch instructions for various data types and architectures.Summary by CodeRabbit
Refactor
Chores