Skip to content

Commit 04014e7

Browse files
committed
Update ROCm float multiplication in sparse Marlin MMA
Replace __builtin_amdgcn_fmul_f32 with __ocml_fmul_f32 for more accurate and consistent float multiplication in the scale_floats function on AMD GPU platforms.
1 parent 66691c3 commit 04014e7

File tree

1 file changed

+10
-10
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+10
-10
lines changed

torchao/csrc/cuda/sparse_marlin/mma.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ namespace torchao {
3535
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
3636
// is not supported. On later versions of CUDA the version without ordered
3737
// metadata results in the following warning:
38-
// | Advisory: Modifier ‘.sp::ordered_metadata should be used on instruction
39-
// | mma instead of modifier ‘.sp’ as it is expected to have substantially
38+
// | Advisory: Modifier 'sp::ordered_metadata' should be used on instruction
39+
// | 'mma' instead of modifier 'sp' as it is expected to have substantially
4040
// | reduced performance on some future architectures
4141

4242
#if defined(USE_ROCM)
@@ -281,15 +281,15 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
281281
float* c7, FragS& s1) {
282282
#ifdef USE_ROCM
283283
// AMD implementation - fixed
284-
*c0 = __builtin_amdgcn_fmul_f32(*c0, __half2float(s0[0].x));
285-
*c1 = __builtin_amdgcn_fmul_f32(*c1, __half2float(s0[0].y));
286-
*c2 = __builtin_amdgcn_fmul_f32(*c2, __half2float(s0[1].x));
287-
*c3 = __builtin_amdgcn_fmul_f32(*c3, __half2float(s0[1].y));
284+
*c0 = __ocml_fmul_f32(*c0, __half2float(s0[0].x));
285+
*c1 = __ocml_fmul_f32(*c1, __half2float(s0[0].y));
286+
*c2 = __ocml_fmul_f32(*c2, __half2float(s0[1].x));
287+
*c3 = __ocml_fmul_f32(*c3, __half2float(s0[1].y));
288288

289-
*c4 = __builtin_amdgcn_fmul_f32(*c4, __half2float(s1[0].x));
290-
*c5 = __builtin_amdgcn_fmul_f32(*c5, __half2float(s1[0].y));
291-
*c6 = __builtin_amdgcn_fmul_f32(*c6, __half2float(s1[1].x));
292-
*c7 = __builtin_amdgcn_fmul_f32(*c7, __half2float(s1[1].y));
289+
*c4 = __ocml_fmul_f32(*c4, __half2float(s1[0].x));
290+
*c5 = __ocml_fmul_f32(*c5, __half2float(s1[0].y));
291+
*c6 = __ocml_fmul_f32(*c6, __half2float(s1[1].x));
292+
*c7 = __ocml_fmul_f32(*c7, __half2float(s1[1].y));
293293
#else
294294
// NVIDIA implementation
295295
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));

0 commit comments

Comments
 (0)