Conversation
This reverts commit 86fbbac.
tests/pytorch/test_numerics.py
Outdated
| ) | ||
| if IS_HIP_EXTENSION: | ||
| from transformer_engine.pytorch.utils import is_mi200, is_mi308 | ||
| from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi300_class |
There was a problem hiding this comment.
is_mi300_class methods is not needed, it is just 9.4 gfx family
| @@ -0,0 +1,276 @@ | |||
| /* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ | |||
There was a problem hiding this comment.
Add proper copyright header
| @@ -0,0 +1,11 @@ | |||
| /* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ | |||
There was a problem hiding this comment.
Put proper copyright header
| size_t workspace_bytes, | ||
| hipStream_t stream) { | ||
|
|
||
| // FIXME: This could be a templated lambda function in C++20. |
There was a problem hiding this comment.
As an alternative dispatch_grouped can be incorporated to ck_tile_grouped_gemm with using of nested TRANSFORMER_ENGINE_SWITCH_CONDITION
There was a problem hiding this comment.
I misread your initial comment, c5d83a4 merges dispatch_grouped and ck_tile_grouped_gemm.
| if (!transA_use && !transB_use) { CALL(RowMajor, RowMajor, false, false); } | ||
| if (!transA_use && transB_use) { CALL(RowMajor, ColMajor, false, true ); } | ||
| if ( transA_use && !transB_use) { CALL(ColMajor, RowMajor, true, false); } | ||
| /* transA_use && transB_use */ { CALL(ColMajor, ColMajor, true, true ); } |
There was a problem hiding this comment.
NV upstream does not support TT, do we support TT?
| } | ||
|
|
||
| template <typename T, typename CLayout, ck_tile::memory_operation_enum MemOp> | ||
| static inline bool dispatch_grouped(bool transA_use, |
There was a problem hiding this comment.
Why separate function is needed?
There was a problem hiding this comment.
Not strictly needed, but merging dispatch_grouped and run_grouped_impl makes the resulting function very complex, and this complexity will likely increase when we add FP8 support.
Here is how it would look like:
Details
template <typename T, typename CLayout, ck_tile::memory_operation_enum MemOp>
static inline bool dispatch_grouped(bool transA_use,
bool transB_use,
const transformer_engine::Tensor* const* A_use,
const transformer_engine::Tensor* const* B_use,
transformer_engine::Tensor* const* D,
int group_num,
void* workspace,
size_t workspace_bytes,
hipStream_t stream) {
int64_t ref_d0 = 0, ref_d1 = 0;
if (!get_flat_2d_dims(*D[0], ref_d0, ref_d1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]");
return false;
}
const ck_tile::index_t N = static_cast<ck_tile::index_t>(ref_d1);
auto run_with_tilecfg = [&](auto tile_tag) -> bool {
using TileCfgSel = decltype(tile_tag);
TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, {
using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>;
TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, {
using BLayout = std::conditional_t<kTransB, ColMajor, RowMajor>;
using Kernel = typename Runner<T, T, T, ALayout, BLayout, CLayout, TileCfgSel, MemOp>::Kernel;
const size_t needed = Kernel::GetWorkSpaceSize(group_num);
if (!workspace || workspace_bytes < needed) {
NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed);
return false;
}
thread_local std::vector<ck_tile::GroupedGemmHostArgs<0>> descs;
descs.clear();
descs.reserve(group_num);
for (int i = 0; i < group_num; ++i) {
const auto& a = data_view(*A_use[i]);
const auto& b = data_view(*B_use[i]);
const auto& d = data_view(*D[i]);
int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0;
if (!get_flat_2d_dims(*A_use[i], Ad0, Ad1) ||
!get_flat_2d_dims(*B_use[i], Bd0, Bd1) ||
!get_flat_2d_dims(*D[i], Dd0, Dd1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher).");
return false;
}
const int64_t M = transA_use ? Ad1 : Ad0;
const int64_t K = transA_use ? Ad0 : Ad1;
const int64_t N = transB_use ? Bd0 : Bd1;
const int64_t Kb = transB_use ? Bd1 : Bd0;
if (Kb != K) {
NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i);
return false;
}
if (Dd0 != M || Dd1 != N) {
NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i);
return false;
}
// Leading dimensions under the flattened-contiguous interpretation
const ck_tile::index_t stride_A = Ad1;
const ck_tile::index_t stride_B = Bd1;
const ck_tile::index_t stride_E = Dd1;
descs.emplace_back(
a.dptr,
b.dptr,
std::array<const void*, 0>{},
d.dptr,
1,
M,
N,
K,
stride_A,
stride_B,
std::array<ck_tile::index_t, 0>{},
stride_E);
}
const dim3 grids = Kernel::GridSize(descs);
auto kargs = Kernel::MakeKargs(descs);
if (!Kernel::IsSupportedArgument(kargs)) {
NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config.");
return false;
}
HIP_CHECK_ERROR(hipMemcpyAsync(workspace,
kargs.data(),
kargs.size() * sizeof(typename decltype(kargs)::value_type),
hipMemcpyHostToDevice,
stream));
const ck_tile::stream_config s{stream};
const dim3 blocks = Kernel::BlockSize();
ck_tile::launch_kernel(
s,
ck_tile::make_kernel<1>(
Kernel{}, grids, blocks, 0,
ck_tile::cast_pointer_to_constant_address_space(workspace),
group_num));
return true;
});
});
};
// Select tile config like Primus-Turbo for FP16/BF16:
// N%256 -> 256x256x64
// N%128 -> 256x128x64
// else -> 256x128x64 padding
// NOTE: We assume N is uniform across groups.
if ((N % 256) == 0) {
return run_with_tilecfg(TileCfg_256x256x64{});
} else if ((N % 128) == 0) {
return run_with_tilecfg(TileCfg_256x128x64{});
} else {
return run_with_tilecfg(TileCfg_256x128x64_padding{});
}
}Which one do you prefer?
There was a problem hiding this comment.
I found a way to merge dispatch_grouped and ck_tile_grouped_gemm instead. Implemented in c5d83a4
| #else | ||
| const int current_device = transformer_engine::cuda::current_device(); | ||
| const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); | ||
| const int current_device = transformer_engine::cuda::current_device(); |
| const int current_device = transformer_engine::cuda::current_device(); | ||
| const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); | ||
| #endif | ||
| const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false); |
There was a problem hiding this comment.
I wonder, should we use different env name on ROCm? Or it should be well documented - what does CUTLASS mean on ROCm
There was a problem hiding this comment.
Previously Matthias has another env. I left the comment to suggest use the same env as NV upstream since I recall CK is meant to be a drop-in replacement for cutlass?
Maybe we can explain this in README?
There was a problem hiding this comment.
I added a paragraph in the README in 7b1dbfa, what do you think?
| const int current_device = transformer_engine::cuda::current_device(); | ||
| const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); | ||
| #endif | ||
| const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false); |
There was a problem hiding this comment.
Previously Matthias has another env. I left the comment to suggest use the same env as NV upstream since I recall CK is meant to be a drop-in replacement for cutlass?
Maybe we can explain this in README?
Added information about CK_Tile-based grouped GEMM implementation and how to enable it.
Description
See https://github.com/ROCm/frameworks-internal/issues/15185 and https://github.com/ROCm/frameworks-internal/issues/13792 for context.
Primus-Turbo implementation: https://github.com/AMD-AGI/Primus-Turbo/blob/5bcd13785ef380fec0eec0911b7d6db5e606143e/csrc/kernels/grouped_gemm
TODOs:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: