Skip to content

Commit 22f0c7a

Browse files
committed
[mlir][AMDGPU] 8-bit float usage in the AMDGPU dialect
Upcoming AMD hardware will include functions that accept 8-bit floats. Specifically, there are MFMA instructions that accept 8-bit floats, either using the same or mixed formats. This patch adds MLIR wrappers for these intrinsics and explicitly adds support for 8-bit floats in the gpu-to-rocdl conversion by way of amdgpu-to-rocdl. Since LLVM does not have f8 types, when targeting LLVM for compilation on an AMD GPU, both f8 types used on AMD hardware (f8E5M2FNUZ and f8E4M3FNUZ) are rewritten to i8. This patch also relaxes the restriction that the types of both source operands to a amdgpu.mfma instructions match exactly, as this is not necessarily required for the bf8 (f8E5M2FNUZ) and fp8 (f8E4M3FNUZ) instructions. In addition, since the buffer_{load,store} operations maintain a whitelist of permitted types, we add the relevant f8 types to that list. This patch does not add any implementations of arithmetic operations for f8 types. Reviewed By: jakeh-gc Differential Revision: https://reviews.llvm.org/D143956
1 parent 50ef867 commit 22f0c7a

File tree

9 files changed

+229
-50
lines changed

9 files changed

+229
-50
lines changed

mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ def AMDGPU_RawBufferLoadOp :
4747
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
4848
OptionalAttr<I32Attr>:$indexOffset,
4949
Optional<I32>:$sgprOffset)>,
50-
Results<(outs AnyTypeOf<[BF16, F16, F32, I32, I8,
50+
Results<(outs AnyTypeOf<[BF16, F16, F32, I32, I8, F8E5M2FNUZ, F8E4M3FNUZ,
5151
VectorOfLengthAndType<[2, 4], [F32, I32]>,
5252
VectorOfLengthAndType<[2, 4, 8], [F16, BF16]>,
53-
VectorOfLengthAndType<[2, 4, 8, 16], [I8]>]>:$value)> {
53+
VectorOfLengthAndType<[2, 4, 8, 16],
54+
[I8, F8E5M2FNUZ, F8E4M3FNUZ]>]>:$value)> {
5455

5556
let summary = "Raw Buffer load, exposing GCN features";
5657
let description = [{
@@ -96,10 +97,11 @@ def AMDGPU_RawBufferLoadOp :
9697
def AMDGPU_RawBufferStoreOp :
9798
AMDGPU_Op<"raw_buffer_store", [AllElementTypesMatch<["value", "memref"]>,
9899
AttrSizedOperandSegments]>,
99-
Arguments<(ins AnyTypeOf<[BF16, F16, F32, I32, I8,
100+
Arguments<(ins AnyTypeOf<[BF16, F16, F32, I32, I8, F8E5M2FNUZ, F8E4M3FNUZ,
100101
VectorOfLengthAndType<[2, 4], [F32, I32]>,
101102
VectorOfLengthAndType<[2, 4, 8], [F16, BF16]>,
102-
VectorOfLengthAndType<[2, 4, 8, 16], [I8]>]>:$value,
103+
VectorOfLengthAndType<[2, 4, 8, 16],
104+
[I8, F8E5M2FNUZ, F8E4M3FNUZ]>]>:$value,
103105
Arg<AnyMemRef, "buffer to store to", [MemWrite]>:$memref,
104106
Variadic<I32>:$indices,
105107
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
@@ -215,15 +217,15 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
215217
VectorOfLengthAndType<[2], [F32]>,
216218
VectorOfLengthAndType<[4], [F16]>,
217219
VectorOfLengthAndType<[2, 4], [BF16]>,
218-
VectorOfLengthAndType<[4, 8], [I8]>]>;
220+
VectorOfLengthAndType<[4, 8], [I8]>,
221+
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
219222
def MFMAOutTypes : AnyTypeOf<[F64,
220223
VectorOfLengthAndType<[4, 16, 32], [F32]>,
221224
VectorOfLengthAndType<[4, 16, 32], [I32]>,
222225
VectorOfLengthAndType<[4], [F64]>]>;
223226

224227
def AMDGPU_MFMAOp :
225-
AMDGPU_Op<"mfma", [AllTypesMatch<["sourceA", "sourceB"]>,
226-
AllTypesMatch<["destC", "destD"]>,
228+
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
227229
Pure]>,
228230
Arguments<(ins
229231
I32Attr:$m,
@@ -274,7 +276,7 @@ def AMDGPU_MFMAOp :
274276
$sourceA `*` $sourceB `+` $destC
275277
attr-dict
276278
`blgp` `=` $blgp
277-
`:` type($sourceA) `,` type($destC)
279+
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
278280
}];
279281
let hasVerifier = 1;
280282
}

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,15 @@ def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8">;
172172
def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8">;
173173
def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32">;
174174
def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32">;
175+
// fp8, only on gfx940
176+
def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8">;
177+
def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8">;
178+
def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8">;
179+
def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.fp8">;
180+
def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.bf8">;
181+
def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8">;
182+
def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">;
183+
def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">;
175184

176185
//===---------------------------------------------------------------------===//
177186
// Vector buffer load/store intrinsics

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,45 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
404404
if (m == 4 && n == 4 && k == 4 && b == 4)
405405
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
406406
}
407+
408+
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
409+
chipset.minorVersion >= 0x40) {
410+
// Known to be correct because there are no scalar f8 instructions and
411+
// because a length mismatch will have been caught by the verifier.
412+
Type sourceBElem =
413+
mfma.getSourceB().getType().cast<VectorType>().getElementType();
414+
if (m == 16 && n == 16 && k == 32 && b == 1) {
415+
if (sourceBElem.isFloat8E5M2FNUZ())
416+
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
417+
if (sourceBElem.isFloat8E4M3FNUZ())
418+
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
419+
}
420+
if (m == 32 && n == 32 && k == 16 && b == 1) {
421+
if (sourceBElem.isFloat8E5M2FNUZ())
422+
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
423+
if (sourceBElem.isFloat8E4M3FNUZ())
424+
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
425+
}
426+
}
427+
428+
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
429+
chipset.minorVersion >= 0x40) {
430+
Type sourceBElem =
431+
mfma.getSourceB().getType().cast<VectorType>().getElementType();
432+
if (m == 16 && n == 16 && k == 32 && b == 1) {
433+
if (sourceBElem.isFloat8E5M2FNUZ())
434+
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
435+
if (sourceBElem.isFloat8E4M3FNUZ())
436+
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
437+
}
438+
if (m == 32 && n == 32 && k == 16 && b == 1) {
439+
if (sourceBElem.isFloat8E5M2FNUZ())
440+
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
441+
if (sourceBElem.isFloat8E4M3FNUZ())
442+
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
443+
}
444+
}
445+
407446
return std::nullopt;
408447
}
409448

