Skip to content

[flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. #118556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 5, 2024

Conversation

vzakhari
Copy link
Contributor

@vzakhari vzakhari commented Dec 3, 2024

An array SUM with the specified constant DIM argument
may be expanded into hlfir.elemental with a reduction loop
inside it processing all elements of the specified dimension.
The expansion allows further optimization of the cases like
A=SUM(B+1,DIM=1) in the optimized bufferization pass
(given that it can prove there are no read/write conflicts).

An array SUM with the specified constant DIM argument
may be expanded into hlfir.elemental with a reduction loop
inside it processing all elements of the specified dimension.
The expansion allows further optimization of the cases like
`A=SUM(B+1,DIM=1)` in the optimized bufferization pass
(given that it can prove there are no read/write conflicts).
@vzakhari vzakhari requested review from tblah and jeanPerier December 3, 2024 22:22
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Dec 3, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

An array SUM with the specified constant DIM argument
may be expanded into hlfir.elemental with a reduction loop
inside it processing all elements of the specified dimension.
The expansion allows further optimization of the cases like
A=SUM(B+1,DIM=1) in the optimized bufferization pass
(given that it can prove there are no read/write conflicts).


Patch is 35.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118556.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+204)
  • (added) flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir (+361)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 60b06437e6a987..35dc881e880df2 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -10,6 +10,7 @@
 // into the calling function.
 //===----------------------------------------------------------------------===//
 
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
@@ -90,6 +91,190 @@ class TransposeAsElementalConversion
   }
 };
 
