Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][AMDGPU] Add OCP FP8 support for new hardware #106160

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

pcf000
Copy link
Contributor

@pcf000 pcf000 commented Aug 26, 2024

Upcoming hardware (gfx12 and some future gfx9) will support the OCP 8-bit float formats for their matrix multiplication intrinsics and conversion operations, retaining existing opcodes and compiler builtins.

This commit adds support for these types to the MLIR wrappers around such operaitons, ensuring that the OCP types aren't used to generate those builtins on hardware that doesn't expect that format and, conversely, to ensure that the pre-OCP formats aren't used on new hardware.

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 26, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Paul C Fuqua (pcf000)

Changes

This part mostly just allows the new types in places where the other F8 formats were allowed.


Full diff: https://github.com/llvm/llvm-project/pull/106160.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+9-9)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h (+7)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+21-18)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+7-2)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+3-1)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e5c1a53f34bf64..04b66cea661afc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,8 +41,8 @@ class AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
 
 def AMDGPU_ExtPackedFp8Op :
     AMDGPU_Op<"ext_packed_fp8", [Pure]>,
-    Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ,
-        VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source,
+    Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
+        VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
       ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
     Results<(outs F32:$res)> {
   let summary = "Extend one of a vector of packed fp8 values to a float";
@@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op :
     Arguments<(ins F32:$sourceA,
       Optional<F32>:$sourceB,
       ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
-      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
-    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
   let summary = "Round two floats into a packed vector of 8-bit floats";
   let description = [{
     Round the inputs `sourceA` and `sourceB` (which is undefined if not
@@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op :
     Arguments<(ins F32:$source,
       I32:$stochiasticParam,
       ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$storeIndex,
-      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>>:$existing)>,
-    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> {
+      Optional<FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
   let summary = "Round float stochiastically into a packed vector of 8-bit floats";
   let description = [{
     Round the input `source`, adding in `stochiasticParam`, and place it into
@@ -405,7 +405,7 @@ def AMDGPU_RawBufferAtomicUminOp :
 
 def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
     "The possible permutations for a DPP operation",
-    [  
+    [
       I32EnumAttrCase<"quad_perm",  0>,
       I32EnumAttrCase<"row_shl",    1>,
       I32EnumAttrCase<"row_shr",    2>,
@@ -419,7 +419,7 @@ def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
       I32EnumAttrCase<"row_bcast_15", 10>,
       I32EnumAttrCase<"row_bcast_31", 11>
     ]> {
-  let genSpecializedAttr = 0;  
+  let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::amdgpu";
 }
 
@@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
                              VectorOfLengthAndType<[4], [F16]>,
                              VectorOfLengthAndType<[2, 4], [BF16]>,
                              VectorOfLengthAndType<[4, 8], [I8]>,
-                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>;
+                             VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
 def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
index 38e0ebe68f943b..6de12a3d50878b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
@@ -18,6 +18,13 @@ struct Chipset {
       : majorVersion(majorVersion), minorVersion(minorVersion){};
   static FailureOr<Chipset> parse(StringRef name);
 
+  bool isGfx940() const {
+    return majorVersion == 9 && minorVersion >= 0x40 && majorVersion < 0x50;
+  }
+  bool hasOcpFp8() const {
+    return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
+  }
+
   unsigned majorVersion = 0;
   unsigned minorVersion = 0;
 };
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 809e9448e80abf..9323fdc7dacd6d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -539,40 +539,42 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f64_4x4x4f64::getOperationName();
   }
 
-  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
-      chipset.minorVersion >= 0x40) {
+  if (destElem.isF32() &&
+      ((sourceElem.isFloat8E5M2FNUZ() && chipset.isGfx940()) ||
+       (sourceElem.isFloat8E5M2() && chipset.hasOcpFp8()))) {
     // Known to be correct because there are no scalar f8 instructions and
     // because a length mismatch will have been caught by the verifier.
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
     }
   }
 
-  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
-      chipset.minorVersion >= 0x40) {
+  if (destElem.isF32() &&
+      ((sourceElem.isFloat8E4M3FNUZ() && chipset.isGfx940()) ||
+       (sourceElem.isFloat8E4M3FN() && chipset.hasOcpFp8()))) {
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (sourceBElem.isFloat8E5M2FNUZ() || sourceBElem.isFloat8E5M2())
         return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (sourceBElem.isFloat8E4M3FNUZ() || sourceBElem.isFloat8E4M3FN())
         return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
     }
   }
@@ -762,10 +764,11 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   }
   Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
   Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
-  if (sourceElemType.isFloat8E5M2FNUZ()) {
+  if (sourceElemType.isFloat8E5M2FNUZ() || sourceElemType.isFloat8E5M2()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
                                                     wordSel);
-  } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+  } else if (sourceElemType.isFloat8E4M3FNUZ() ||
+             sourceElemType.isFloat8E4M3FN()) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
                                                     wordSel);
   }
@@ -797,10 +800,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
   Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
     result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
     result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
 
@@ -832,10 +835,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
   Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (resultElemType.isFloat8E5M2FNUZ() || resultElemType.isFloat8E5M2())
     result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (resultElemType.isFloat8E4M3FNUZ() || resultElemType.isFloat8E4M3FN())
     result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
 
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index d36583c8118ff4..a66c13caa6d0ab 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,8 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
       return failure();
     inType = inVecType.getElementType();
   }
-  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
+                 inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
 }
 
 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +217,11 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
   if (inType && inType.getWidth() <= 8 && saturateFP8)
     // Conversion between 8-bit floats is not supported with truncation enabled.
     return failure();
-  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+
+  return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
+                   chipset.isGfx940()) ||
+                  ((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
+                   chipset.hasOcpFp8())));
 }
 
 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 3943696364950f..2747eebebefa52 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -271,14 +271,16 @@ LogicalResult MFMAOp::verify() {
   }
 
   Type sourceBType = getSourceB().getType();
-  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ() ||
+      sourceElem.isFloat8E5M2() || sourceElem.isFloat8E4M3FN()) {
     int64_t sourceBLen = 1;
     Type sourceBElem = sourceBType;
     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
       sourceBLen = sourceBVector.getNumElements();
       sourceBElem = sourceBVector.getElementType();
     }
-    if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+    if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ() &&
+        !sourceBElem.isFloat8E5M2() && !sourceBElem.isFloat8E4M3FN())
       return emitOpError("expected both source operands to have f8 elements");
     if (sourceLen != sourceBLen)
       return emitOpError(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b78c372af77e64..963fd6fd7c0511 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -509,7 +509,9 @@ bool TosaValidation::isValidElementType(Type type) {
   if (isa<FloatType>(type)) {
     if (profile == TosaProfileEnum::BaseInference)
       return false;
-    return type.isF32() || type.isF16() || type.isBF16();
+    return type.isF32() || type.isF16() || type.isBF16() ||
+           type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() ||
+           type.isFloat8E4M3FN() || type.isFloat8E5M2();
   }
   if (auto intTy = dyn_cast<IntegerType>(type)) {
     if (intTy.isUnsigned()) {

@pcf000
Copy link
Contributor Author

pcf000 commented Aug 26, 2024

@krzysz00 @giuseros (I can't add reviewers.)

@krzysz00
Copy link
Contributor

krzysz00 commented Sep 6, 2024

@pcf000 There've been downstream changes that fix bugs here, could we add those to this PR

pcf000 and others added 3 commits September 15, 2024 23:46
Upcoming hardware (gfx12 and some future gfx9) will support the OCP
8-bit float formats for their matrix multiplication intrinsics and
conversion operations, retaining existing opcodes and compiler builtins.

This commit adds support for these types to the MLIR wrappers around
such operations, ensuring that the OCP types aren't used to generate
those builtins on hardware that doesn't expect that format and,
conversely, to ensure that the pre-OCP formats aren't used on new
hardware.
those operations were not being converted to the LLVM intrinsics they
correspond to because the rewrite patterns were still checking for
gfx940+.

As part of this, factor out tests for type-match isto isNativeFp8()
and isNativeBf8() functions in the AMDGPUToRocdl rewrites.

Also, fix a typo in isGfx940() that caused it to be true for gfx950.

Finally, test all these OCP format conversions by duplicating the
gfx940 tests.
@pcf000 pcf000 changed the title [MLIR][AMDGPU] Implement emulated F8 for the OCP formats. [MLIR][AMDGPU] Add OCP FP8 support for new hardware Sep 17, 2024
mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h Outdated Show resolved Hide resolved
Comment on lines 459 to 466
static bool isNativeBf8(Chipset chipset, Type type) {
return (chipset.isGfx940() && type.isFloat8E5M2FNUZ()) ||
(chipset.hasOcpFp8() && type.isFloat8E5M2());
}

/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool isNativeFp8(Chipset chipset, Type type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: hasNativeF8 to follow the type naming scheme in llvm/mlir?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... yeah, I can see hasNaviveF8 or isNativeF8 if we switched the argument order

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file uses "FP8" extensively, but both as a generic term for Float8 and as the specific for 4.3 (with "BF8" for 5.2). The two isNative functions are doing the latter.

@@ -68,6 +72,15 @@ struct TruncfToFloat16RewritePattern final

} // end namespace

static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: supportsF8?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we've been using Fp8 in this file? But my memory of having written this code ages ago could be hazy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The various ops use "fp8" as do the saturate parameters. I made this guy isSupportedF8().

Comment on lines 77 to 80
return success(elementType.isFloat8E5M2FNUZ() ||
elementType.isFloat8E4M3FNUZ());
if (chipset.hasOcpFp8())
return success(elementType.isFloat8E5M2() || elementType.isFloat8E4M3FN());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use isa<T1, T2>(elementType)

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp Outdated Show resolved Hide resolved
@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Sep 18, 2024
@@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
}
}

/// Return true if `type` is the E5M2 variant of an 8-bit float that is
/// supported by the `_bf8` instructions on the given `chipset`.
static bool isNativeBf8(Chipset chipset, Type type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static bool isNativeBf8(Chipset chipset, Type type) {
static bool hasNativeBF8(Chipset chipset, Type type) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a query about the type, not about the chipset, I'm happier with an "is" form. Since I'm wordy, how about isNativeBf8ForChipset(Type,Chipset)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the distinction. The ISA and the supported types are the property of the chipset, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, yes, but all chipsets have a native BF8 type -- FNUZ for the older ones, OCP for the newer -- so the answer to "has native BF8" is "yes". Rather, we want to know "is this type the same as the chipset's native BF8 type?" Also, in a context in which we're doing isa<T1,T2>() and isFloat(), "is" flows better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, in a context in which we're doing isa<T1,T2>() and isFloat(), "is" flows better.

The issue with chains of .isType is that it won't be optimized across multiple types and always performed one by one (with short circuiting). Our casting infra might not do it today, but in order to utilize faster 'group' isa, you have to use it across the codebase in the first place.

Copy link
Member

@kuhar kuhar Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"is this type the same as the chipset's native BF8 type?"

Thanks for the explanation, I see the difference now. Maybe something like this then isSupportedByNativeBf8(Chipset, Type)?

If you prefer the current name, could you add a comment that explains what the intention is (similar to how you explained it above)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An equivalent way of phrasing it, from where I'm standing, is "is this one of the types that the _fp8 or _bf8 instructions on the given chipset will be expecting"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous job would have called it "typeIsNativelySupportedBf8ForChipset(Type,Chipset)" -- we liked naming the parameters in the function name -- or "typeIsExpectedBf8ForChipset(Type,Chipset)". As I wrote those out, I think I like "expected" better than "native" or "supported," but I could be persuaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm ... isImplementedBf8OnChipset()?

But Expected works too

Comment on lines +143 to +145
/// Return true if this is an float type (with the specified width).
bool isFloat() const;
bool isFloat(unsigned width) const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
/// supported by the `_fp8` instructions on the given `chipset`.
static bool isNativeFp8(Chipset chipset, Type type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static bool isNativeFp8(Chipset chipset, Type type) {
static bool hasNativeFp8(Chipset chipset, Type type) {

Comment on lines 512 to 514
return type.isF32() || type.isF16() || type.isBF16() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() ||
type.isFloat8E4M3FN() || type.isFloat8E5M2();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use isa

@pcf000 pcf000 requested a review from kuhar October 18, 2024 18:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants