@@ -133,6 +133,7 @@ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
133133template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate >
134134__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32 (float * C, uint32_t * A,
135135 uint32_t * B) {
136+ static_assert (sizeof (T) == 1 , " DType must be 8bit floating data type" );
136137#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
137138 if constexpr (mma_mode == MMAMode::kInit ) {
138139 if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
@@ -216,7 +217,7 @@ __device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uin
216217 }
217218 }
218219#else
219- static_assert ( false , " fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+" );
220+ # error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
220221#endif
221222}
222223
@@ -387,8 +388,45 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
387388#endif
388389}
389390
390- // template <typename DType>
391- // __device__ __forceinline__ void
391+ /* !
392+ * \brief Use mma instructions to compute rowsum.
393+ */
394+ template <typename DType>
395+ __device__ __forceinline__ void rowsum_f8f8f32 (float * d, DType* s) {
396+ static_assert (sizeof (DType) == 1 , " DType must be 8bit floating data type" );
397+ uint32_t * s_u32 = (uint32_t *)(s);
398+ #if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
399+ if constexpr (std::is_same<DType, __nv_fp8_e4m3>::value) {
400+ asm volatile (
401+ " {\n "
402+ " .reg .f32 ph;\n "
403+ " mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
404+ " {%0, ph, %1, ph},"
405+ " {%2, %3, %4, %5},"
406+ " {%6, %7},"
407+ " {%8, 0., %9, 0.};\n "
408+ " }\n "
409+ : " =f" (d[0 ]), " =f" (d[1 ])
410+ : " r" (s_u32[0 ]), " r" (s_u32[1 ]), " r" (s_u32[2 ]), " r" (s_u32[3 ]), " r" (943208504 ),
411+ " r" (943208504 ), " f" (d[0 ]), " f" (d[1 ]));
412+ } else { // e5m2
413+ asm volatile (
414+ " {\n "
415+ " .reg .f32 ph;\n "
416+ " mma.sync.aligned.m16n8k16.row.col.f32.e5m2.e5m2.f32 "
417+ " {%0, ph, %1, ph},"
418+ " {%2, %3, %4, %5},"
419+ " {%6, %7},"
420+ " {%8, 0., %9, 0.};\n "
421+ " }\n "
422+ : " =f" (d[0 ]), " =f" (d[1 ])
423+ : " r" (s_u32[0 ]), " r" (s_u32[1 ]), " r" (s_u32[2 ]), " r" (s_u32[3 ]), " r" (1010580540 ),
424+ " r" (1010580540 ), " f" (d[0 ]), " f" (d[1 ]));
425+ }
426+ #else
427+ #error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
428+ #endif
429+ }
392430
393431/* !
394432 * \brief Use mma instructions to compute rowsum.
0 commit comments