@@ -475,6 +514,14 @@ struct ConvertAMDGPUToROCDLPass
475514
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
476515
RewritePatternSet &patterns,
477516
Chipset chipset) {
517+
// ROCDL supports fp8 types in some contexts, but there is no LLVM-level f8
518+
// type. Therefore, for this target, declare f8 to be equal to i8.
519+
converter.addConversion([](FloatType type) -> std::optional<Type> {
520+
if (type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ())
521+
return IntegerType::get(type.getContext(), 8);
522+
return std::nullopt;
523+
});
524+
478525
patterns.add<LDSBarrierOpLowering>(converter);
479526
patterns.add<
480527
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawBufferLoadOp>,

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,24 @@ LogicalResult MFMAOp::verify() {
189189
destElem = destVector.getElementType();
190190
}
191191

192+
Type sourceBType = getSourceB().getType();
193+
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
194+
int64_t sourceBLen = 1;
195+
Type sourceBElem = sourceBType;
196+
if (auto sourceBVector = sourceBType.dyn_cast<VectorType>()) {
197+
sourceBLen = sourceBVector.getNumElements();
198+
sourceBElem = sourceBVector.getElementType();
199+
}
200+
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
201+
return emitOpError("expected both source operands to have f8 elements");
202+
if (sourceLen != sourceBLen)
203+
return emitOpError(
204+
"expected both f8 source vectors to have the same length");
205+
} else {
206+
if (sourceType != sourceBType)
207+
return emitOpError(
208+
"expected both non-f8 source operand types to match exactly");
209+
}
192210
// Normalize the wider integer types the compiler expects to i8
193211
if (sourceElem.isInteger(32)) {
194212
sourceLen *= 4;

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ func.func @gpu_gcn_raw_buffer_load_i8(%buf: memref<64xi8>, %idx: i32) -> i8 {
4949
%0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi8>, i32 -> i8
5050
func.return %0 : i8
5151
}
52+
5253
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_2xi8
5354
func.func @gpu_gcn_raw_buffer_load_2xi8(%buf: memref<64xi8>, %idx: i32) -> vector<2xi8> {
5455
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
@@ -69,6 +70,29 @@ func.func @gpu_gcn_raw_buffer_load_16xi8(%buf: memref<64xi8>, %idx: i32) -> vect
6970
func.return %0 : vector<16xi8>
7071
}
7172

73+
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ
74+
func.func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ(%buf: memref<64xf8E5M2FNUZ>, %idx: i32) -> f8E5M2FNUZ {
75+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
76+
// CHECK: llvm.insertelement{{.*}}%[[numRecords]]
77+
// CHECK: %[[loaded:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i8
78+
// CHECK: %[[ret:.*]] = builtin.unrealized_conversion_cast %[[loaded]] : i8 to f8E5M2FNUZ
79+
// CHECK: return %[[ret]]
80+
%0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xf8E5M2FNUZ>, i32 -> f8E5M2FNUZ
81+
func.return %0 : f8E5M2FNUZ
82+
}
83+
84+
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ
85+
func.func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ(%buf: memref<64xf8E4M3FNUZ>, %idx: i32) -> vector<4xf8E4M3FNUZ> {
86+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
87+
// CHECK: llvm.insertelement{{.*}}%[[numRecords]]
88+
// CHECK: %[[loaded:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
89+
// CHECK: %[[cast:.*]] = llvm.bitcast %[[loaded]] : i32 to vector<4xi8>
90+
// CHECK: %[[ret:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
91+
// CHECK: return %[[ret]]
92+
%0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xf8E4M3FNUZ>, i32 -> vector<4xf8E4M3FNUZ>
93+
func.return %0 : vector<4xf8E4M3FNUZ>
94+
}
95+
7296
// Since the lowering logic is shared with loads, only bitcasts need to be rechecked
7397
// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_i32
7498
func.func @gpu_gcn_raw_buffer_store_i32(%value: i32, %buf: memref<64xi32>, %idx: i32) {

0 commit comments

Comments
 (0)