Skip to content

[MLIR][ROCDL]Add MFMA_*_F8F6F4 instructions to the ROCDL dialect #123830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 23, 2025

Conversation

yiqian1
Copy link
Contributor

@yiqian1 yiqian1 commented Jan 21, 2025

This PR adds mfma.scale.f32.32x32x64.f8f6f4 and mfma.scale.f32.16x16x128.f8f6f4 to the ROCDL dialect. They are converted to the corresponding intrinsics in the mlir-to-llvmir pass.

This PR adds mfma.scale.f32.32x32x64.f8f6f4 and mfma.scale.f32.16x16x128.f8f6f4 to the ROCDL dialect.
They are converted to the corresponding intrinsics in the mlir-to-llvmir pass.
@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Yi Qian (yiqian1)

Changes

This PR adds mfma.scale.f32.32x32x64.f8f6f4 and mfma.scale.f32.16x16x128.f8f6f4 to the ROCDL dialect. They are converted to the corresponding intrinsics in the mlir-to-llvmir pass.


Patch is 25.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123830.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+14-1)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+276)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index e9e62a74237c4f..95fbe7ed66a434 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -343,6 +343,18 @@ class ROCDL_Mfma_IntrOp<string mnemonic, list<Trait> traits = []> :
     "$args attr-dict `:` functional-type($args, $res)";
 }
 
+//===---------------------------------------------------------------------===//
+// MFMA intrinsics with overloaded operands
+class ROCDL_Mfma_OO_IntrOp<string mnemonic, list<int> overloadedOperands,
+                        list<Trait> traits = []> :
+  LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
+                  "amdgcn_" # !subst(".","_", mnemonic),
+                  [], overloadedOperands, traits, 1>,
+  Arguments<(ins Variadic<LLVM_Type>:$args)> {
+  let assemblyFormat =
+    "$args attr-dict `:` functional-type($args, $res)";
+}
+
 // Available on all CDNA.
 def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32">;
 def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">;
@@ -394,7 +406,8 @@ def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16">;
 def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf16">;
 def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8">;
 def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16">;
-
+def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", [0,1]>;
+def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", [0,1]>;
 //===---------------------------------------------------------------------===//
 // WMMA intrinsics
 class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 8879ba02b24057..b74edb62106837 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -398,6 +398,282 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
   llvm.return %r0 : vector<32 x f32>
 }
 
+llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
+                   %arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
+                   %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<16 x f32> {
+  %cst0 = llvm.mlir.constant(0 : i32) : i32
+  %cst1 = llvm.mlir.constant(1 : i32) : i32
+  %cst2 = llvm.mlir.constant(2 : i32) : i32
+  %cst3 = llvm.mlir.constant(3 : i32) : i32
+  %cst4 = llvm.mlir.constant(4 : i32) : i32
+
+  // CHECK-LABEL: rocdl.mfma.scale.f32.32x32x64.f8f6f4
+  // fp8 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r00 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r01 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r02 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r03 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp8 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 0, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r04 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r10 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r11 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r12 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r13 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf8 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 1, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r14 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r20 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r21 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r22 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r23 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp6 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 2, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r24 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst2, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r30 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r31 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg2, %arg1, %cst3, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r32 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r33 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg3, %arg1, %cst3, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // bf6 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 3, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r34 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg3, %arg4, %arg1, %cst3, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * fp8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r40 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * bf8
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r41 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg2, %arg1, %cst4, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * fp6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r42 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * bf6
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r43 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg3, %arg1, %cst4, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  // fp4 * fp4
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 4, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r44 = rocdl.mfma.scale.f32.32x32x64.f8f6f4 %arg4, %arg4, %arg1, %cst4, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+
+  llvm.return %r00 : vector<16 x f32>
+}
+
+llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
+                   %arg1 : vector<4 x f32>, %arg2 : vector<8xi32>,
+                   %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<4 x f32> {
+  %cst0 = llvm.mlir.constant(0 : i32) : i32
+  %cst1 = llvm.mlir.constant(1 : i32) : i32
+  %cst2 = llvm.mlir.constant(2 : i32) : i32
+  %cst3 = llvm.mlir.constant(3 : i32) : i32
+  %cst4 = llvm.mlir.constant(4 : i32) : i32
+
+  // CHECK-LABEL: rocdl.mfma.scale.f32.16x16x128.f8f6f4
+  // fp8 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r00 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r01 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst0, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r02 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r03 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst0, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp8 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r04 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst0, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r10 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r11 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg2, %arg1, %cst1, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r12 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * bf6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r13 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg3, %arg1, %cst1, %cst3, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // bf8 * fp4
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r14 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2, %arg4, %arg1, %cst1, %cst4, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<8xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * fp8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r20 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst0, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * bf8
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r21 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg2, %arg1, %cst2, %cst1, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+
+  // fp6 * fp6
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
+  %r22 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3, %arg3, %arg1, %cst2, %cst2, %cst0, %arg0, %cst0, %arg0 :
+                              (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32,...
[truncated]

@jungpark-mlir jungpark-mlir requested review from kuhar and krzysz00 and removed request for kuhar January 22, 2025 14:56
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, though I'd keep the newline

@krzysz00
Copy link
Contributor

Actually, wait, no, hold on.

Please add tests to Dialect/LLVM/rocdl.mlir that just show these ops existing.

@jungpark-mlir jungpark-mlir merged commit c118864 into llvm:main Jan 23, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants