Skip to content

[mlir][IR] Remove isF...() type API for low-precision FP types #123326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2025

Conversation

matthias-springer
Copy link
Member

Remove type.isFloat4E2M1FN() etc. Use isa<Float4E2M1FNType>(type) instead.

For details, see: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28

Depends on #123321.

@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-nvgpu
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-backend-amdgpu

Author: Matthias Springer (matthias-springer)

Changes

Remove type.isFloat4E2M1FN() etc. Use isa&lt;Float4E2M1FNType&gt;(type) instead.

For details, see: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28

Depends on #123321.


Patch is 22.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123326.diff

11 Files Affected:

  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+13-13)
  • (modified) mlir/include/mlir/IR/Types.h (-11)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+24-16)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+20-18)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+2-2)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+4-5)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-4)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+2-2)
  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+1-2)
  • (modified) mlir/lib/IR/Types.cpp (-19)
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 6f52195c1d7c92..e752cdfb47fbb1 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -329,31 +329,31 @@ def F64 : F<64>;
 def F80 : F<80>;
 def F128 : F<128>;
 
-def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
+def BF16 : Type<CPred<"::llvm::isa<BFloat16Type>($_self)">, "bfloat16 type">,
            BuildableType<"$_builder.getType<BFloat16Type>()">;
-def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
+def TF32 : Type<CPred<"::llvm::isa<FloatTF32Type>($_self)">, "tf32 type">,
            BuildableType<"$_builder.getType<FloatTF32Type>()">;
-def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
+def F8E4M3FN : Type<CPred<"::llvm::isa<Float8E4M3FNType>($_self)">, "f8E4M3FN type">,
                BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
-def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
+def F8E5M2 : Type<CPred<"::llvm::isa<Float8E5M2Type>($_self)">, "f8E5M2 type">,
              BuildableType<"$_builder.getType<Float8E5M2Type>()">;
-def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
+def F8E4M3 : Type<CPred<"::llvm::isa<Float8E4M3Type>($_self)">, "f8E4M3 type">,
              BuildableType<"$_builder.getType<Float8E4M3Type>()">;
-def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
+def F8E4M3FNUZ : Type<CPred<"::llvm::isa<Float8E4M3FNUZType>($_self)">, "f8E4M3FNUZ type">,
                  BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
-def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
+def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<Float8E4M3B11FNUZType>($_self)">, "f8E4M3B11FNUZ type">,
                  BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
-def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
+def F8E5M2FNUZ : Type<CPred<"::llvm::isa<Float8E5M2FNUZType>($_self)">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
-def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
+def F8E3M4 : Type<CPred<"::llvm::isa<Float8E3M4Type>($_self)">, "f8E3M4 type">,
              BuildableType<"$_builder.getType<Float8E3M4Type>()">;
-def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
+def F4E2M1FN : Type<CPred<"::llvm::isa<Float4E2M1FNType>($_self)">, "f4E2M1FN type">,
                BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
-def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
+def F6E2M3FN : Type<CPred<"::llvm::isa<Float6E2M3FNType>($_self)">, "f6E2M3FN type">,
                BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
-def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
+def F6E3M2FN : Type<CPred<"::llvm::isa<Float6E3M2FNType($_self)">, "f6E3M2FN type">,
                BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
-def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
+def F8E8M0FNU : Type<CPred<"::llvm::isa<Float8E8M0FNUType>($_self)">, "f8E8M0FNU type">,
                 BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index acd0f894abbbe6..0e82ad2be907ab 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -125,17 +125,6 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
-  bool isFloat4E2M1FN() const;
-  bool isFloat6E2M3FN() const;
-  bool isFloat6E3M2FN() const;
-  bool isFloat8E5M2() const;
-  bool isFloat8E4M3() const;
-  bool isFloat8E4M3FN() const;
-  bool isFloat8E5M2FNUZ() const;
-  bool isFloat8E4M3FNUZ() const;
-  bool isFloat8E4M3B11FNUZ() const;
-  bool isFloat8E3M4() const;
-  bool isFloat8E8M0FNU() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 250e4a6bbf8dfd..313d6830b41b2a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -90,7 +90,7 @@ MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
-  return unwrap(type).isFloat4E2M1FN();
+  return llvm::isa<Float4E2M1FNType>(unwrap(type));
 }
 
 MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
@@ -102,7 +102,7 @@ MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
-  return unwrap(type).isFloat6E2M3FN();
+  return llvm::isa<Float6E2M3FNType>(unwrap(type));
 }
 
 MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
@@ -114,7 +114,7 @@ MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
-  return unwrap(type).isFloat6E3M2FN();
+  return llvm::isa<Float6E3M2FNType>(unwrap(type));
 }
 
 MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
@@ -126,7 +126,7 @@ MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E5M2(MlirType type) {
-  return unwrap(type).isFloat8E5M2();
+  return llvm::isa<Float8E5M2Type>(unwrap(type));
 }
 
 MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
@@ -138,7 +138,7 @@ MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E4M3(MlirType type) {
-  return unwrap(type).isFloat8E4M3();
+  return llvm::isa<Float8E4M3Type>(unwrap(type));
 }
 
 MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
@@ -150,7 +150,7 @@ MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
-  return unwrap(type).isFloat8E4M3FN();
+  return llvm::isa<Float8E4M3FNType>(unwrap(type));
 }
 
 MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
@@ -162,7 +162,7 @@ MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
-  return unwrap(type).isFloat8E5M2FNUZ();
+  return llvm::isa<Float8E5M2FNUZType>(unwrap(type));
 }
 
 MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
@@ -174,7 +174,7 @@ MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
-  return unwrap(type).isFloat8E4M3FNUZ();
+  return llvm::isa<Float8E4M3FNUZType>(unwrap(type));
 }
 
 MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
@@ -186,7 +186,7 @@ MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
-  return unwrap(type).isFloat8E4M3B11FNUZ();
+  return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type));
 }
 
 MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
@@ -198,7 +198,7 @@ MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E3M4(MlirType type) {
-  return unwrap(type).isFloat8E3M4();
+  return llvm::isa<Float8E3M4Type>(unwrap(type));
 }
 
 MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
@@ -210,7 +210,7 @@ MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
 }
 
 bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
-  return unwrap(type).isFloat8E8M0FNU();
+  return llvm::isa<Float8E8M0FNUType>(unwrap(type));
 }
 
 MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
@@ -221,7 +221,9 @@ MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
 
-bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
+bool mlirTypeIsABF16(MlirType type) {
+  return llvm::isa<BFloat16Type>(unwrap(type));
+}
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {
   return wrap(BFloat16Type::get(unwrap(ctx)));
@@ -229,7 +231,9 @@ MlirType mlirBF16TypeGet(MlirContext ctx) {
 
 MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
 
-bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
+bool mlirTypeIsAF16(MlirType type) {
+  return llvm::isa<Float16Type>(unwrap(type));
+}
 
 MlirType mlirF16TypeGet(MlirContext ctx) {
   return wrap(Float16Type::get(unwrap(ctx)));
@@ -239,7 +243,7 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() {
   return wrap(FloatTF32Type::getTypeID());
 }
 
-bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
+bool mlirTypeIsATF32(MlirType type) { return llvm::isa<FloatTF32Type>(type); }
 
 MlirType mlirTF32TypeGet(MlirContext ctx) {
   return wrap(FloatTF32Type::get(unwrap(ctx)));
@@ -247,7 +251,9 @@ MlirType mlirTF32TypeGet(MlirContext ctx) {
 
 MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
 
-bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
+bool mlirTypeIsAF32(MlirType type) {
+  return llvm::isa<Float32Type>(unwrap(type));
+}
 
 MlirType mlirF32TypeGet(MlirContext ctx) {
   return wrap(Float32Type::get(unwrap(ctx)));
@@ -255,7 +261,9 @@ MlirType mlirF32TypeGet(MlirContext ctx) {
 
 MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
 
-bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
+bool mlirTypeIsAF64(MlirType type) {
+  return llvm::isa<Float64Type>(unwrap(type));
+}
 
 MlirType mlirF64TypeGet(MlirContext ctx) {
   return wrap(Float64Type::get(unwrap(ctx)));
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 1564e417a7a48e..5d09d6f1d69523 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -564,38 +564,40 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
       return ROCDL::mfma_f64_4x4x4f64::getOperationName();
   }
 
-  if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+  if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
+      chipset >= kGfx940) {
     // Known to be correct because there are no scalar f8 instructions and
     // because a length mismatch will have been caught by the verifier.
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (isa<Float8E5M2FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (isa<Float8E4M3FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (isa<Float8E5M2FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (isa<Float8E4M3FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
     }
   }
 
-  if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+  if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
+      chipset >= kGfx940) {
     Type sourceBElem =
         cast<VectorType>(mfma.getSourceB().getType()).getElementType();
     if (m == 16 && n == 16 && k == 32 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (isa<Float8E5M2FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (isa<Float8E4M3FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
     }
     if (m == 32 && n == 32 && k == 16 && b == 1) {
-      if (sourceBElem.isFloat8E5M2FNUZ())
+      if (isa<Float8E5M2FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
-      if (sourceBElem.isFloat8E4M3FNUZ())
+      if (isa<Float8E4M3FNUZType>(sourceBElem))
         return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
     }
   }
@@ -623,9 +625,9 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
   if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
+  if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
     return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
-  if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
+  if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
     return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
   return std::nullopt;
 }
@@ -803,10 +805,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   }
   Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
   Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
-  if (sourceElemType.isFloat8E5M2FNUZ()) {
+  if (isa<Float8E5M2FNUZType>(sourceElemType)) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
                                                     wordSel);
-  } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+  } else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
     rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
                                                     wordSel);
   }
@@ -838,10 +840,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
   Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (isa<Float8E5M2FNUZType>(resultElemType))
     result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (isa<Float8E4M3FNUZType>(resultElemType))
     result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
                                                    existing, wordSel);
 
@@ -873,10 +875,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
   Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
 
   Value result;
-  if (resultElemType.isFloat8E5M2FNUZ())
+  if (isa<Float8E5M2FNUZType>(resultElemType))
     result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
-  else if (resultElemType.isFloat8E4M3FNUZ())
+  else if (isa<Float8E4M3FNUZType>(resultElemType))
     result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
                                                    existing, byteSel);
 
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index a8283023afc53d..33370566996eee 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
       return failure();
     inType = inVecType.getElementType();
   }
-  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+  return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
 }
 
 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +216,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
   if (inType && inType.getWidth() <= 8 && saturateFP8)
     // Conversion between 8-bit floats is not supported with truncation enabled.
     return failure();
-  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+  return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
 }
 
 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 64bdb248dff430..247a8ab28a44be 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -299,11 +299,10 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
     return type;
 
   // F4, F6, F8 types are converted to integer types with the same bit width.
-  if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
-      type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
-      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
-      type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
-      type.isFloat8E8M0FNU())
+  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+          Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
+          Float8E8M0FNUType>(type))
     return IntegerType::get(&getContext(), type.getWidth());
 
   // Other floating-point types: A custom type conversion rule must be
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 34a6b1d506540d..7e97fb84434f89 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1254,8 +1254,8 @@ struct NVGPUWarpgroupMmaOpLowering
         wgmmaK = 8;
       } else if (inputElemType.isF16() || inputElemType.isBF16()) {
         wgmmaK = 16;
-      } else if (inputElemType.isFloat8E4M3FN() ||
-                 inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+      } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
+                 inputElemType.isInteger(16)) {
         wgmmaK = 32;
       } else if (inputElemType.isInteger(1)) {
         wgmmaK = 256;
@@ -1276,9 +1276,9 @@ struct NVGPUWarpgroupMmaOpLowering
           return NVVM::WGMMATypes::f16;
         if (elemType.isBF16())
           return NVVM::WGMMATypes::bf16;
-        if (elemType.isFloat8E4M3FN())
+        if (isa<Float8E4M3FNType>(elemType))
           return NVVM::WGMMATypes::e4m3;
-        if (elemType.isFloat8E5M2())
+        if (isa<Float8E5M2Type>(elemType))
           return NVVM::WGMMATypes::e5m2;
         if (elemType.isInteger(1))
           return NVVM::WGMMATypes::b1;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 492e4781f57810..5af0cb0c7ba1cc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
   }
 
   Type sourceBType = getSourceB().getType();
