Skip to content

Commit

Permalink
[flang] Simplify hlfir.sum total reductions. (#119482)
Browse files Browse the repository at this point in the history
I am trying to switch to keeping the reduction value in a temporary
scalar location so that I can use hlfir::genLoopNest easily.
This also allows using omp.loop_nest with worksharing for OpenMP.
  • Loading branch information
vzakhari authored Dec 13, 2024
1 parent af5d3af commit a00946f
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 182 deletions.
35 changes: 35 additions & 0 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,10 @@ struct LoopNest {
/// Generate a fir.do_loop nest looping from 1 to extents[i].
/// \p isUnordered specifies whether the loops in the loop nest
/// are unordered.
///
/// NOTE: genLoopNestWithReductions() should be used in favor
/// of this method, though, it cannot generate OpenMP workshare
/// loop constructs currently.
LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange extents, bool isUnordered = false,
bool emitWorkshareLoop = false);
Expand All @@ -376,6 +380,37 @@ inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
isUnordered, emitWorkshareLoop);
}

/// The type of a callback that generates the body of a reduction
/// loop nest. It takes a location and a builder, as usual.
/// In addition, the first set of values are the values of the loops'
/// induction variables. The second set of values are the values
/// of the reductions on entry to the innermost loop.
/// The callback must return the updated values of the reductions.
using ReductionLoopBodyGenerator = std::function<llvm::SmallVector<mlir::Value>(
mlir::Location, fir::FirOpBuilder &, mlir::ValueRange, mlir::ValueRange)>;

/// Generate a loop nest loopong from 1 to \p extents[i] and reducing
/// a set of values.
/// \p isUnordered specifies whether the loops in the loop nest
/// are unordered.
/// \p reductionInits are the initial values of the reductions
/// on entry to the outermost loop.
/// \p genBody callback is repsonsible for generating the code
/// that updates the reduction values in the innermost loop.
///
/// NOTE: the implementation of this function may decide
/// to perform the reductions on SSA or in memory.
/// In the latter case, this function is responsible for
/// allocating/loading/storing the reduction variables,
/// and making sure they have proper data sharing attributes
/// in case any parallel constructs are present around the point
/// of the loop nest insertion, or if the function decides
/// to use any worksharing loop constructs for the loop nest.
llvm::SmallVector<mlir::Value> genLoopNestWithReductions(
mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents,
mlir::ValueRange reductionInits, const ReductionLoopBodyGenerator &genBody,
bool isUnordered = false);

/// Inline the body of an hlfir.elemental at the current insertion point
/// given a list of one based indices. This generates the computation
/// of one element of the elemental expression. Return the YieldElementOp
Expand Down
50 changes: 50 additions & 0 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,56 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
return loopNest;
}

llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents,
mlir::ValueRange reductionInits, const ReductionLoopBodyGenerator &genBody,
bool isUnordered) {
assert(!extents.empty() && "must have at least one extent");
// Build loop nest from column to row.
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
mlir::Type indexType = builder.getIndexType();
unsigned dim = extents.size() - 1;
fir::DoLoopOp outerLoop = nullptr;
fir::DoLoopOp parentLoop = nullptr;
llvm::SmallVector<mlir::Value> oneBasedIndices;
oneBasedIndices.resize(dim + 1);
for (auto extent : llvm::reverse(extents)) {
auto ub = builder.createConvert(loc, indexType, extent);

// The outermost loop takes reductionInits as the initial
// values of its iter-args.
// A child loop takes its iter-args from the region iter-args
// of its parent loop.
fir::DoLoopOp doLoop;
if (!parentLoop) {
doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
/*finalCountValue=*/false,
reductionInits);
} else {
doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
/*finalCountValue=*/false,
parentLoop.getRegionIterArgs());
// Return the results of the child loop from its parent loop.
builder.create<fir::ResultOp>(loc, doLoop.getResults());
}

builder.setInsertionPointToStart(doLoop.getBody());
// Reverse the indices so they are in column-major order.
oneBasedIndices[dim--] = doLoop.getInductionVar();
if (!outerLoop)
outerLoop = doLoop;
parentLoop = doLoop;
}

