-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. #118556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
An array SUM with the specified constant DIM argument may be expanded into hlfir.elemental with a reduction loop inside it processing all elements of the specified dimension. The expansion allows further optimization of the cases like `A=SUM(B+1,DIM=1)` in the optimized bufferization pass (given that it can prove there are no read/write conflicts).
@llvm/pr-subscribers-flang-fir-hlfir Author: Slava Zakharin (vzakhari) ChangesAn array SUM with the specified constant DIM argument Patch is 35.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118556.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 60b06437e6a987..35dc881e880df2 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -10,6 +10,7 @@
// into the calling function.
//===----------------------------------------------------------------------===//
+#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
@@ -90,6 +91,190 @@ class TransposeAsElementalConversion
}
};
+// Expand the SUM(DIM=CONSTANT) operation into .
+class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+public:
+ using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::SumOp sum,
+ mlir::PatternRewriter &rewriter) const override {
+ mlir::Location loc = sum.getLoc();
+ fir::FirOpBuilder builder{rewriter, sum.getOperation()};
+ hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
+ assert(expr && "expected an expression type for the result of hlfir.sum");
+ mlir::Type elementType = expr.getElementType();
+ hlfir::Entity array = hlfir::Entity{sum.getArray()};
+ mlir::Value mask = sum.getMask();
+ mlir::Value dim = sum.getDim();
+ int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+ assert(dimVal > 0 && "DIM must be present and a positive constant");
+ mlir::Value resultShape, dimExtent;
+ std::tie(resultShape, dimExtent) =
+ genResultShape(loc, builder, array, dimVal);
+
+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange inputIndices) -> hlfir::Entity {
+ // Loop over all indices in the DIM dimension, and reduce all values.
+ // We do not need to create the reduction loop always: if we can
+ // slice the input array given the inputIndices, then we can
+ // just apply a new SUM operation (total reduction) to the slice.
+ // For the time being, generate the explicit loop because the slicing
+ // requires generating an elemental operation for the input array
+ // (and the mask, if present).
+ // TODO: produce the slices and new SUM after adding a pattern
+ // for expanding total reduction SUM case.
+ mlir::Type indexType = builder.getIndexType();
+ auto one = builder.createIntegerConstant(loc, indexType, 1);
+ auto ub = builder.createConvert(loc, indexType, dimExtent);
+
+ // Initial value for the reduction.
+ mlir::Value initValue = genInitValue(loc, builder, elementType);
+
+ // The reduction loop may be unordered if FastMathFlags::reassoc
+ // transformations are allowed. The integer reduction is always
+ // unordered.
+ bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
+ static_cast<bool>(sum.getFastmath() &
+ mlir::arith::FastMathFlags::reassoc);
+
+ // If the mask is present and is a scalar, then we'd better load its value
+ // outside of the reduction loop making the loop unswitching easier.
+ // Maybe it is worth hoisting it from the elemental operation as well.
+ if (mask) {
+ hlfir::Entity maskValue{mask};
+ if (maskValue.isScalar())
+ mask = hlfir::loadTrivialScalar(loc, builder, maskValue);
+ }
+
+ // NOTE: the outer elemental operation may be lowered into
+ // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
+ // loop may appear disjoint from the workshare loop nest.
+ // Moreover, the inner loop is not strictly nested (due to the reduction
+ // starting value initialization), and the above omp dialect operations
+ // cannot produce results.
+ // It is unclear what we should do about it yet.
+ auto doLoop = builder.create<fir::DoLoopOp>(
+ loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
+ mlir::ValueRange{initValue});
+
+ // Address the input array using the reduction loop's IV
+ // for the DIM dimension.
+ mlir::Value iv = doLoop.getInductionVar();
+ llvm::SmallVector<mlir::Value> indices{inputIndices};
+ indices.insert(indices.begin() + dimVal - 1, iv);
+
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(doLoop.getBody());
+ mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
+ fir::IfOp ifOp;
+ if (mask) {
+ // Make the reduction value update conditional on the value
+ // of the mask.
+ hlfir::Entity maskValue{mask};
+ if (!maskValue.isScalar()) {
+ // If the mask is an array, use the elemental and the loop indices
+ // to address the proper mask element.
+ maskValue = hlfir::getElementAt(loc, builder, maskValue, indices);
+ maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue);
+ }
+ mlir::Value isUnmasked =
+ builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
+ ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
+ /*withElseRegion=*/true);
+ // In the 'else' block return the current reduction value.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<fir::ResultOp>(loc, reductionValue);
+
+ // In the 'then' block do the actual addition.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ }
+
+ hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
+ hlfir::Entity elementValue =
+ hlfir::loadTrivialScalar(loc, builder, element);
+ // NOTE: we can use "Kahan summation" same way as the runtime
+ // (e.g. when fast-math is not allowed), but let's start with
+ // the simple version.
+ reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
+ builder.create<fir::ResultOp>(loc, reductionValue);
+
+ if (ifOp) {
+ builder.setInsertionPointAfter(ifOp);
+ builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
+ }
+
+ return hlfir::Entity{doLoop.getResult(0)};
+ };
+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+ loc, builder, elementType, resultShape, {}, genKernel,
+ /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
+ sum.getResult().getType());
+
+ // it wouldn't be safe to replace block arguments with a different
+ // hlfir.expr type. Types can differ due to differing amounts of shape
+ // information
+ assert(elementalOp.getResult().getType() == sum.getResult().getType());
+
+ rewriter.replaceOp(sum, elementalOp);
+ return mlir::success();
+ }
+
+private:
+ // Return fir.shape specifying the shape of the result
+ // of a SUM reduction with DIM=dimVal. The second return value
+ // is the extent of the DIM dimension.
+ static std::tuple<mlir::Value, mlir::Value>
+ genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::Entity array, int64_t dimVal) {
+ mlir::Value inShape = hlfir::genShape(loc, builder, array);
+ llvm::SmallVector<mlir::Value> inExtents =
+ hlfir::getExplicitExtentsFromShape(inShape, builder);
+ if (inShape.getUses().empty())
+ inShape.getDefiningOp()->erase();
+
+ mlir::Value dimExtent = inExtents[dimVal - 1];
+ inExtents.erase(inExtents.begin() + dimVal - 1);
+ return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
+ }
+
+ // Generate the initial value for a SUM reduction with the given
+ // data type.
+ static mlir::Value genInitValue(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::Type elementType) {
+ if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(loc, elementType,
+ llvm::APFloat::getZero(sem));
+ } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
+ mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
+ return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
+ initValue);
+ } else if (mlir::isa<mlir::IntegerType>(elementType)) {
+ return builder.createIntegerConstant(loc, elementType, 0);
+ }
+
+ llvm_unreachable("unsupported SUM reduction type");
+ }
+
+ // Generate scalar addition of the two values (of the same data type).
+ static mlir::Value genScalarAdd(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::Value value1, mlir::Value value2) {
+ mlir::Type ty = value1.getType();
+ assert(ty == value2.getType() && "reduction values' types do not match");
+ if (mlir::isa<mlir::FloatType>(ty))
+ return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
+ else if (mlir::isa<mlir::ComplexType>(ty))
+ return builder.create<fir::AddcOp>(loc, value1, value2);
+ else if (mlir::isa<mlir::IntegerType>(ty))
+ return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
+
+ llvm_unreachable("unsupported SUM reduction type");
+ }
+};
+
class SimplifyHLFIRIntrinsics
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
public:
@@ -97,6 +282,7 @@ class SimplifyHLFIRIntrinsics
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);
+ patterns.insert<SumAsElementalConversion>(context);
mlir::ConversionTarget target(*context);
// don't transform transpose of polymorphic arrays (not currently supported
// by hlfir.elemental)
@@ -105,6 +291,24 @@ class SimplifyHLFIRIntrinsics
return mlir::cast<hlfir::ExprType>(transpose.getType())
.isPolymorphic();
});
+ // Handle only SUM(DIM=CONSTANT) case for now.
+ // It may be beneficial to expand the non-DIM case as well.
+ // E.g. when the input array is an elemental array expression,
+ // expanding the SUM into a total reduction loop nest
+ // would avoid creating a temporary for the elemental array expression.
+ target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
+ if (mlir::Value dim = sum.getDim()) {
+ if (fir::getIntIfConstant(dim)) {
+ if (!fir::isa_trivial(sum.getType())) {
+ // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
+ // It is only legal when X is 1, and it should probably be
+ // canonicalized into SUM(a).
+ return false;
+ }
+ }
+ }
+ return true;
+ });
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
new file mode 100644
index 00000000000000..05a4dfde6344e2
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -0,0 +1,361 @@
+// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
+
+// box with known extents
+func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
+ %cst = arith.constant 2 : i32
+ %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<2x3xi32>>, i32) -> !hlfir.expr<2xi32>
+ return
+}
+// CHECK-LABEL: func.func @sum_box_known_extents(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<2x3xi32>>) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xi32> {
+// CHECK: ^bb0(%[[VAL_6:.*]]: index):
+// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK: %[[VAL_12:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_12]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_14]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_17:.*]] = arith.subi %[[VAL_13]]#0, %[[VAL_16]] : index
+// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_6]], %[[VAL_17]] : index
+// CHECK: %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_16]] : index
+// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_10]], %[[VAL_19]] : index
+// CHECK: %[[VAL_21:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_18]], %[[VAL_20]]) : (!fir.box<!fir.array<2x3xi32>>, index, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
+// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_11]], %[[VAL_22]] : i32
+// CHECK: fir.result %[[VAL_23]] : i32
+// CHECK: }
+// CHECK: hlfir.yield_element %[[VAL_9]] : i32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// expr with known extents
+func.func @sum_expr_known_extents(%arg0: !hlfir.expr<2x3xi32>) {
+ %cst = arith.constant 1 : i32
+ %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+ return
+}
+// CHECK-LABEL: func.func @sum_expr_known_extents(
+// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<2x3xi32>) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xi32> {
+// CHECK: ^bb0(%[[VAL_6:.*]]: index):
+// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xi32>, index, index) -> i32
+// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
+// CHECK: fir.result %[[VAL_13]] : i32
+// CHECK: }
+// CHECK: hlfir.yield_element %[[VAL_9]] : i32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// box with unknown extent
+func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+ %cst = arith.constant 1 : i32
+ %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<?x3xcomplex<f64>>>, i32) -> !hlfir.expr<3xcomplex<f64>>
+ return
+}
+// CHECK-LABEL: func.func @sum_box_unknown_extent1(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK: %[[VAL_4:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xcomplex<f64>> {
+// CHECK: ^bb0(%[[VAL_7:.*]]: index):
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[VAL_10:.*]] = fir.undefined complex<f64>
+// CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_3]]#1 step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_18]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK: %[[VAL_20:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index
+// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_14]], %[[VAL_21]] : index
+// CHECK: %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index
+// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_7]], %[[VAL_23]] : index
+// CHECK: %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]]) : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
+// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
+// CHECK: %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
+// CHECK: fir.result %[[VAL_27]] : complex<f64>
+// CHECK: }
+// CHECK: hlfir.yield_element %[[VAL_13]] : complex<f64>
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+ %cst = arith.constant 2 : i32
+ %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<?x3xcomplex<f64>>>, i32) -> !hlfir.expr<?xcomplex<f64>>
+ return
+}
+// CHECK-LABEL: func.func @sum_box_unknown_extent2(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK: %[[VAL_4:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_3]]#1 : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xcomplex<f64>> {
+// CHECK: ^bb0(%[[VAL_7:.*]]: index):
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[VAL_10:.*]] = fir.undefined complex<f64>
+// CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+//...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks great!
// Moreover, the inner loop is not strictly nested (due to the reduction | ||
// starting value initialization), and the above omp dialect operations | ||
// cannot produce results. | ||
// It is unclear what we should do about it yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we will need some reduction concept in HLFIR and we can leave later passes map that to whatever parallelism concept.
Also, I do not know of OpenMP/ACC handle reductions of expression (looking at OpenMP standard 5.2 section 5.5.8, it looks like the list item of an OpenMP reduction must be variables (array/array sections)), so I am not sure if the MLIR operations will be able to represent reduction on expression operands (A+B), while there is technically an opportunity to parallelize. Anyway, I think that is not in the scope of this patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of having a reduction concept. That might make it easier to share code across these various optimized intrinsic implementations. I think it could take a lot of work for OpenMP to take advantage of this. I'm happy to discuss more if you are interested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is possible to represent the reduction using a temporary reduction storage instead of the iter_args, but this requires making sure that the storage has proper data sharing attributes with regards to the enclosing parallel constructs. I would prefer to keep it as-is right now, and then think about OpenMP cases as they arise.
// requires generating an elemental operation for the input array | ||
// (and the mask, if present). | ||
// TODO: produce the slices and new SUM after adding a pattern | ||
// for expanding total reduction SUM case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType()); | ||
assert(expr && "expected an expression type for the result of hlfir.sum"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType()); | |
assert(expr && "expected an expression type for the result of hlfir.sum"); | |
hlfir::ExprType expr = mlir::cast<hlfir::ExprType>(sum.getType()); |
nit: cast<>()
already has the assertion built in https://llvm.org/docs/ProgrammersManual.html#the-isa-cast-and-dyn-cast-templates
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am usually in favor of more explanatory assertion messages, so I would prefer to leave it here.
// Moreover, the inner loop is not strictly nested (due to the reduction | ||
// starting value initialization), and the above omp dialect operations | ||
// cannot produce results. | ||
// It is unclear what we should do about it yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is okay. Most intrinsics are going to be evaluated in a single thread in WORKSHARE for now (which is what some other compilers do too). In this case I think SUM would be best implemented with a special rewrite pattern for openmp using a reduction clause.
In general, implementing good multithreaded versions of these intrinsics that are useful on both CPU and offloading devices is quite hard. My opinion is that we should only attempt this when there is a concrete performance case to benchmark. I wouldn't want this relatively rare openmp construct (with historically poor compiler support) to make performance work in the rest of the compiler more difficult.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update to deal with dynamically absent masks LGTM, thanks
Hi @vzakhari, we have a regression of roughly 5% on exchange2_r from spec2017 since this patch. This was measured on an aws graviton 3 machine (Arm Neoverse-V1). The flags I used were Would it be okay to revert this commit until the problem is solved? Please get in touch if I can help with reproducing this issue. Edit: some other engineers at Arm had a quick look. This does not appear to reproduce on Neoverse-V2 (e.g. Grace) but this might be influenced by |
Hi Tom! Thanks for letting me know! I will try to reproduce it and see what I can do. |
To temporarily address exchange2 perf regression reported in llvm#118556 I disabled the inlining by default, and put it under engineering option `-flang-simplify-hlfir-sum`.
I tried reproducing it today on grace with @tblah by any chance can you confirm that the regression appears only with In the meantime, I disabled the simplification by default in #119287. |
Just FYI, it seems the |
Thanks for disabling this for now and looking into it. I can only reproduce this with I had a quick look and I think function specialization is behaving differently with and without your patch: With:
Without:
But I am not familiar with the function specialization pass so I might have misunderstood something. |
To temporarily address exchange2 perf regression reported in #118556 I disabled the inlining by default, and put it under engineering option `-flang-simplify-hlfir-sum`.
I was able to reproduce ~3% slowdown with the latest The function specialization works same way with and without partial SUM inlining. The difference appears later after LLVM inlining. I suppose the size threshold is affected by the SUM inlining. I do not have any idea how to fix that, so I decided to experiment with further improving exchange2 performance. I have two prorotype patches that improve InlineElementals and OptimizedBufferization passes. Together with the total SUM reduction patch, these patches make exchange2 a little bit faster than with the current flang-new. I will prepare the two patches for reviews, and will flip the engineering switch. |
An array SUM with the specified constant DIM argument
may be expanded into hlfir.elemental with a reduction loop
inside it processing all elements of the specified dimension.
The expansion allows further optimization of the cases like
A=SUM(B+1,DIM=1)
in the optimized bufferization pass(given that it can prove there are no read/write conflicts).