Skip to content

Commit 622b0e8

Browse files
authored
Add a wrapper for hgemm kernel (#69)
Adds a common wrapper function to mma_ops.hpp for hgemm kernels that works for both CUDA and HIP. Replaces `mma_sync_m16n16k16_row_col_f16f16f32`
1 parent 4627b70 commit 622b0e8

File tree

3 files changed

+13
-28
lines changed

3 files changed

+13
-28
lines changed

libflashinfer/include/gpu_iface/backend/hip/mma_hip.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,10 @@ load_fragment_transpose(uint32_t *R, const T *smem_ptr, uint32_t stride)
112112
static_cast<const uint32_t>(*v3);
113113
}
114114

115-
// MMA operation for FP16 inputs with FP32 accumulator
116115
// MMA operation for FP16 inputs with FP32 accumulator
117116
template <typename T, mma::MMAMode mma_mode = mma::MMAMode::kInplaceUpdate>
118117
__device__ __forceinline__ void
119-
amdgcn_mfma_fp32_16x16x16fp16(float *C, uint32_t *A, uint32_t *B)
118+
mma_sync_m16n16k16_row_col_f16f16f32(float *C, uint32_t *A, uint32_t *B)
120119
{
121120
// Ensure T is either __half or __hip_bfloat16
122121
static_assert(std::is_same_v<T, __half> ||

libflashinfer/include/gpu_iface/mma_ops.hpp

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
// Include platform-specific implementations
1010
#if defined(PLATFORM_CUDA_DEVICE)
1111
#include "backend/cuda/mma.cuh"
12-
namespace detail = flashinfer::gpu_iface::mma_impl::cuda;
12+
namespace mma_detail = flashinfer::gpu_iface::mma_impl::cuda;
1313
#elif defined(PLATFORM_HIP_DEVICE)
1414
#include "backend/hip/mma_hip.h"
15-
namespace detail = flashinfer::gpu_iface::mma_impl::hip;
15+
namespace mma_detail = flashinfer::gpu_iface::mma_impl::hip;
1616
#endif
1717

1818
namespace flashinfer
@@ -34,14 +34,14 @@ namespace mma
3434
template <typename T>
3535
__device__ __forceinline__ void load_fragment(uint32_t *R, const T *smem_ptr)
3636
{
37-
detail::load_fragment<T>(R, smem_ptr);
37+
mma_detail::load_fragment<T>(R, smem_ptr);
3838
}
3939

4040
template <typename T>
4141
__device__ __forceinline__ void
4242
load_fragment_transpose(uint32_t *R, const T *smem_ptr, uint32_t stride)
4343
{
44-
detail::load_fragment_transpose<T>(R, smem_ptr, stride);
44+
mma_detail::load_fragment_transpose<T>(R, smem_ptr, stride);
4545
}
4646

4747
#if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__)
@@ -51,13 +51,14 @@ load_fragment_transpose_4x4_half_registers(uint32_t *R, const T *smem_ptr)
5151
{
5252
static_assert(std::is_same<T, int>::value,
5353
"Only __half is supported for the 4x4 register transpose");
54-
detail::load_fragment_4x4_half_registers<half>(R, smem_ptr);
54+
mma_detail::load_fragment_4x4_half_registers<half>(R, smem_ptr);
5555
}
5656
#endif
5757

5858
/*!
59-
* \brief Wrapper of two mma m16n16k16 instructions for row major and column
60-
* major f16 matrix multiplication, accumulated in f32.
59+
* \brief An m16n16k16 gemm kernel using MMA instructions for CUDA/HIP for row
60+
* major and column major f16 matrix multiplication, accumulated in f32.
61+
*
6162
* \tparam T data type of the fragment
6263
* \tparam mma_mode whether we are initializing the accumulator or updating it
6364
* \param C pointer to the accumulator
@@ -66,32 +67,17 @@ load_fragment_transpose_4x4_half_registers(uint32_t *R, const T *smem_ptr)
6667
*/
6768
template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate>
6869
__device__ __forceinline__ void
69-
amdgcn_mfma_fp32_16x16x16fp16(float *C, uint32_t *A, uint32_t *B)
70+
mma_sync_m16n16k16_row_col_f16f16f32(float *C, uint32_t *A, uint32_t *B)
7071
{
71-
#if defined(PLATFORM_HIP_DEVICE)
72-
detail::amdgcn_mfma_fp32_16x16x16fp16<T, mma_mode>(C, A, B);
73-
#else
74-
FLASHINFER_RUNTIME_ASSERT(
75-
"MMA f16f16f32 not supported on this architecture");
76-
#endif
72+
mma_detail::mma_sync_m16n16k16_row_col_f16f16f32<T, mma_mode>(C, A, B);
7773
}
7874

7975
template <typename DType>
8076
__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float *d, DType *s)
8177
{
82-
detail::m16k16_rowsum_f16f16f32<DType>(d, s);
78+
mma_detail::m16k16_rowsum_f16f16f32<DType>(d, s);
8379
}
8480

85-
// /*!
86-
// * \brief Use mma instructions to compute rowsum.
87-
// */
88-
// template <typename DType>
89-
// __device__ __forceinline__ void
90-
// m16k16_rowsum_f16f16f32(float* d, DType* s)
91-
// {
92-
// detail::m16k16_rowsum_f16f16f32(d, s);
93-
// }
94-
9581
} // namespace mma
9682
} // namespace gpu_iface
9783
} // namespace flashinfer

libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ __global__ void test_mfma_kernel(const __half *A, const __half *B, float *C)
8181
flashinfer::gpu_iface::mma::load_fragment_transpose<__half>(b_reg,
8282
&B[b_idx], LDB);
8383

84-
flashinfer::gpu_iface::mma::amdgcn_mfma_fp32_16x16x16fp16<__half>(
84+
flashinfer::gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32<__half>(
8585
c_reg, a_reg, b_reg);
8686

8787
for (int i = 0; i < 4; ++i) {

0 commit comments

Comments
 (0)