Skip to content

Commit 9c42dcd

Browse files
[mlir][arith] Fix arith.cmpf lowering with unsupported FP types
1 parent 6c640b8 commit 9c42dcd

File tree

4 files changed

+47
-31
lines changed

4 files changed

+47
-31
lines changed

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ LogicalResult vectorOneToOneRewrite(
5959
ArrayRef<NamedAttribute> targetAttrs,
6060
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
6161
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
62+
63+
/// Return "true" if the given type is an unsupported floating point type. In
64+
/// case of a vector type, return "true" if the element type is an unsupported
65+
/// floating point type.
66+
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
67+
Type type);
6268
} // namespace detail
6369
} // namespace LLVM
6470

@@ -92,42 +98,23 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
9298
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
9399
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
94100

95-
/// Return the given type if it's a floating point type. If the given type is
96-
/// a vector type, return its element type if it's a floating point type.
97-
static FloatType getFloatingPointType(Type type) {
98-
if (auto floatType = dyn_cast<FloatType>(type))
99-
return floatType;
100-
if (auto vecType = dyn_cast<VectorType>(type))
101-
return dyn_cast<FloatType>(vecType.getElementType());
102-
return nullptr;
103-
}
104-
105101
LogicalResult
106102
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
107103
ConversionPatternRewriter &rewriter) const override {
108104
static_assert(
109105
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
110106
"expected single result op");
111107

112-
// The pattern should not apply if a floating-point operand is converted to
113-
// a non-floating-point type. This indicates that the floating point type
114-
// is not supported by the LLVM lowering. (Such types are converted to
115-
// integers.)
116-
auto checkType = [&](Value v) -> LogicalResult {
117-
FloatType floatType = getFloatingPointType(v.getType());
118-
if (!floatType)
119-
return success();
120-
Type convertedType = this->getTypeConverter()->convertType(floatType);
121-
if (!isa_and_nonnull<FloatType>(convertedType))
108+
// Bail on unsupported floating point types. (These are type-converted to
109+
// integer types.)
110+
for (Value operand : op->getOperands())
111+
if (LLVM::detail::isUnsupportedFloatingPointType(
112+
*this->getTypeConverter(), operand.getType()))
122113
return rewriter.notifyMatchFailure(op,
123114
"unsupported floating point type");
124-
return success();
125-
};
126-
for (Value operand : op->getOperands())
127-
if (failed(checkType(operand)))
128-
return failure();
129-
if (failed(checkType(op->getResult(0))))
130-
return failure();
115+
if (LLVM::detail::isUnsupportedFloatingPointType(
116+
*this->getTypeConverter(), op->getResult(0).getType()))
117+
return rewriter.notifyMatchFailure(op, "unsupported floating point type");
131118

132119
// Determine attributes for the target op
133120
AttrConvert<SourceOp, TargetOp> attrConvert(op);

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
459459
LogicalResult
460460
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
461461
ConversionPatternRewriter &rewriter) const {
462+
if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
463+
op.getLhs().getType()))
464+
return rewriter.notifyMatchFailure(op, "unsupported floating point type");
465+
462466
Type operandType = adaptor.getLhs().getType();
463467
Type resultType = op.getResult().getType();
464468
LLVM::FastmathFlags fmf =

mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
131131
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
132132
rewriter);
133133
}
134+
135+
/// Return the given type if it's a floating point type. If the given type is
136+
/// a vector type, return its element type if it's a floating point type.
137+
static FloatType getFloatingPointType(Type type) {
138+
if (auto floatType = dyn_cast<FloatType>(type))
139+
return floatType;
140+
if (auto vecType = dyn_cast<VectorType>(type))
141+
return dyn_cast<FloatType>(vecType.getElementType());
142+
return nullptr;
143+
}
144+
145+
bool LLVM::detail::isUnsupportedFloatingPointType(
146+
const TypeConverter &typeConverter, Type type) {
147+
FloatType floatType = getFloatingPointType(type);
148+
if (!floatType)
149+
return false;
150+
Type convertedType = typeConverter.convertType(floatType);
151+
if (!convertedType)
152+
return true;
153+
return !isa<FloatType>(convertedType);
154+
}

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
754754
// CHECK: arith.addf {{.*}} : f4E2M1FN
755755
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
756756
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
757-
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
757+
// CHECK: arith.cmpf {{.*}} : f4E2M1FN
758+
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN) {
758759
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
759760
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
760761
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
761-
return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
762+
%3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
763+
return
762764
}
763765

764766
// -----
@@ -767,9 +769,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2
767769
// CHECK: llvm.fadd {{.*}} : f32
768770
// CHECK: llvm.fadd {{.*}} : vector<4xf32>
769771
// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
770-
func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
772+
// CHECK: llvm.fcmp {{.*}} : f32
773+
func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
771774
%0 = arith.addf %arg0, %arg0 : f32
772775
%1 = arith.addf %arg1, %arg1 : vector<4xf32>
773776
%2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
774-
return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
777+
%3 = arith.cmpf oeq, %arg0, %arg3 : f32
778+
return
775779
}

0 commit comments

Comments
 (0)