Skip to content

Grouped GEMM with ck_tile#434

Open
matthiasdiener wants to merge 58 commits intodevfrom
ck-grouped-gemm
Open

Grouped GEMM with ck_tile#434
matthiasdiener wants to merge 58 commits intodevfrom
ck-grouped-gemm

Conversation

@matthiasdiener
Copy link
Contributor

@matthiasdiener matthiasdiener commented Jan 28, 2026

Description

See https://github.com/ROCm/frameworks-internal/issues/15185 and https://github.com/ROCm/frameworks-internal/issues/13792 for context.

Primus-Turbo implementation: https://github.com/AMD-AGI/Primus-Turbo/blob/5bcd13785ef380fec0eec0911b7d6db5e606143e/csrc/kernels/grouped_gemm

TODOs:

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Implement ck_tile-based group GEMM, similar to Cutlass

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@matthiasdiener matthiasdiener self-assigned this Jan 28, 2026
@matthiasdiener matthiasdiener changed the title [WIP] proof-of-concept: grouped GEMM with ck_tile [WIP] Grouped GEMM with ck_tile Jan 29, 2026
@matthiasdiener matthiasdiener marked this pull request as ready for review February 17, 2026 22:58
)
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import is_mi200, is_mi308
from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi300_class
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_mi300_class methods is not needed, it is just 9.4 gfx family

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in 7910038

@@ -0,0 +1,276 @@
/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add proper copyright header

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done in f680d6a

@@ -0,0 +1,11 @@
/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put proper copyright header

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done in f680d6a