+// Expand the SUM(DIM=CONSTANT) operation into .
+class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::SumOp sum,
+                  mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = sum.getLoc();
+    fir::FirOpBuilder builder{rewriter, sum.getOperation()};
+    hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
+    assert(expr && "expected an expression type for the result of hlfir.sum");
+    mlir::Type elementType = expr.getElementType();
+    hlfir::Entity array = hlfir::Entity{sum.getArray()};
+    mlir::Value mask = sum.getMask();
+    mlir::Value dim = sum.getDim();
+    int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+    assert(dimVal > 0 && "DIM must be present and a positive constant");
+    mlir::Value resultShape, dimExtent;
+    std::tie(resultShape, dimExtent) =
+        genResultShape(loc, builder, array, dimVal);
+
+    auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange inputIndices) -> hlfir::Entity {
+      // Loop over all indices in the DIM dimension, and reduce all values.
+      // We do not need to create the reduction loop always: if we can
+      // slice the input array given the inputIndices, then we can
+      // just apply a new SUM operation (total reduction) to the slice.
+      // For the time being, generate the explicit loop because the slicing
+      // requires generating an elemental operation for the input array
+      // (and the mask, if present).
+      // TODO: produce the slices and new SUM after adding a pattern
+      // for expanding total reduction SUM case.
+      mlir::Type indexType = builder.getIndexType();
+      auto one = builder.createIntegerConstant(loc, indexType, 1);
+      auto ub = builder.createConvert(loc, indexType, dimExtent);
+
+      // Initial value for the reduction.
+      mlir::Value initValue = genInitValue(loc, builder, elementType);
+
+      // The reduction loop may be unordered if FastMathFlags::reassoc
+      // transformations are allowed. The integer reduction is always
+      // unordered.
+      bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
+                         static_cast<bool>(sum.getFastmath() &
+                                           mlir::arith::FastMathFlags::reassoc);
+
+      // If the mask is present and is a scalar, then we'd better load its value
+      // outside of the reduction loop making the loop unswitching easier.
+      // Maybe it is worth hoisting it from the elemental operation as well.
+      if (mask) {
+        hlfir::Entity maskValue{mask};
+        if (maskValue.isScalar())
+          mask = hlfir::loadTrivialScalar(loc, builder, maskValue);
+      }
+
+      // NOTE: the outer elemental operation may be lowered into
+      // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
+      // loop may appear disjoint from the workshare loop nest.
+      // Moreover, the inner loop is not strictly nested (due to the reduction
+      // starting value initialization), and the above omp dialect operations
+      // cannot produce results.
+      // It is unclear what we should do about it yet.
+      auto doLoop = builder.create<fir::DoLoopOp>(
+          loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
+          mlir::ValueRange{initValue});
+
+      // Address the input array using the reduction loop's IV
+      // for the DIM dimension.
+      mlir::Value iv = doLoop.getInductionVar();
+      llvm::SmallVector<mlir::Value> indices{inputIndices};
+      indices.insert(indices.begin() + dimVal - 1, iv);
+
+      mlir::OpBuilder::InsertionGuard guard(builder);
+      builder.setInsertionPointToStart(doLoop.getBody());
+      mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
+      fir::IfOp ifOp;
+      if (mask) {
+        // Make the reduction value update conditional on the value
+        // of the mask.
+        hlfir::Entity maskValue{mask};
+        if (!maskValue.isScalar()) {
+          // If the mask is an array, use the elemental and the loop indices
+          // to address the proper mask element.
+          maskValue = hlfir::getElementAt(loc, builder, maskValue, indices);
+          maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue);
+        }
+        mlir::Value isUnmasked =
+            builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
+        ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
+                                         /*withElseRegion=*/true);
+        // In the 'else' block return the current reduction value.
+        builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+        builder.create<fir::ResultOp>(loc, reductionValue);
+
+        // In the 'then' block do the actual addition.
+        builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      }
+
+      hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
+      hlfir::Entity elementValue =
+          hlfir::loadTrivialScalar(loc, builder, element);
+      // NOTE: we can use "Kahan summation" same way as the runtime
+      // (e.g. when fast-math is not allowed), but let's start with
+      // the simple version.
+      reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
+      builder.create<fir::ResultOp>(loc, reductionValue);
+
+      if (ifOp) {
+        builder.setInsertionPointAfter(ifOp);
+        builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
+      }
+
+      return hlfir::Entity{doLoop.getResult(0)};
+    };
+    hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+        loc, builder, elementType, resultShape, {}, genKernel,
+        /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
+        sum.getResult().getType());
+
+    // it wouldn't be safe to replace block arguments with a different
+    // hlfir.expr type. Types can differ due to differing amounts of shape
+    // information
+    assert(elementalOp.getResult().getType() == sum.getResult().getType());
+
+    rewriter.replaceOp(sum, elementalOp);
+    return mlir::success();
+  }
+
+private:
+  // Return fir.shape specifying the shape of the result
+  // of a SUM reduction with DIM=dimVal. The second return value
+  // is the extent of the DIM dimension.
+  static std::tuple<mlir::Value, mlir::Value>
+  genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
+                 hlfir::Entity array, int64_t dimVal) {
+    mlir::Value inShape = hlfir::genShape(loc, builder, array);
+    llvm::SmallVector<mlir::Value> inExtents =
+        hlfir::getExplicitExtentsFromShape(inShape, builder);
+    if (inShape.getUses().empty())
+      inShape.getDefiningOp()->erase();
+
+    mlir::Value dimExtent = inExtents[dimVal - 1];
+    inExtents.erase(inExtents.begin() + dimVal - 1);
+    return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
+  }
+
+  // Generate the initial value for a SUM reduction with the given
+  // data type.
+  static mlir::Value genInitValue(mlir::Location loc,
+                                  fir::FirOpBuilder &builder,
+                                  mlir::Type elementType) {
+    if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
+      const llvm::fltSemantics &sem = ty.getFloatSemantics();
+      return builder.createRealConstant(loc, elementType,
+                                        llvm::APFloat::getZero(sem));
+    } else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
+      mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
+      return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
+                                                               initValue);
+    } else if (mlir::isa<mlir::IntegerType>(elementType)) {
+      return builder.createIntegerConstant(loc, elementType, 0);
+    }
+
+    llvm_unreachable("unsupported SUM reduction type");
+  }
+
+  // Generate scalar addition of the two values (of the same data type).
+  static mlir::Value genScalarAdd(mlir::Location loc,
+                                  fir::FirOpBuilder &builder,
+                                  mlir::Value value1, mlir::Value value2) {
+    mlir::Type ty = value1.getType();
+    assert(ty == value2.getType() && "reduction values' types do not match");
+    if (mlir::isa<mlir::FloatType>(ty))
+      return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
+    else if (mlir::isa<mlir::ComplexType>(ty))
+      return builder.create<fir::AddcOp>(loc, value1, value2);
+    else if (mlir::isa<mlir::IntegerType>(ty))
+      return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
+
+    llvm_unreachable("unsupported SUM reduction type");
+  }
+};
+
 class SimplifyHLFIRIntrinsics
     : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
 public:
@@ -97,6 +282,7 @@ class SimplifyHLFIRIntrinsics
     mlir::MLIRContext *context = &getContext();
     mlir::RewritePatternSet patterns(context);
     patterns.insert<TransposeAsElementalConversion>(context);
+    patterns.insert<SumAsElementalConversion>(context);
     mlir::ConversionTarget target(*context);
     // don't transform transpose of polymorphic arrays (not currently supported
     // by hlfir.elemental)
@@ -105,6 +291,24 @@ class SimplifyHLFIRIntrinsics
           return mlir::cast<hlfir::ExprType>(transpose.getType())
               .isPolymorphic();
         });
+    // Handle only SUM(DIM=CONSTANT) case for now.
+    // It may be beneficial to expand the non-DIM case as well.
+    // E.g. when the input array is an elemental array expression,
+    // expanding the SUM into a total reduction loop nest
+    // would avoid creating a temporary for the elemental array expression.
+    target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
+      if (mlir::Value dim = sum.getDim()) {
+        if (fir::getIntIfConstant(dim)) {
+          if (!fir::isa_trivial(sum.getType())) {
+            // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
+            // It is only legal when X is 1, and it should probably be
+            // canonicalized into SUM(a).
+            return false;
+          }
+        }
+      }
+      return true;
+    });
     target.markUnknownOpDynamicallyLegal(
         [](mlir::Operation *) { return true; });
     if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
new file mode 100644
index 00000000000000..05a4dfde6344e2
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -0,0 +1,361 @@
+// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
+
+// box with known extents
+func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
+  %cst = arith.constant 2 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<2x3xi32>>, i32) -> !hlfir.expr<2xi32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_box_known_extents(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: !fir.box<!fir.array<2x3xi32>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 2 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xi32> {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: index):
+// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK:               %[[VAL_12:.*]] = arith.constant 0 : index
+// CHECK:               %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_12]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_14]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_17:.*]] = arith.subi %[[VAL_13]]#0, %[[VAL_16]] : index
+// CHECK:               %[[VAL_18:.*]] = arith.addi %[[VAL_6]], %[[VAL_17]] : index
+// CHECK:               %[[VAL_19:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_16]] : index
+// CHECK:               %[[VAL_20:.*]] = arith.addi %[[VAL_10]], %[[VAL_19]] : index
+// CHECK:               %[[VAL_21:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_18]], %[[VAL_20]])  : (!fir.box<!fir.array<2x3xi32>>, index, index) -> !fir.ref<i32>
+// CHECK:               %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
+// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_11]], %[[VAL_22]] : i32
+// CHECK:               fir.result %[[VAL_23]] : i32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// expr with known extents
+func.func @sum_expr_known_extents(%arg0: !hlfir.expr<2x3xi32>) {
+  %cst = arith.constant 1 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_expr_known_extents(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !hlfir.expr<2x3xi32>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xi32> {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: index):
+// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK:               %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xi32>, index, index) -> i32
+// CHECK:               %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
+// CHECK:               fir.result %[[VAL_13]] : i32
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// box with unknown extent
+func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+  %cst = arith.constant 1 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<?x3xcomplex<f64>>>, i32) -> !hlfir.expr<3xcomplex<f64>>
+  return
+}
+// CHECK-LABEL:   func.func @sum_box_unknown_extent1(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_4:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xcomplex<f64>> {
+// CHECK:           ^bb0(%[[VAL_7:.*]]: index):
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
+// CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_3]]#1 step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+// CHECK:               %[[VAL_16:.*]] = arith.constant 0 : index
+// CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_19:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_18]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_20:.*]] = arith.constant 1 : index
+// CHECK:               %[[VAL_21:.*]] = arith.subi %[[VAL_17]]#0, %[[VAL_20]] : index
+// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_14]], %[[VAL_21]] : index
+// CHECK:               %[[VAL_23:.*]] = arith.subi %[[VAL_19]]#0, %[[VAL_20]] : index
+// CHECK:               %[[VAL_24:.*]] = arith.addi %[[VAL_7]], %[[VAL_23]] : index
+// CHECK:               %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
+// CHECK:               %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
+// CHECK:               %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
+// CHECK:               fir.result %[[VAL_27]] : complex<f64>
+// CHECK:             }
+// CHECK:             hlfir.yield_element %[[VAL_13]] : complex<f64>
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+  %cst = arith.constant 2 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!fir.box<!fir.array<?x3xcomplex<f64>>>, i32) -> !hlfir.expr<?xcomplex<f64>>
+  return
+}
+// CHECK-LABEL:   func.func @sum_box_unknown_extent2(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: !fir.box<!fir.array<?x3xcomplex<f64>>>) {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 2 : i32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_4:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_3]]#1 : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xcomplex<f64>> {
+// CHECK:           ^bb0(%[[VAL_7:.*]]: index):
+// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
+// CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
+// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+//...
[truncated]

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks great!

// Moreover, the inner loop is not strictly nested (due to the reduction
// starting value initialization), and the above omp dialect operations
// cannot produce results.
// It is unclear what we should do about it yet.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we will need some reduction concept in HLFIR and we can leave later passes map that to whatever parallelism concept.

Also, I do not know of OpenMP/ACC handle reductions of expression (looking at OpenMP standard 5.2 section 5.5.8, it looks like the list item of an OpenMP reduction must be variables (array/array sections)), so I am not sure if the MLIR operations will be able to represent reduction on expression operands (A+B), while there is technically an opportunity to parallelize. Anyway, I think that is not in the scope of this patch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of having a reduction concept. That might make it easier to share code across these various optimized intrinsic implementations. I think it could take a lot of work for OpenMP to take advantage of this. I'm happy to discuss more if you are interested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is possible to represent the reduction using a temporary reduction storage instead of the iter_args, but this requires making sure that the storage has proper data sharing attributes with regards to the enclosing parallel constructs. I would prefer to keep it as-is right now, and then think about OpenMP cases as they arise.

// requires generating an elemental operation for the input array
// (and the mask, if present).
// TODO: produce the slices and new SUM after adding a pattern
// for expanding total reduction SUM case.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Comment on lines +104 to +105
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
assert(expr && "expected an expression type for the result of hlfir.sum");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
assert(expr && "expected an expression type for the result of hlfir.sum");
hlfir::ExprType expr = mlir::cast<hlfir::ExprType>(sum.getType());

nit: cast<>() already has the assertion built in https://llvm.org/docs/ProgrammersManual.html#the-isa-cast-and-dyn-cast-templates

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am usually in favor of more explanatory assertion messages, so I would prefer to leave it here.

// Moreover, the inner loop is not strictly nested (due to the reduction
// starting value initialization), and the above omp dialect operations
// cannot produce results.
// It is unclear what we should do about it yet.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is okay. Most intrinsics are going to be evaluated in a single thread in WORKSHARE for now (which is what some other compilers do too). In this case I think SUM would be best implemented with a special rewrite pattern for openmp using a reduction clause.

In general, implementing good multithreaded versions of these intrinsics that are useful on both CPU and offloading devices is quite hard. My opinion is that we should only attempt this when there is a concrete performance case to benchmark. I wouldn't want this relatively rare openmp construct (with historically poor compiler support) to make performance work in the rest of the compiler more difficult.

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to deal with dynamically absent masks LGTM, thanks

@vzakhari vzakhari merged commit cc46d0b into llvm:main Dec 5, 2024
8 checks passed
@tblah
Copy link
Contributor

