Skip to content

Commit b142912

Browse files
[mlir][arith] Fix arith.cmpf lowering with unsupported FP types (#166684)
The `arith.cmpf` lowering pattern used to generate invalid IR when an unsupported floating-point type was used.
1 parent 6f7ea34 commit b142912

File tree

4 files changed

+48
-31
lines changed

4 files changed

+48
-31
lines changed

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
6060
Attribute propertiesAttr,
6161
const LLVMTypeConverter &typeConverter,
6262
ConversionPatternRewriter &rewriter);
63+
64+
/// Return "true" if the given type is an unsupported floating point type. In
65+
/// case of a vector type, return "true" if the element type is an unsupported
66+
/// floating point type.
67+
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
68+
Type type);
6369
} // namespace detail
6470
} // namespace LLVM
6571

@@ -97,43 +103,25 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
97103
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
98104
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
99105

100-
/// Return the given type if it's a floating point type. If the given type is
101-
/// a vector type, return its element type if it's a floating point type.
102-
static FloatType getFloatingPointType(Type type) {
103-
if (auto floatType = dyn_cast<FloatType>(type))
104-
return floatType;
105-
if (auto vecType = dyn_cast<VectorType>(type))
106-
return dyn_cast<FloatType>(vecType.getElementType());
107-
return nullptr;
108-
}
109-
110106
LogicalResult
111107
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
112108
ConversionPatternRewriter &rewriter) const override {
113109
static_assert(
114110
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
115111
"expected single result op");
116112

117-
// The pattern should not apply if a floating-point operand is converted to
118-
// a non-floating-point type. This indicates that the floating point type
119-
// is not supported by the LLVM lowering. (Such types are converted to
120-
// integers.)
121-
auto checkType = [&](Value v) -> LogicalResult {
122-
FloatType floatType = getFloatingPointType(v.getType());
123-
if (!floatType)
124-
return success();
125-
Type convertedType = this->getTypeConverter()->convertType(floatType);
126-
if (!isa_and_nonnull<FloatType>(convertedType))
127-
return rewriter.notifyMatchFailure(op,
128-
"unsupported floating point type");
129-
return success();
130-
};
113+
// Bail on unsupported floating point types. (These are type-converted to
114+
// integer types.)
131115
if (FailOnUnsupportedFP) {
132116
for (Value operand : op->getOperands())
133-
if (failed(checkType(operand)))
134-
return failure();
135-
if (failed(checkType(op->getResult(0))))
136-
return failure();
117+
if (LLVM::detail::isUnsupportedFloatingPointType(
118+
*this->getTypeConverter(), operand.getType()))
119+
return rewriter.notifyMatchFailure(op,
120+
"unsupported floating point type");
121+
if (LLVM::detail::isUnsupportedFloatingPointType(
122+
*this->getTypeConverter(), op->getResult(0).getType()))
123+
return rewriter.notifyMatchFailure(op,
124+
"unsupported floating point type");
137125
}
138126

139127
// Determine attributes for the target op

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
483483
LogicalResult
484484
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
485485
ConversionPatternRewriter &rewriter) const {
486+
if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
487+
op.getLhs().getType()))
488+
return rewriter.notifyMatchFailure(op, "unsupported floating point type");
489+
486490
Type operandType = adaptor.getLhs().getType();
487491
Type resultType = op.getResult().getType();
488492
LLVM::FastmathFlags fmf =

mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
130130
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
131131
rewriter);
132132
}
133+
134+
/// Return the given type if it's a floating point type. If the given type is
135+
/// a vector type, return its element type if it's a floating point type.
136+
static FloatType getFloatingPointType(Type type) {
137+
if (auto floatType = dyn_cast<FloatType>(type))
138+
return floatType;
139+
if (auto vecType = dyn_cast<VectorType>(type))
140+
return dyn_cast<FloatType>(vecType.getElementType());
141+
return nullptr;
142+
}
143+
144+
bool LLVM::detail::isUnsupportedFloatingPointType(
145+
const TypeConverter &typeConverter, Type type) {
146+
FloatType floatType = getFloatingPointType(type);
147+
if (!floatType)
148+
return false;
149+
Type convertedType = typeConverter.convertType(floatType);
150+
if (!convertedType)
151+
return true;
152+
return !isa<FloatType>(convertedType);
153+
}

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -770,12 +770,14 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
770770
// CHECK: arith.addf {{.*}} : f4E2M1FN
771771
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
772772
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
773+
// CHECK: arith.cmpf {{.*}} : f4E2M1FN
773774
// CHECK: llvm.select {{.*}} : i1, i4
774775
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
775776
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
776777
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
777778
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
778-
%3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
779+
%3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
780+
%4 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
779781
return
780782
}
781783

@@ -785,9 +787,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2
785787
// CHECK: llvm.fadd {{.*}} : f32
786788
// CHECK: llvm.fadd {{.*}} : vector<4xf32>
787789
// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
788-
func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
790+
// CHECK: llvm.fcmp {{.*}} : f32
791+
func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
789792
%0 = arith.addf %arg0, %arg0 : f32
790793
%1 = arith.addf %arg1, %arg1 : vector<4xf32>
791794
%2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
792-
return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
795+
%3 = arith.cmpf oeq, %arg0, %arg3 : f32
796+
return
793797
}

0 commit comments

Comments
 (0)