size_t workspace_bytes,
hipStream_t stream) {

// FIXME: This could be a templated lambda function in C++20.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative dispatch_grouped can be incorporated to ck_tile_grouped_gemm with using of nested TRANSFORMER_ENGINE_SWITCH_CONDITION

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of 6d85088?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misread your initial comment, c5d83a4 merges dispatch_grouped and ck_tile_grouped_gemm.

if (!transA_use && !transB_use) { CALL(RowMajor, RowMajor, false, false); }
if (!transA_use && transB_use) { CALL(RowMajor, ColMajor, false, true ); }
if ( transA_use && !transB_use) { CALL(ColMajor, RowMajor, true, false); }
/* transA_use && transB_use */ { CALL(ColMajor, ColMajor, true, true ); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NV upstream does not support TT, do we support TT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, yes.

}

template <typename T, typename CLayout, ck_tile::memory_operation_enum MemOp>
static inline bool dispatch_grouped(bool transA_use,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why separate function is needed?

Copy link
Contributor Author

@matthiasdiener matthiasdiener Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not strictly needed, but merging dispatch_grouped and run_grouped_impl makes the resulting function very complex, and this complexity will likely increase when we add FP8 support.

Here is how it would look like:

Details
template <typename T, typename CLayout, ck_tile::memory_operation_enum MemOp>
static inline bool dispatch_grouped(bool transA_use,
                                    bool transB_use,
                                    const transformer_engine::Tensor* const* A_use,
                                    const transformer_engine::Tensor* const* B_use,
                                    transformer_engine::Tensor* const* D,
                                    int group_num,
                                    void* workspace,
                                    size_t workspace_bytes,
                                    hipStream_t stream) {

  int64_t ref_d0 = 0, ref_d1 = 0;
  if (!get_flat_2d_dims(*D[0], ref_d0, ref_d1)) {
    NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]");
    return false;
  }
  const ck_tile::index_t N = static_cast<ck_tile::index_t>(ref_d1);

  auto run_with_tilecfg = [&](auto tile_tag) -> bool {
    using TileCfgSel = decltype(tile_tag);
    TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, {
      using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>;

      TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, {
        using BLayout = std::conditional_t<kTransB, ColMajor, RowMajor>;

        using Kernel = typename Runner<T, T, T, ALayout, BLayout, CLayout, TileCfgSel, MemOp>::Kernel;

        const size_t needed = Kernel::GetWorkSpaceSize(group_num);
        if (!workspace || workspace_bytes < needed) {
          NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed);
          return false;
        }

        thread_local std::vector<ck_tile::GroupedGemmHostArgs<0>> descs;
        descs.clear();
        descs.reserve(group_num);

        for (int i = 0; i < group_num; ++i) {
          const auto& a = data_view(*A_use[i]);
          const auto& b = data_view(*B_use[i]);
          const auto& d = data_view(*D[i]);

          int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0;
          if (!get_flat_2d_dims(*A_use[i], Ad0, Ad1) ||
              !get_flat_2d_dims(*B_use[i], Bd0, Bd1) ||
              !get_flat_2d_dims(*D[i],     Dd0, Dd1)) {
            NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher).");
            return false;
          }

          const int64_t M  = transA_use ? Ad1 : Ad0;
          const int64_t K  = transA_use ? Ad0 : Ad1;
          const int64_t N  = transB_use ? Bd0 : Bd1;
          const int64_t Kb = transB_use ? Bd1 : Bd0;

          if (Kb != K) {
            NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i);
            return false;
          }

          if (Dd0 != M || Dd1 != N) {
            NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i);
            return false;
          }

          // Leading dimensions under the flattened-contiguous interpretation
          const ck_tile::index_t stride_A = Ad1;
          const ck_tile::index_t stride_B = Bd1;
          const ck_tile::index_t stride_E = Dd1;

          descs.emplace_back(
              a.dptr,
              b.dptr,
              std::array<const void*, 0>{},
              d.dptr,
              1,
              M,
              N,
              K,
              stride_A,
              stride_B,
              std::array<ck_tile::index_t, 0>{},
              stride_E);
        }

        const dim3 grids = Kernel::GridSize(descs);
        auto kargs = Kernel::MakeKargs(descs);
        if (!Kernel::IsSupportedArgument(kargs)) {
          NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config.");
          return false;
        }

        HIP_CHECK_ERROR(hipMemcpyAsync(workspace,
                                      kargs.data(),
                                      kargs.size() * sizeof(typename decltype(kargs)::value_type),
                                      hipMemcpyHostToDevice,
                                      stream));

        const ck_tile::stream_config s{stream};
        const dim3 blocks = Kernel::BlockSize();

        ck_tile::launch_kernel(
            s,
            ck_tile::make_kernel<1>(
                Kernel{}, grids, blocks, 0,
                ck_tile::cast_pointer_to_constant_address_space(workspace),
                group_num));
        return true;
      });
    });
  };

  // Select tile config like Primus-Turbo for FP16/BF16:
  //   N%256 -> 256x256x64
  //   N%128 -> 256x128x64
  //   else  -> 256x128x64 padding
  // NOTE: We assume N is uniform across groups.
  if ((N % 256) == 0) {
    return run_with_tilecfg(TileCfg_256x256x64{});
  } else if ((N % 128) == 0) {
    return run_with_tilecfg(TileCfg_256x128x64{});
  } else {
    return run_with_tilecfg(TileCfg_256x128x64_padding{});
  }
}

Which one do you prefer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found a way to merge dispatch_grouped and ck_tile_grouped_gemm instead. Implemented in c5d83a4

#else
const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
const int current_device = transformer_engine::cuda::current_device();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please restore indent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 98e0c66

const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
#endif
const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder, should we use different env name on ROCm? Or it should be well documented - what does CUTLASS mean on ROCm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously Matthias has another env. I left the comment to suggest use the same env as NV upstream since I recall CK is meant to be a drop-in replacement for cutlass?

Maybe we can explain this in README?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a paragraph in the README in 7b1dbfa, what do you think?

const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
#endif
const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously Matthias has another env. I left the comment to suggest use the same env as NV upstream since I recall CK is meant to be a drop-in replacement for cutlass?

Maybe we can explain this in README?

Added information about CK_Tile-based grouped GEMM implementation and how to enable it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants