Skip to content

Conversation

@matthias-springer
Copy link
Member

The arith.cmpf lowering pattern used to generate invalid IR when an unsupported floating-point type was used.

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

The arith.cmpf lowering pattern used to generate invalid IR when an unsupported floating-point type was used.


Full diff: https://github.com/llvm/llvm-project/pull/166684.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+14-27)
  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+4)
  • (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (+21)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+8-4)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..297af1a047f77 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -59,6 +59,12 @@ LogicalResult vectorOneToOneRewrite(
     ArrayRef<NamedAttribute> targetAttrs,
     const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
     IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
+
+/// Return "true" if the given type is an unsupported floating point type. In
+/// case of a vector type, return "true" if the element type is an unsupported
+/// floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+                                    Type type);
 } // namespace detail
 } // namespace LLVM
 
@@ -92,16 +98,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
-  /// Return the given type if it's a floating point type. If the given type is
-  /// a vector type, return its element type if it's a floating point type.
-  static FloatType getFloatingPointType(Type type) {
-    if (auto floatType = dyn_cast<FloatType>(type))
-      return floatType;
-    if (auto vecType = dyn_cast<VectorType>(type))
-      return dyn_cast<FloatType>(vecType.getElementType());
-    return nullptr;
-  }
-
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -109,25 +105,16 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
 
-    // The pattern should not apply if a floating-point operand is converted to
-    // a non-floating-point type. This indicates that the floating point type
-    // is not supported by the LLVM lowering. (Such types are converted to
-    // integers.)
-    auto checkType = [&](Value v) -> LogicalResult {
-      FloatType floatType = getFloatingPointType(v.getType());
-      if (!floatType)
-        return success();
-      Type convertedType = this->getTypeConverter()->convertType(floatType);
-      if (!isa_and_nonnull<FloatType>(convertedType))
+    // Bail on unsupported floating point types. (These are type-converted to
+    // integer types.)
+    for (Value operand : op->getOperands())
+      if (LLVM::detail::isUnsupportedFloatingPointType(
+              *this->getTypeConverter(), operand.getType()))
         return rewriter.notifyMatchFailure(op,
                                            "unsupported floating point type");
-      return success();
-    };
-    for (Value operand : op->getOperands())
-      if (failed(checkType(operand)))
-        return failure();
-    if (failed(checkType(op->getResult(0))))
-      return failure();
+    if (LLVM::detail::isUnsupportedFloatingPointType(
+            *this->getTypeConverter(), op->getResult(0).getType()))
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
 
     // Determine attributes for the target op
     AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..71a8128986e7f 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -459,6 +459,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
 LogicalResult
 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
+  if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
+                                                   op.getLhs().getType()))
+    return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
   Type operandType = adaptor.getLhs().getType();
   Type resultType = op.getResult().getType();
   LLVM::FastmathFlags fmf =
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b506e12d..b37b35d79901c 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -131,3 +131,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
                                        rewriter);
 }
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+  if (auto floatType = dyn_cast<FloatType>(type))
+    return floatType;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return dyn_cast<FloatType>(vecType.getElementType());
+  return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+    const TypeConverter &typeConverter, Type type) {
+  FloatType floatType = getFloatingPointType(type);
+  if (!floatType)
+    return false;
+  Type convertedType = typeConverter.convertType(floatType);
+  if (!convertedType)
+    return true;
+  return !isa<FloatType>(convertedType);
+}
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..8cf736c0a559f 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
 //       CHECK:   arith.addf {{.*}} : f4E2M1FN
 //       CHECK:   arith.addf {{.*}} : vector<4xf4E2M1FN>
 //       CHECK:   arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+//       CHECK:   arith.cmpf {{.*}} : f4E2M1FN
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN) {
   %0 = arith.addf %arg0, %arg0 : f4E2M1FN
   %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
   %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
-  return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+  %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
+  return
 }
 
 // -----
@@ -767,9 +769,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2
 //         CHECK:   llvm.fadd {{.*}} : f32
 //         CHECK:   llvm.fadd {{.*}} : vector<4xf32>
 // CHECK-COUNT-4:   llvm.fadd {{.*}} : vector<8xf32>
-func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+//         CHECK:   llvm.fcmp {{.*}} : f32
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
   %0 = arith.addf %arg0, %arg0 : f32
   %1 = arith.addf %arg1, %arg1 : vector<4xf32>
   %2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
-  return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+  %3 = arith.cmpf oeq, %arg0, %arg3 : f32
+  return
 }

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The arith.cmpf lowering pattern used to generate invalid IR when an unsupported floating-point type was used.


Full diff: https://github.com/llvm/llvm-project/pull/166684.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+14-27)
  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+4)
  • (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (+21)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+8-4)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..297af1a047f77 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -59,6 +59,12 @@ LogicalResult vectorOneToOneRewrite(
     ArrayRef<NamedAttribute> targetAttrs,
     const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
     IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
+
+/// Return "true" if the given type is an unsupported floating point type. In
+/// case of a vector type, return "true" if the element type is an unsupported
+/// floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+                                    Type type);
 } // namespace detail
 } // namespace LLVM
 
@@ -92,16 +98,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
-  /// Return the given type if it's a floating point type. If the given type is
-  /// a vector type, return its element type if it's a floating point type.
-  static FloatType getFloatingPointType(Type type) {
-    if (auto floatType = dyn_cast<FloatType>(type))
-      return floatType;
-    if (auto vecType = dyn_cast<VectorType>(type))
-      return dyn_cast<FloatType>(vecType.getElementType());
-    return nullptr;
-  }
-
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -109,25 +105,16 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
 
-    // The pattern should not apply if a floating-point operand is converted to
-    // a non-floating-point type. This indicates that the floating point type
-    // is not supported by the LLVM lowering. (Such types are converted to
-    // integers.)
-    auto checkType = [&](Value v) -> LogicalResult {
-      FloatType floatType = getFloatingPointType(v.getType());
-      if (!floatType)
-        return success();
-      Type convertedType = this->getTypeConverter()->convertType(floatType);
-      if (!isa_and_nonnull<FloatType>(convertedType))
+    // Bail on unsupported floating point types. (These are type-converted to
+    // integer types.)
+    for (Value operand : op->getOperands())
+      if (LLVM::detail::isUnsupportedFloatingPointType(
+              *this->getTypeConverter(), operand.getType()))
         return rewriter.notifyMatchFailure(op,
                                            "unsupported floating point type");
-      return success();
-    };
-    for (Value operand : op->getOperands())
-      if (failed(checkType(operand)))
-        return failure();
-    if (failed(checkType(op->getResult(0))))
-      return failure();
+    if (LLVM::detail::isUnsupportedFloatingPointType(
+            *this->getTypeConverter(), op->getResult(0).getType()))
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
 
     // Determine attributes for the target op
     AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..71a8128986e7f 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -459,6 +459,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
 LogicalResult
 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
+  if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
+                                                   op.getLhs().getType()))
+    return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
   Type operandType = adaptor.getLhs().getType();
   Type resultType = op.getResult().getType();
   LLVM::FastmathFlags fmf =
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b506e12d..b37b35d79901c 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -131,3 +131,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
                                        rewriter);
 }
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+  if (auto floatType = dyn_cast<FloatType>(type))
+    return floatType;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return dyn_cast<FloatType>(vecType.getElementType());
+  return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+    const TypeConverter &typeConverter, Type type) {
+  FloatType floatType = getFloatingPointType(type);
+  if (!floatType)
+    return false;
+  Type convertedType = typeConverter.convertType(floatType);
+  if (!convertedType)
+    return true;
+  return !isa<FloatType>(convertedType);
+}
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..8cf736c0a559f 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
 //       CHECK:   arith.addf {{.*}} : f4E2M1FN
 //       CHECK:   arith.addf {{.*}} : vector<4xf4E2M1FN>
 //       CHECK:   arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+//       CHECK:   arith.cmpf {{.*}} : f4E2M1FN
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN) {
   %0 = arith.addf %arg0, %arg0 : f4E2M1FN
   %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
   %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
-  return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+  %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
+  return
 }
 
 // -----
@@ -767,9 +769,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2
 //         CHECK:   llvm.fadd {{.*}} : f32
 //         CHECK:   llvm.fadd {{.*}} : vector<4xf32>
 // CHECK-COUNT-4:   llvm.fadd {{.*}} : vector<8xf32>
-func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+//         CHECK:   llvm.fcmp {{.*}} : f32
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
   %0 = arith.addf %arg0, %arg0 : f32
   %1 = arith.addf %arg1, %arg1 : vector<4xf32>
   %2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
-  return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+  %3 = arith.cmpf oeq, %arg0, %arg3 : f32
+  return
 }

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me but before we merge maybe @hanhanW @krzysz00 can test (because the other one broke IREE...)

@matthias-springer matthias-springer force-pushed the users/matthias-springer/fix_cmpf branch from 9c42dcd to b55e698 Compare November 9, 2025 04:45
@krzysz00
Copy link
Contributor

This still live?

@matthias-springer matthias-springer force-pushed the users/matthias-springer/fix_cmpf branch from b55e698 to a6287dc Compare November 28, 2025 02:41
@matthias-springer matthias-springer enabled auto-merge (squash) November 28, 2025 02:48
@matthias-springer matthias-springer merged commit b142912 into main Nov 28, 2025
10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/fix_cmpf branch November 28, 2025 03:01
@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 28, 2025

LLVM Buildbot has detected a new failure on builder ppc64le-mlir-rhel-clang running on ppc64le-mlir-rhel-test while building mlir at step 3 "clean-build-dir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/34002

Here is the relevant piece of the build log for the reference
Step 3 (clean-build-dir) failure: Delete failed. (failure) (timed out)
Step 4 (cmake-configure) failure: cmake (failure) (timed out)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants