@@ -59,6 +59,12 @@ LogicalResult vectorOneToOneRewrite(
5959 ArrayRef<NamedAttribute> targetAttrs,
6060 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
6161 IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
62+
63+ // / Return "true" if the given type is an unsupported floating point type. In
64+ // / case of a vector type, return "true" if the element type is an unsupported
65+ // / floating point type.
66+ bool isUnsupportedFloatingPointType (const TypeConverter &typeConverter,
67+ Type type);
6268} // namespace detail
6369} // namespace LLVM
6470
@@ -92,42 +98,23 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
9298 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
9399 using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
94100
95- // / Return the given type if it's a floating point type. If the given type is
96- // / a vector type, return its element type if it's a floating point type.
97- static FloatType getFloatingPointType (Type type) {
98- if (auto floatType = dyn_cast<FloatType>(type))
99- return floatType;
100- if (auto vecType = dyn_cast<VectorType>(type))
101- return dyn_cast<FloatType>(vecType.getElementType ());
102- return nullptr ;
103- }
104-
105101 LogicalResult
106102 matchAndRewrite (SourceOp op, typename SourceOp::Adaptor adaptor,
107103 ConversionPatternRewriter &rewriter) const override {
108104 static_assert (
109105 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
110106 " expected single result op" );
111107
112- // The pattern should not apply if a floating-point operand is converted to
113- // a non-floating-point type. This indicates that the floating point type
114- // is not supported by the LLVM lowering. (Such types are converted to
115- // integers.)
116- auto checkType = [&](Value v) -> LogicalResult {
117- FloatType floatType = getFloatingPointType (v.getType ());
118- if (!floatType)
119- return success ();
120- Type convertedType = this ->getTypeConverter ()->convertType (floatType);
121- if (!isa_and_nonnull<FloatType>(convertedType))
108+ // Bail on unsupported floating point types. (These are type-converted to
109+ // integer types.)
110+ for (Value operand : op->getOperands ())
111+ if (LLVM::detail::isUnsupportedFloatingPointType (
112+ *this ->getTypeConverter (), operand.getType ()))
122113 return rewriter.notifyMatchFailure (op,
123114 " unsupported floating point type" );
124- return success ();
125- };
126- for (Value operand : op->getOperands ())
127- if (failed (checkType (operand)))
128- return failure ();
129- if (failed (checkType (op->getResult (0 ))))
130- return failure ();
115+ if (LLVM::detail::isUnsupportedFloatingPointType (
116+ *this ->getTypeConverter (), op->getResult (0 ).getType ()))
117+ return rewriter.notifyMatchFailure (op, " unsupported floating point type" );
131118
132119 // Determine attributes for the target op
133120 AttrConvert<SourceOp, TargetOp> attrConvert (op);
0 commit comments