Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Inline hlfir.cshift as hlfir.elemental. #119480

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
Expand Down Expand Up @@ -331,6 +332,108 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
}
};

class CShiftAsElementalConversion
: public mlir::OpRewritePattern<hlfir::CShiftOp> {
public:
using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;

explicit CShiftAsElementalConversion(mlir::MLIRContext *ctx)
: OpRewritePattern(ctx) {
setHasBoundedRewriteRecursion();
}

llvm::LogicalResult
matchAndRewrite(hlfir::CShiftOp cshift,
mlir::PatternRewriter &rewriter) const override {
using Fortran::common::maxRank;

mlir::Location loc = cshift.getLoc();
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
assert(expr &&
"expected an expression type for the result of hlfir.cshift");
mlir::Type elementType = expr.getElementType();
hlfir::Entity array = hlfir::Entity{cshift.getArray()};
mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> arrayExtents =
hlfir::getExplicitExtentsFromShape(arrayShape, builder);
unsigned arrayRank = expr.getRank();
llvm::SmallVector<mlir::Value, 1> typeParams;
hlfir::genLengthParameters(loc, builder, array, typeParams);
hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
// The new index computation involves MODULO, which is not implemented
// for IndexType, so use I64 instead.
mlir::Type calcType = builder.getI64Type();

mlir::Value one = builder.createIntegerConstant(loc, calcType, 1);
mlir::Value shiftVal;
if (shift.isScalar()) {
shiftVal = hlfir::loadTrivialScalar(loc, builder, shift);
shiftVal = builder.createConvert(loc, calcType, shiftVal);
}

int64_t dimVal = 1;
if (arrayRank == 1) {
// When it is a 1D CSHIFT, we may assume that the DIM argument
// (whether it is present or absent) is equal to 1, otherwise,
// the program is illegal.
assert(shiftVal && "SHIFT must be scalar");
} else {
if (mlir::Value dim = cshift.getDim())
dimVal = fir::getIntIfConstant(dim).value_or(0);
assert(dimVal > 0 && dimVal <= arrayRank &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
}

auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange inputIndices) -> hlfir::Entity {
llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
if (!shift.isScalar()) {
// When the array is not a vector, section
// (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)
// of the result has a value equal to:
// CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)),
// SH, 1),
// where SH is either SHIFT (if scalar) or
// SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)).
llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices};
shiftIndices.erase(shiftIndices.begin() + dimVal - 1);
hlfir::Entity shiftElement =
hlfir::getElementAt(loc, builder, shift, shiftIndices);
shiftVal = hlfir::loadTrivialScalar(loc, builder, shiftElement);
shiftVal = builder.createConvert(loc, calcType, shiftVal);
}

// Element i of the result (1-based) is element
// 'MODULO(i + SH - 1, SIZE(ARRAY)) + 1' (1-based) of the original
// ARRAY (or its section, when ARRAY is not a vector).
mlir::Value index =
builder.createConvert(loc, calcType, inputIndices[dimVal - 1]);
mlir::Value extent = arrayExtents[dimVal - 1];
mlir::Value newIndex =
builder.create<mlir::arith::AddIOp>(loc, index, shiftVal);
newIndex = builder.create<mlir::arith::SubIOp>(loc, newIndex, one);
newIndex = fir::IntrinsicLibrary{builder, loc}.genModulo(
calcType, {newIndex, builder.createConvert(loc, calcType, extent)});
newIndex = builder.create<mlir::arith::AddIOp>(loc, newIndex, one);
newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex);

indices[dimVal - 1] = newIndex;
hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
return hlfir::loadTrivialScalar(loc, builder, element);
};

hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
loc, builder, elementType, arrayShape, typeParams, genKernel,
/*isUnordered=*/true,
array.isPolymorphic() ? static_cast<mlir::Value>(array) : nullptr,
cshift.getResult().getType());
rewriter.replaceOp(cshift, elementalOp);
return mlir::success();
}
};

class SimplifyHLFIRIntrinsics
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
public:
Expand All @@ -339,6 +442,7 @@ class SimplifyHLFIRIntrinsics
mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);
patterns.insert<SumAsElementalConversion>(context);
patterns.insert<CShiftAsElementalConversion>(context);
mlir::ConversionTarget target(*context);
// don't transform transpose of polymorphic arrays (not currently supported
// by hlfir.elemental)
Expand Down Expand Up @@ -375,6 +479,24 @@ class SimplifyHLFIRIntrinsics
}
return true;
});
target.addDynamicallyLegalOp<hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be easier to turn this pass into a greedy conversion, and check all the conditions in the patterns themselves. If there are no objections, I will clean it up in a separate NFC patch.

unsigned resultRank = hlfir::Entity{cshift}.getRank();
if (resultRank == 1)
return false;

mlir::Value dim = cshift.getDim();
if (!dim)
return false;

// If DIM is present, then it must be constant to please
// the conversion. In addition, ignore cases with
// illegal DIM values.
if (auto dimVal = fir::getIntIfConstant(dim))
if (*dimVal > 0 && *dimVal <= resultRank)
return false;

return true;
});
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
Expand Down
Loading
Loading