-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][IR] Remove isF...()
type API for low-precision FP types
#123326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-backend-amdgpu Author: Matthias Springer (matthias-springer) ChangesRemove For details, see: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28 Depends on #123321. Patch is 22.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123326.diff 11 Files Affected:
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 6f52195c1d7c92..e752cdfb47fbb1 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -329,31 +329,31 @@ def F64 : F<64>;
def F80 : F<80>;
def F128 : F<128>;
-def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
+def BF16 : Type<CPred<"::llvm::isa<BFloat16Type>($_self)">, "bfloat16 type">,
BuildableType<"$_builder.getType<BFloat16Type>()">;
-def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
+def TF32 : Type<CPred<"::llvm::isa<FloatTF32Type>($_self)">, "tf32 type">,
BuildableType<"$_builder.getType<FloatTF32Type>()">;
-def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
+def F8E4M3FN : Type<CPred<"::llvm::isa<Float8E4M3FNType>($_self)">, "f8E4M3FN type">,
BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
-def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
+def F8E5M2 : Type<CPred<"::llvm::isa<Float8E5M2Type>($_self)">, "f8E5M2 type">,
BuildableType<"$_builder.getType<Float8E5M2Type>()">;
-def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
+def F8E4M3 : Type<CPred<"::llvm::isa<Float8E4M3Type>($_self)">, "f8E4M3 type">,
BuildableType<"$_builder.getType<Float8E4M3Type>()">;
-def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
+def F8E4M3FNUZ : Type<CPred<"::llvm::isa<Float8E4M3FNUZType>($_self)">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
-def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
+def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<Float8E4M3B11FNUZType>($_self)">, "f8E4M3B11FNUZ type">,
BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
-def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
+def F8E5M2FNUZ : Type<CPred<"::llvm::isa<Float8E5M2FNUZType>($_self)">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
-def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
+def F8E3M4 : Type<CPred<"::llvm::isa<Float8E3M4Type>($_self)">, "f8E3M4 type">,
BuildableType<"$_builder.getType<Float8E3M4Type>()">;
-def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
+def F4E2M1FN : Type<CPred<"::llvm::isa<Float4E2M1FNType>($_self)">, "f4E2M1FN type">,
BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
-def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
+def F6E2M3FN : Type<CPred<"::llvm::isa<Float6E2M3FNType>($_self)">, "f6E2M3FN type">,
BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
-def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
+def F6E3M2FN : Type<CPred<"::llvm::isa<Float6E3M2FNType($_self)">, "f6E3M2FN type">,
BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
-def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
+def F8E8M0FNU : Type<CPred<"::llvm::isa<Float8E8M0FNUType>($_self)">, "f8E8M0FNU type">,
BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index acd0f894abbbe6..0e82ad2be907ab 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -125,17 +125,6 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
- bool isFloat4E2M1FN() const;
- bool isFloat6E2M3FN() const;
- bool isFloat6E3M2FN() const;
- bool isFloat8E5M2() const;
- bool isFloat8E4M3() const;
- bool isFloat8E4M3FN() const;
- bool isFloat8E5M2FNUZ() const;
- bool isFloat8E4M3FNUZ() const;
- bool isFloat8E4M3B11FNUZ() const;
- bool isFloat8E3M4() const;
- bool isFloat8E8M0FNU() const;
bool isBF16() const;
bool isF16() const;
bool isTF32() const;
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 250e4a6bbf8dfd..313d6830b41b2a 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -90,7 +90,7 @@ MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
}
bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
- return unwrap(type).isFloat4E2M1FN();
+ return llvm::isa<Float4E2M1FNType>(unwrap(type));
}
MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
@@ -102,7 +102,7 @@ MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
}
bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
- return unwrap(type).isFloat6E2M3FN();
+ return llvm::isa<Float6E2M3FNType>(unwrap(type));
}
MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
@@ -114,7 +114,7 @@ MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
}
bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
- return unwrap(type).isFloat6E3M2FN();
+ return llvm::isa<Float6E3M2FNType>(unwrap(type));
}
MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
@@ -126,7 +126,7 @@ MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
}
bool mlirTypeIsAFloat8E5M2(MlirType type) {
- return unwrap(type).isFloat8E5M2();
+ return llvm::isa<Float8E5M2Type>(unwrap(type));
}
MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
@@ -138,7 +138,7 @@ MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
}
bool mlirTypeIsAFloat8E4M3(MlirType type) {
- return unwrap(type).isFloat8E4M3();
+ return llvm::isa<Float8E4M3Type>(unwrap(type));
}
MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
@@ -150,7 +150,7 @@ MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
- return unwrap(type).isFloat8E4M3FN();
+ return llvm::isa<Float8E4M3FNType>(unwrap(type));
}
MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
@@ -162,7 +162,7 @@ MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
- return unwrap(type).isFloat8E5M2FNUZ();
+ return llvm::isa<Float8E5M2FNUZType>(unwrap(type));
}
MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
@@ -174,7 +174,7 @@ MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
- return unwrap(type).isFloat8E4M3FNUZ();
+ return llvm::isa<Float8E4M3FNUZType>(unwrap(type));
}
MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
@@ -186,7 +186,7 @@ MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
- return unwrap(type).isFloat8E4M3B11FNUZ();
+ return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type));
}
MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
@@ -198,7 +198,7 @@ MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
}
bool mlirTypeIsAFloat8E3M4(MlirType type) {
- return unwrap(type).isFloat8E3M4();
+ return llvm::isa<Float8E3M4Type>(unwrap(type));
}
MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
@@ -210,7 +210,7 @@ MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
}
bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
- return unwrap(type).isFloat8E8M0FNU();
+ return llvm::isa<Float8E8M0FNUType>(unwrap(type));
}
MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
@@ -221,7 +221,9 @@ MlirTypeID mlirBFloat16TypeGetTypeID() {
return wrap(BFloat16Type::getTypeID());
}
-bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
+bool mlirTypeIsABF16(MlirType type) {
+ return llvm::isa<BFloat16Type>(unwrap(type));
+}
MlirType mlirBF16TypeGet(MlirContext ctx) {
return wrap(BFloat16Type::get(unwrap(ctx)));
@@ -229,7 +231,9 @@ MlirType mlirBF16TypeGet(MlirContext ctx) {
MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
-bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
+bool mlirTypeIsAF16(MlirType type) {
+ return llvm::isa<Float16Type>(unwrap(type));
+}
MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(Float16Type::get(unwrap(ctx)));
@@ -239,7 +243,7 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() {
return wrap(FloatTF32Type::getTypeID());
}
-bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
+bool mlirTypeIsATF32(MlirType type) { return llvm::isa<FloatTF32Type>(type); }
MlirType mlirTF32TypeGet(MlirContext ctx) {
return wrap(FloatTF32Type::get(unwrap(ctx)));
@@ -247,7 +251,9 @@ MlirType mlirTF32TypeGet(MlirContext ctx) {
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
-bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
+bool mlirTypeIsAF32(MlirType type) {
+ return llvm::isa<Float32Type>(unwrap(type));
+}
MlirType mlirF32TypeGet(MlirContext ctx) {
return wrap(Float32Type::get(unwrap(ctx)));
@@ -255,7 +261,9 @@ MlirType mlirF32TypeGet(MlirContext ctx) {
MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
-bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
+bool mlirTypeIsAF64(MlirType type) {
+ return llvm::isa<Float64Type>(unwrap(type));
+}
MlirType mlirF64TypeGet(MlirContext ctx) {
return wrap(Float64Type::get(unwrap(ctx)));
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 1564e417a7a48e..5d09d6f1d69523 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -564,38 +564,40 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}
- if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
+ chipset >= kGfx940) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}
- if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
+ if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
+ chipset >= kGfx940) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
- if (sourceBElem.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
- if (sourceBElem.isFloat8E4M3FNUZ())
+ if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
@@ -623,9 +625,9 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
+ if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
- if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
+ if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
return std::nullopt;
}
@@ -803,10 +805,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
- if (sourceElemType.isFloat8E5M2FNUZ()) {
+ if (isa<Float8E5M2FNUZType>(sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
- } else if (sourceElemType.isFloat8E4M3FNUZ()) {
+ } else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
@@ -838,10 +840,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (isa<Float8E4M3FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
@@ -873,10 +875,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
- if (resultElemType.isFloat8E5M2FNUZ())
+ if (isa<Float8E5M2FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
- else if (resultElemType.isFloat8E4M3FNUZ())
+ else if (isa<Float8E4M3FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index a8283023afc53d..33370566996eee 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -86,7 +86,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
- return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
+ return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
}
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +216,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
- return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
+ return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
}
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 64bdb248dff430..247a8ab28a44be 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -299,11 +299,10 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
return type;
// F4, F6, F8 types are converted to integer types with the same bit width.
- if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
- type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
- type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
- type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
- type.isFloat8E8M0FNU())
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
+ Float8E8M0FNUType>(type))
return IntegerType::get(&getContext(), type.getWidth());
// Other floating-point types: A custom type conversion rule must be
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 34a6b1d506540d..7e97fb84434f89 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1254,8 +1254,8 @@ struct NVGPUWarpgroupMmaOpLowering
wgmmaK = 8;
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
wgmmaK = 16;
- } else if (inputElemType.isFloat8E4M3FN() ||
- inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
+ } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
+ inputElemType.isInteger(16)) {
wgmmaK = 32;
} else if (inputElemType.isInteger(1)) {
wgmmaK = 256;
@@ -1276,9 +1276,9 @@ struct NVGPUWarpgroupMmaOpLowering
return NVVM::WGMMATypes::f16;
if (elemType.isBF16())
return NVVM::WGMMATypes::bf16;
- if (elemType.isFloat8E4M3FN())
+ if (isa<Float8E4M3FNType>(elemType))
return NVVM::WGMMATypes::e4m3;
- if (elemType.isFloat8E5M2())
+ if (isa<Float8E5M2Type>(elemType))
return NVVM::WGMMATypes::e5m2;
if (elemType.isInteger(1))
return NVVM::WGMMATypes::b1;
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 492e4781f57810..5af0cb0c7ba1cc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
}
Type sourceBType = getSourceB().getType();
- if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
+ if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
- if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
+ if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index a027350e8a5f70..47d1b8492e06ec 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -525,8 +525,8 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
return success();
// F16 += f8 + f8
// F32 += f8 + f8
- if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) &&
- (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) &&
+ if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
+ isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
(typeD.isF32() || typeD.isF16()))
return success();
@@ -548,7 +548,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
80, 96, 112, 128, 144, 160,
176, 192, 208, 224, 240, 256};
if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
- typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
+ isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
if (llvm::is_contained(allowedN, sizeN))
return success();
diff --git a/mlir/lib/Dial...
[truncated]
|
55825a9
to
bfe10b1
Compare
bfe10b1
to
51db6eb
Compare
8c85f1f
to
2e1833f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
51db6eb
to
dfe8194
Compare
…pes (llvm#123326)" This reverts commit 7a77f14.
…s isa<> calls in isa<> calls To ease integration with downstream projects. Follow-up to PR llvm#123326.
…pes (llvm#123326)" This reverts commit 7a77f14.
The isF...() methods have been removed in the main LLVM branch: llvm/llvm-project#123326
The isF...() methods have been removed in the main LLVM branch: llvm/llvm-project#123326
The isF...() methods have been removed in the main LLVM branch: llvm/llvm-project#123326
Pulls in llvm/llvm-project#123200 which is useful and also handles #5664. Integrations were required due to llvm/llvm-project#123026, llvm/llvm-project#123321 and llvm/llvm-project#123326. Also closes #5685
* `isF...()` APIs are deprecated for low-precision FP types like Float4E2M1FN according to this LLVM PR: llvm/llvm-project#123326 * These API usages are replaced with `isa<Float...Type>(type)` instead. For example: `isa<Float4E2M1FNType>(type)` Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: Ia6073987cb4348fa4a701bc182f961786e85e20c
The isF...() methods have been removed in the main LLVM branch: llvm/llvm-project#123326
Pulls in llvm/llvm-project#123200 which is useful and also handles triton-lang#5664. Integrations were required due to llvm/llvm-project#123026, llvm/llvm-project#123321 and llvm/llvm-project#123326. Also closes triton-lang#5685
Remove
type.isFloat4E2M1FN()
etc. Useisa<Float4E2M1FNType>(type)
instead.For details, see: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28
Depends on #123321.