llvm::SmallVector<mlir::Value> reductionValues;
reductionValues =
genBody(loc, builder, oneBasedIndices, parentLoop.getRegionIterArgs());
builder.setInsertionPointToEnd(parentLoop.getBody());
builder.create<fir::ResultOp>(loc, reductionValues);
builder.setInsertionPointAfter(outerLoop);
return outerLoop->getResults();
}

static fir::ExtendedValue translateVariableToExtendedValue(
mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity variable,
bool forceHlfirBase = false, bool contiguousHint = false) {
Expand Down
234 changes: 128 additions & 106 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,34 +106,43 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = sum.getLoc();
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
assert(expr && "expected an expression type for the result of hlfir.sum");
mlir::Type elementType = expr.getElementType();
mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
hlfir::Entity array = hlfir::Entity{sum.getArray()};
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
int64_t dimVal =
isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
mlir::Value resultShape, dimExtent;
std::tie(resultShape, dimExtent) =
genResultShape(loc, builder, array, dimVal);
llvm::SmallVector<mlir::Value> arrayExtents;
if (isTotalReduction)
arrayExtents = genArrayExtents(loc, builder, array);
else
std::tie(resultShape, dimExtent) =
genResultShapeForPartialReduction(loc, builder, array, dimVal);

// If the mask is present and is a scalar, then we'd better load its value
// outside of the reduction loop making the loop unswitching easier.
mlir::Value isPresentPred, maskValue;
if (mask) {
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
// MASK represented by a box might be dynamically optional,
// so we have to check for its presence before accessing it.
isPresentPred =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
}

if (hlfir::Entity{mask}.isScalar())
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
}

auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange inputIndices) -> hlfir::Entity {
// Loop over all indices in the DIM dimension, and reduce all values.
// We do not need to create the reduction loop always: if we can
// slice the input array given the inputIndices, then we can
// just apply a new SUM operation (total reduction) to the slice.
// For the time being, generate the explicit loop because the slicing
// requires generating an elemental operation for the input array
// (and the mask, if present).
// TODO: produce the slices and new SUM after adding a pattern
// for expanding total reduction SUM case.
mlir::Type indexType = builder.getIndexType();
auto one = builder.createIntegerConstant(loc, indexType, 1);
auto ub = builder.createConvert(loc, indexType, dimExtent);
// If DIM is not present, do total reduction.

// Initial value for the reduction.
mlir::Value initValue = genInitValue(loc, builder, elementType);
mlir::Value reductionInitValue = genInitValue(loc, builder, elementType);

// The reduction loop may be unordered if FastMathFlags::reassoc
// transformations are allowed. The integer reduction is always
Expand All @@ -142,79 +151,83 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
static_cast<bool>(sum.getFastmath() &
mlir::arith::FastMathFlags::reassoc);

// If the mask is present and is a scalar, then we'd better load its value
// outside of the reduction loop making the loop unswitching easier.
// Maybe it is worth hoisting it from the elemental operation as well.
mlir::Value isPresentPred, maskValue;
if (mask) {
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
// MASK represented by a box might be dynamically optional,
// so we have to check for its presence before accessing it.
isPresentPred =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
llvm::SmallVector<mlir::Value> extents;
if (isTotalReduction)
extents = arrayExtents;
else
extents.push_back(
builder.createConvert(loc, builder.getIndexType(), dimExtent));

auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange oneBasedIndices,
mlir::ValueRange reductionArgs)
-> llvm::SmallVector<mlir::Value, 1> {
// Generate the reduction loop-nest body.
// The initial reduction value in the innermost loop
// is passed via reductionArgs[0].
llvm::SmallVector<mlir::Value> indices;
if (isTotalReduction) {
indices = oneBasedIndices;
} else {
indices = inputIndices;
indices.insert(indices.begin() + dimVal - 1, oneBasedIndices[0]);
}

if (hlfir::Entity{mask}.isScalar())
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
}
mlir::Value reductionValue = reductionArgs[0];
fir::IfOp ifOp;
if (mask) {
// Make the reduction value update conditional on the value
// of the mask.
if (!maskValue) {
// If the mask is an array, use the elemental and the loop indices
// to address the proper mask element.
maskValue =
genMaskValue(loc, builder, mask, isPresentPred, indices);
}
mlir::Value isUnmasked = builder.create<fir::ConvertOp>(
loc, builder.getI1Type(), maskValue);
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
/*withElseRegion=*/true);
// In the 'else' block return the current reduction value.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<fir::ResultOp>(loc, reductionValue);

// In the 'then' block do the actual addition.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
}

