99// Include platform-specific implementations
1010#if defined(PLATFORM_CUDA_DEVICE)
1111#include " backend/cuda/mma.cuh"
12- namespace detail = flashinfer::gpu_iface::mma_impl::cuda;
12+ namespace mma_detail = flashinfer::gpu_iface::mma_impl::cuda;
1313#elif defined(PLATFORM_HIP_DEVICE)
1414#include " backend/hip/mma_hip.h"
15- namespace detail = flashinfer::gpu_iface::mma_impl::hip;
15+ namespace mma_detail = flashinfer::gpu_iface::mma_impl::hip;
1616#endif
1717
1818namespace flashinfer
@@ -34,14 +34,14 @@ namespace mma
3434template <typename T>
3535__device__ __forceinline__ void load_fragment (uint32_t *R, const T *smem_ptr)
3636{
37- detail ::load_fragment<T>(R, smem_ptr);
37+ mma_detail ::load_fragment<T>(R, smem_ptr);
3838}
3939
4040template <typename T>
4141__device__ __forceinline__ void
4242load_fragment_transpose (uint32_t *R, const T *smem_ptr, uint32_t stride)
4343{
44- detail ::load_fragment_transpose<T>(R, smem_ptr, stride);
44+ mma_detail ::load_fragment_transpose<T>(R, smem_ptr, stride);
4545}
4646
4747#if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__)
@@ -51,13 +51,14 @@ load_fragment_transpose_4x4_half_registers(uint32_t *R, const T *smem_ptr)
5151{
5252 static_assert (std::is_same<T, int >::value,
5353 " Only __half is supported for the 4x4 register transpose" );
54- detail ::load_fragment_4x4_half_registers<half>(R, smem_ptr);
54+ mma_detail ::load_fragment_4x4_half_registers<half>(R, smem_ptr);
5555}
5656#endif
5757
5858/* !
59- * \brief Wrapper of two mma m16n16k16 instructions for row major and column
60- * major f16 matrix multiplication, accumulated in f32.
59+ * \brief An m16n16k16 gemm kernel using MMA instructions for CUDA/HIP for row
60+ * major and column major f16 matrix multiplication, accumulated in f32.
61+ *
6162 * \tparam T data type of the fragment
6263 * \tparam mma_mode whether we are initializing the accumulator or updating it
6364 * \param C pointer to the accumulator
@@ -66,32 +67,17 @@ load_fragment_transpose_4x4_half_registers(uint32_t *R, const T *smem_ptr)
6667 */
6768template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate >
6869__device__ __forceinline__ void
69- amdgcn_mfma_fp32_16x16x16fp16 (float *C, uint32_t *A, uint32_t *B)
70+ mma_sync_m16n16k16_row_col_f16f16f32 (float *C, uint32_t *A, uint32_t *B)
7071{
71- #if defined(PLATFORM_HIP_DEVICE)
72- detail::amdgcn_mfma_fp32_16x16x16fp16<T, mma_mode>(C, A, B);
73- #else
74- FLASHINFER_RUNTIME_ASSERT (
75- " MMA f16f16f32 not supported on this architecture" );
76- #endif
72+ mma_detail::mma_sync_m16n16k16_row_col_f16f16f32<T, mma_mode>(C, A, B);
7773}
7874
7975template <typename DType>
8076__device__ __forceinline__ void m16k16_rowsum_f16f16f32 (float *d, DType *s)
8177{
82- detail ::m16k16_rowsum_f16f16f32<DType>(d, s);
78+ mma_detail ::m16k16_rowsum_f16f16f32<DType>(d, s);
8379}
8480
85- // /*!
86- // * \brief Use mma instructions to compute rowsum.
87- // */
88- // template <typename DType>
89- // __device__ __forceinline__ void
90- // m16k16_rowsum_f16f16f32(float* d, DType* s)
91- // {
92- // detail::m16k16_rowsum_f16f16f32(d, s);
93- // }
94-
9581} // namespace mma
9682} // namespace gpu_iface
9783} // namespace flashinfer
0 commit comments