tblah commented Dec 9, 2024

Hi @vzakhari, we have a regression of roughly 5% on exchange2_r from spec2017 since this patch. This was measured on an aws graviton 3 machine (Arm Neoverse-V1). The flags I used were -Ofast -mcpu=native -flto.

Would it be okay to revert this commit until the problem is solved? Please get in touch if I can help with reproducing this issue.

Edit: some other engineers at Arm had a quick look. This does not appear to reproduce on Neoverse-V2 (e.g. Grace) but this might be influenced by -mtune=/-mcpu= rather than the actual CPU design so if you don't have a V1 system available you could try -mtune=neoverse-v1 on a Grace system. The difference might be down to changes in inlining or loop unrolling as a result of different code generated after this patch (although this has not been confirmed). - I hope this helps!

@vzakhari
Copy link
Contributor Author

vzakhari commented Dec 9, 2024

Hi Tom! Thanks for letting me know! I will try to reproduce it and see what I can do.

vzakhari added a commit to vzakhari/llvm-project that referenced this pull request Dec 9, 2024
To temporarily address exchange2 perf regression reported in llvm#118556
I disabled the inlining by default, and put it under engineering
option `-flang-simplify-hlfir-sum`.
@vzakhari
Copy link
Contributor Author

vzakhari commented Dec 9, 2024

I tried reproducing it today on grace with -mtune=neoverse-v1, but with no luck. I could not use -flto, because I did not have LLVMgold.so plugin readily available. With -flto there is no perf difference, though I see the assembly changes in digits_2 for the only case of SUM(DIM) - it looks like a complete unroll of 8x9 loop nest.

@tblah by any chance can you confirm that the regression appears only with -flto? I will build the plugin and try it myself, but not until tomorrow. If you get regression without -flto, then I will need to find the right testing machine.

In the meantime, I disabled the simplification by default in #119287.

@vzakhari
Copy link
Contributor Author

vzakhari commented Dec 9, 2024

Just FYI, it seems the SUM(DIM) case in digits_2 might be optimized further by proving that the LHS and RHS accesses never overlap. Flang creates an unnecessary temp currently.

@tblah
Copy link
Contributor

tblah commented Dec 10, 2024

Thanks for disabling this for now and looking into it.

I can only reproduce this with -flto. The LTO pipeline gives significant improvement to exchange2 because of aggressive inlining after function specialization.

I had a quick look and I think function specialization is behaving differently with and without your patch:

With:

 nm -C benchspec/CPU/548.exchange2_r/run/run_base_refrate_flang-64.0001/exchange2_r_base.flang-64 | grep digits_2

00000000002ca180 t _QMbrute_forcePdigits_2.specialized.1
00000000002cd040 t _QMbrute_forcePdigits_2.specialized.5

Without:

nm -C benchspec/CPU/548.exchange2_r/run/run_base_refrate_flang-64.0001/exchange2_r_base.flang-64 | grep digits_2

00000000002e0860 t _QMbrute_forcePdigits_2.specialized.4

But I am not familiar with the function specialization pass so I might have misunderstood something.

vzakhari added a commit that referenced this pull request Dec 10, 2024
To temporarily address exchange2 perf regression reported in #118556
I disabled the inlining by default, and put it under engineering
option `-flang-simplify-hlfir-sum`.
@vzakhari
Copy link
Contributor Author

I was able to reproduce ~3% slowdown with the latest flang-new -Ofast -mcpu=neoverse-v1 -mtune=neoverse-v1 -flto -fuse-ld=lld -mmlir -flang-simplify-hlfir-sum=true on a grace machine.

The function specialization works same way with and without partial SUM inlining. The difference appears later after LLVM inlining. I suppose the size threshold is affected by the SUM inlining. I do not have any idea how to fix that, so I decided to experiment with further improving exchange2 performance.

I have two prorotype patches that improve InlineElementals and OptimizedBufferization passes. Together with the total SUM reduction patch, these patches make exchange2 a little bit faster than with the current flang-new. I will prepare the two patches for reviews, and will flip the engineering switch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants