-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][arith] Fix arith.cmpf lowering with unsupported FP types
#166684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesThe Full diff: https://github.com/llvm/llvm-project/pull/166684.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..297af1a047f77 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -59,6 +59,12 @@ LogicalResult vectorOneToOneRewrite(
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
+
+/// Return "true" if the given type is an unsupported floating point type. In
+/// case of a vector type, return "true" if the element type is an unsupported
+/// floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+ Type type);
} // namespace detail
} // namespace LLVM
@@ -92,16 +98,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
- /// Return the given type if it's a floating point type. If the given type is
- /// a vector type, return its element type if it's a floating point type.
- static FloatType getFloatingPointType(Type type) {
- if (auto floatType = dyn_cast<FloatType>(type))
- return floatType;
- if (auto vecType = dyn_cast<VectorType>(type))
- return dyn_cast<FloatType>(vecType.getElementType());
- return nullptr;
- }
-
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -109,25 +105,16 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
- // The pattern should not apply if a floating-point operand is converted to
- // a non-floating-point type. This indicates that the floating point type
- // is not supported by the LLVM lowering. (Such types are converted to
- // integers.)
- auto checkType = [&](Value v) -> LogicalResult {
- FloatType floatType = getFloatingPointType(v.getType());
- if (!floatType)
- return success();
- Type convertedType = this->getTypeConverter()->convertType(floatType);
- if (!isa_and_nonnull<FloatType>(convertedType))
+ // Bail on unsupported floating point types. (These are type-converted to
+ // integer types.)
+ for (Value operand : op->getOperands())
+ if (LLVM::detail::isUnsupportedFloatingPointType(
+ *this->getTypeConverter(), operand.getType()))
return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
- return success();
- };
- for (Value operand : op->getOperands())
- if (failed(checkType(operand)))
- return failure();
- if (failed(checkType(op->getResult(0))))
- return failure();
+ if (LLVM::detail::isUnsupportedFloatingPointType(
+ *this->getTypeConverter(), op->getResult(0).getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..71a8128986e7f 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -459,6 +459,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
LogicalResult
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
+ if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
+ op.getLhs().getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
Type operandType = adaptor.getLhs().getType();
Type resultType = op.getResult().getType();
LLVM::FastmathFlags fmf =
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b506e12d..b37b35d79901c 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -131,3 +131,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+ const TypeConverter &typeConverter, Type type) {
+ FloatType floatType = getFloatingPointType(type);
+ if (!floatType)
+ return false;
+ Type convertedType = typeConverter.convertType(floatType);
+ if (!convertedType)
+ return true;
+ return !isa<FloatType>(convertedType);
+}
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..8cf736c0a559f 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
// CHECK: arith.addf {{.*}} : f4E2M1FN
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+// CHECK: arith.cmpf {{.*}} : f4E2M1FN
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN) {
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
- return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+ %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
+ return
}
// -----
@@ -767,9 +769,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2
// CHECK: llvm.fadd {{.*}} : f32
// CHECK: llvm.fadd {{.*}} : vector<4xf32>
// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
-func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+// CHECK: llvm.fcmp {{.*}} : f32
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
%0 = arith.addf %arg0, %arg0 : f32
%1 = arith.addf %arg1, %arg1 : vector<4xf32>
%2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
- return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+ %3 = arith.cmpf oeq, %arg0, %arg3 : f32
+ return
}
|
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe Full diff: https://github.com/llvm/llvm-project/pull/166684.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..297af1a047f77 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -59,6 +59,12 @@ LogicalResult vectorOneToOneRewrite(
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
+
+/// Return "true" if the given type is an unsupported floating point type. In
+/// case of a vector type, return "true" if the element type is an unsupported
+/// floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+ Type type);
} // namespace detail
} // namespace LLVM
@@ -92,16 +98,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
- /// Return the given type if it's a floating point type. If the given type is
- /// a vector type, return its element type if it's a floating point type.
- static FloatType getFloatingPointType(Type type) {
- if (auto floatType = dyn_cast<FloatType>(type))
- return floatType;
- if (auto vecType = dyn_cast<VectorType>(type))
- return dyn_cast<FloatType>(vecType.getElementType());
- return nullptr;
- }
-
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -109,25 +105,16 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
- // The pattern should not apply if a floating-point operand is converted to
- // a non-floating-point type. This indicates that the floating point type
- // is not supported by the LLVM lowering. (Such types are converted to
- // integers.)
- auto checkType = [&](Value v) -> LogicalResult {
- FloatType floatType = getFloatingPointType(v.getType());
- if (!floatType)
- return success();
- Type convertedType = this->getTypeConverter()->convertType(floatType);
- if (!isa_and_nonnull<FloatType>(convertedType))
+ // Bail on unsupported floating point types. (These are type-converted to
+ // integer types.)
+ for (Value operand : op->getOperands())
+ if (LLVM::detail::isUnsupportedFloatingPointType(
+ *this->getTypeConverter(), operand.getType()))
return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
- return success();
- };
- for (Value operand : op->getOperands())
- if (failed(checkType(operand)))
- return failure();
- if (failed(checkType(op->getResult(0))))
- return failure();
+ if (LLVM::detail::isUnsupportedFloatingPointType(
+ *this->getTypeConverter(), op->getResult(0).getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..71a8128986e7f 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -459,6 +459,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
LogicalResult
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
+ if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
+ op.getLhs().getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
Type operandType = adaptor.getLhs().getType();
Type resultType = op.getResult().getType();
LLVM::FastmathFlags fmf =
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b506e12d..b37b35d79901c 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -131,3 +131,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+ const TypeConverter &typeConverter, Type type) {
+ FloatType floatType = getFloatingPointType(type);
+ if (!floatType)
+ return false;
+ Type convertedType = typeConverter.convertType(floatType);
+ if (!convertedType)
+ return true;
+ return !isa<FloatType>(convertedType);
+}
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..8cf736c0a559f 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
// CHECK: arith.addf {{.*}} : f4E2M1FN
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+// CHECK: arith.cmpf {{.*}} : f4E2M1FN
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN) {
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
- return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+ %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
+ return
}
// -----
@@ -767,9 +769,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2
// CHECK: llvm.fadd {{.*}} : f32
// CHECK: llvm.fadd {{.*}} : vector<4xf32>
// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
-func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+// CHECK: llvm.fcmp {{.*}} : f32
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
%0 = arith.addf %arg0, %arg0 : f32
%1 = arith.addf %arg1, %arg1 : vector<4xf32>
%2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
- return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+ %3 = arith.cmpf oeq, %arg0, %arg3 : f32
+ return
}
|
makslevental
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9c42dcd to
b55e698
Compare
|
This still live? |
b55e698 to
a6287dc
Compare
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/34002 Here is the relevant piece of the build log for the reference |
The
arith.cmpflowering pattern used to generate invalid IR when an unsupported floating-point type was used.