-  if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+  if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
     int64_t sourceBLen = 1;
     Type sourceBElem = sourceBType;
     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
       sourceBLen = sourceBVector.getNumElements();
       sourceBElem = sourceBVector.getElementType();
     }
-    if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+    if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
       return emitOpError("expected both source operands to have f8 elements");
     if (sourceLen != sourceBLen)
       return emitOpError(
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index a027350e8a5f70..47d1b8492e06ec 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -525,8 +525,8 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
     return success();
   // F16 += f8 + f8
   // F32 += f8 + f8
-  if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) &&
-      (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) &&
+  if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
+      isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
       (typeD.isF32() || typeD.isF16()))
     return success();
 
@@ -548,7 +548,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
                                     80,  96,  112, 128, 144, 160,
                                     176, 192, 208, 224, 240, 256};
   if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
-      typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
+      isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
     if (llvm::is_contained(allowedN, sizeN))
       return success();
 
diff --git a/mlir/lib/Dial...
[truncated]

@matthias-springer matthias-springer force-pushed the users/matthias-springer/fp_type_api branch from 55825a9 to bfe10b1 Compare January 17, 2025 11:43
@matthias-springer matthias-springer force-pushed the users/matthias-springer/fp_type_api branch from bfe10b1 to 51db6eb Compare January 17, 2025 11:45
@matthias-springer matthias-springer force-pushed the users/matthias-springer/fp_builder_api branch from 8c85f1f to 2e1833f Compare January 17, 2025 17:10
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Base automatically changed from users/matthias-springer/fp_builder_api to main January 18, 2025 09:38
@matthias-springer matthias-springer force-pushed the users/matthias-springer/fp_type_api branch from 51db6eb to dfe8194 Compare January 20, 2025 08:13
@matthias-springer matthias-springer merged commit 7a77f14 into main Jan 20, 2025
5 of 6 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/fp_type_api branch January 20, 2025 08:22
nirvedhmeshram added a commit to iree-org/llvm-project that referenced this pull request Jan 20, 2025
cota added a commit to cota/llvm-project that referenced this pull request Jan 21, 2025
…s isa<> calls in isa<> calls

To ease integration with downstream projects.

Follow-up to PR llvm#123326.
cota added a commit that referenced this pull request Jan 21, 2025
#123738)

…s isa<> calls in isa<> calls

To ease integration with downstream projects.

Follow-up to PR #123326.
marbre pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2025
AndreyPavlenko added a commit to AndreyPavlenko/intel-xpu-backend-for-triton that referenced this pull request Jan 26, 2025
The isF...() methods have been removed in the main LLVM branch:
llvm/llvm-project#123326
AndreyPavlenko added a commit to AndreyPavlenko/triton that referenced this pull request Jan 27, 2025
The isF...() methods have been removed in the main LLVM branch:
llvm/llvm-project#123326
AndreyPavlenko added a commit to AndreyPavlenko/intel-xpu-backend-for-triton that referenced this pull request Jan 28, 2025
The isF...() methods have been removed in the main LLVM branch:
llvm/llvm-project#123326
peterbell10 pushed a commit to triton-lang/triton that referenced this pull request Jan 29, 2025
justin-ngo-arm added a commit to justin-ngo-arm/torch-mlir that referenced this pull request Jan 29, 2025
* `isF...()` APIs are deprecated for low-precision FP types like
  Float4E2M1FN according to this LLVM PR:
  llvm/llvm-project#123326
* These API usages are replaced with `isa<Float...Type>(type)` instead.
  For example: `isa<Float4E2M1FNType>(type)`

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: Ia6073987cb4348fa4a701bc182f961786e85e20c
AndreyPavlenko added a commit to AndreyPavlenko/intel-xpu-backend-for-triton that referenced this pull request Jan 30, 2025
The isF...() methods have been removed in the main LLVM branch:
llvm/llvm-project#123326
makslevental added a commit to makslevental/triton that referenced this pull request Feb 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants