Skip to content

[mlir][math] Fix intrinsic conversions to LLVM for 0D-vector types #141020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 69 additions & 54 deletions mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
using ATan2OpLowering =
ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
// may be better to separate the patterns.
template <typename MathOp, typename LLVMOp>
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
Expand All @@ -81,26 +83,29 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
LogicalResult
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto &typeConverter = *this->getTypeConverter();
auto operandType = adaptor.getOperand().getType();

if (!operandType || !LLVM::isCompatibleType(operandType))
auto llvmOperandType = typeConverter.convertType(operandType);
if (!llvmOperandType)
return failure();

auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto llvmResultType = typeConverter.convertType(resultType);
if (!llvmResultType)
return failure();

if (!isa<LLVM::LLVMArrayType>(operandType)) {
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
false);
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
adaptor.getOperand(), false);
return success();
}

auto vectorType = dyn_cast<VectorType>(resultType);
if (!vectorType)
if (!isa<VectorType>(llvmResultType))
return failure();

return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
false);
Expand All @@ -123,40 +128,42 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
LogicalResult
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto &typeConverter = *this->getTypeConverter();
auto operandType = adaptor.getOperand().getType();

if (!operandType || !LLVM::isCompatibleType(operandType))
auto llvmOperandType = typeConverter.convertType(operandType);
if (!llvmOperandType)
return failure();

auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
auto floatType = cast<FloatType>(
typeConverter.convertType(getElementTypeOrSelf(resultType)));
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);

if (!isa<LLVM::LLVMArrayType>(operandType)) {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(operandType)) {
if (LLVM::isCompatibleVectorType(llvmOperandType)) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne));
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
}
auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
expAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
return success();
}

auto vectorType = dyn_cast<VectorType>(resultType);
if (!vectorType)
if (!isa<VectorType>(resultType))
return rewriter.notifyMatchFailure(op, "expected vector result type");

return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
Expand All @@ -181,41 +188,43 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
LogicalResult
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto &typeConverter = *this->getTypeConverter();
auto operandType = adaptor.getOperand().getType();

if (!operandType || !LLVM::isCompatibleType(operandType))
auto llvmOperandType = typeConverter.convertType(operandType);
if (!llvmOperandType)
return rewriter.notifyMatchFailure(op, "unsupported operand type");

auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
auto floatType = cast<FloatType>(
typeConverter.convertType(getElementTypeOrSelf(resultType)));
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);

if (!isa<LLVM::LLVMArrayType>(operandType)) {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one =
LLVM::isCompatibleVectorType(operandType)
isa<VectorType>(llvmOperandType)
? rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(cast<ShapedType>(resultType),
loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne))
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
: rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
floatOne);

auto add = rewriter.create<LLVM::FAddOp>(
loc, operandType, ValueRange{one, adaptor.getOperand()},
loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
addAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
logAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::LogOp>(
op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
return success();
}

auto vectorType = dyn_cast<VectorType>(resultType);
if (!vectorType)
if (!isa<VectorType>(resultType))
return rewriter.notifyMatchFailure(op, "expected vector result type");

return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
Expand All @@ -241,40 +250,42 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
LogicalResult
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto &typeConverter = *this->getTypeConverter();
auto operandType = adaptor.getOperand().getType();

if (!operandType || !LLVM::isCompatibleType(operandType))
auto llvmOperandType = typeConverter.convertType(operandType);
if (!llvmOperandType)
return failure();

auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
auto floatType = cast<FloatType>(
typeConverter.convertType(getElementTypeOrSelf(resultType)));
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);

if (!isa<LLVM::LLVMArrayType>(operandType)) {
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(operandType)) {
if (isa<VectorType>(llvmOperandType)) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
loc, llvmOperandType,
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
floatOne));
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
}
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
sqrtAttrs.getAttrs());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
return success();
}

auto vectorType = dyn_cast<VectorType>(resultType);
if (!vectorType)
if (!isa<VectorType>(resultType))
return failure();

return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
Expand All @@ -298,13 +309,15 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
LogicalResult
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();

if (!operandType || !LLVM::isCompatibleType(operandType))
const auto &typeConverter = *this->getTypeConverter();
auto operandType =
typeConverter.convertType(adaptor.getOperand().getType());
auto resultType = typeConverter.convertType(op.getResult().getType());
if (!operandType || !resultType)
return failure();

rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
op, op.getType(), adaptor.getOperand(), llvm::fcNan);
op, resultType, adaptor.getOperand(), llvm::fcNan);
return success();
}
};
Expand All @@ -315,13 +328,15 @@ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
LogicalResult
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();

if (!operandType || !LLVM::isCompatibleType(operandType))
const auto &typeConverter = *this->getTypeConverter();
auto operandType =
typeConverter.convertType(adaptor.getOperand().getType());
auto resultType = typeConverter.convertType(op.getResult().getType());
if (!operandType || !resultType)
return failure();

rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
op, resultType, adaptor.getOperand(), llvm::fcFinite);
return success();
}
};
Expand Down
Loading