// NOTE: the outer elemental operation may be lowered into
// omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
// loop may appear disjoint from the workshare loop nest.
// Moreover, the inner loop is not strictly nested (due to the reduction
// starting value initialization), and the above omp dialect operations
// cannot produce results.
// It is unclear what we should do about it yet.
auto doLoop = builder.create<fir::DoLoopOp>(
loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
mlir::ValueRange{initValue});

// Address the input array using the reduction loop's IV
// for the DIM dimension.
mlir::Value iv = doLoop.getInductionVar();
llvm::SmallVector<mlir::Value> indices{inputIndices};
indices.insert(indices.begin() + dimVal - 1, iv);

mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(doLoop.getBody());
mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
fir::IfOp ifOp;
if (mask) {
// Make the reduction value update conditional on the value
// of the mask.
if (!maskValue) {
// If the mask is an array, use the elemental and the loop indices
// to address the proper mask element.
maskValue = genMaskValue(loc, builder, mask, isPresentPred, indices);
hlfir::Entity element =
hlfir::getElementAt(loc, builder, array, indices);
hlfir::Entity elementValue =
hlfir::loadTrivialScalar(loc, builder, element);
// NOTE: we can use "Kahan summation" same way as the runtime
// (e.g. when fast-math is not allowed), but let's start with
// the simple version.
reductionValue =
genScalarAdd(loc, builder, reductionValue, elementValue);

if (ifOp) {
builder.create<fir::ResultOp>(loc, reductionValue);
builder.setInsertionPointAfter(ifOp);
reductionValue = ifOp.getResult(0);
}
mlir::Value isUnmasked =
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
/*withElseRegion=*/true);
// In the 'else' block return the current reduction value.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<fir::ResultOp>(loc, reductionValue);

// In the 'then' block do the actual addition.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
}

hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
hlfir::Entity elementValue =
hlfir::loadTrivialScalar(loc, builder, element);
// NOTE: we can use "Kahan summation" same way as the runtime
// (e.g. when fast-math is not allowed), but let's start with
// the simple version.
reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
builder.create<fir::ResultOp>(loc, reductionValue);

if (ifOp) {
builder.setInsertionPointAfter(ifOp);
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
}
return {reductionValue};
};

return hlfir::Entity{doLoop.getResult(0)};
llvm::SmallVector<mlir::Value, 1> reductionFinalValues =
hlfir::genLoopNestWithReductions(loc, builder, extents,
{reductionInitValue}, genBody,
isUnordered);
return hlfir::Entity{reductionFinalValues[0]};
};

if (isTotalReduction) {
hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{});
rewriter.replaceOp(sum, result);
return mlir::success();
}

hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
loc, builder, elementType, resultShape, {}, genKernel,
/*isUnordered=*/true, /*polymorphicMold=*/nullptr,
Expand All @@ -230,20 +243,29 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
}

private:
static llvm::SmallVector<mlir::Value>
genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity array) {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();
return inExtents;
}

// Return fir.shape specifying the shape of the result
// of a SUM reduction with DIM=dimVal. The second return value
// is the extent of the DIM dimension.
static std::tuple<mlir::Value, mlir::Value>
genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity array, int64_t dimVal) {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
genResultShapeForPartialReduction(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity array, int64_t dimVal) {
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
genArrayExtents(loc, builder, array);
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();

mlir::Value dimExtent = inExtents[dimVal - 1];
inExtents.erase(inExtents.begin() + dimVal - 1);
Expand Down Expand Up @@ -459,22 +481,22 @@ class SimplifyHLFIRIntrinsics
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
if (!simplifySum)
return true;
if (mlir::Value dim = sum.getDim()) {
if (auto dimVal = fir::getIntIfConstant(dim)) {
if (!fir::isa_trivial(sum.getType())) {
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
// It is only legal when X is 1, and it should probably be
// canonicalized into SUM(a).
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(
sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}

// Always inline total reductions.
if (hlfir::Entity{sum}.getRank() == 0)
return false;
mlir::Value dim = sum.getDim();
if (!dim)
return false;

if (auto dimVal = fir::getIntIfConstant(dim)) {
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}
return true;
Expand Down
Loading

0 comments on commit a00946f

Please sign in to comment.