Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Improve disjoint/identical slices recognition in opt-bufferization. #119780

Merged
merged 2 commits into from
Dec 13, 2024

Conversation

vzakhari
Copy link
Contributor

The changes are needed to be able to optimize 'x(9,:)=SUM(x(1:8,:),DIM=1)'
without a temporary array. This pattern exists in exchange2.

The patch also fixes an existing problem in Flang with this test:

program main
  integer :: a(10) = (/1,2,3,4,5,6,7,8,9,10/)
  integer :: expected(10) = (/1,10,9,8,7,6,5,4,3,2/)
  print *, 'INPUT: ', a
  print *, 'EXPECTED: ', expected
  call test(a, 10, 2, 10, 9)
  print *, 'RESULT: ', a
contains
  subroutine test(a, size, x, y, z)
    integer :: x, y, z, size
    integer :: a(:)
    a(x:y:1) = a(z:x-1:-1) + 1
  end subroutine test
end program main

…ation.

The changes are needed to be able to optimize 'x(9,:)=SUM(x(1:8,:),DIM=1)'
without a temporary array. This pattern exists in exchange2.

The patch also fixes an existing problem in Flang with this test:
```
program main
  integer :: a(10) = (/1,2,3,4,5,6,7,8,9,10/)
  integer :: expected(10) = (/1,10,9,8,7,6,5,4,3,2/)
  print *, 'INPUT: ', a
  print *, 'EXPECTED: ', expected
  call test(a, 10, 2, 10, 9)
  print *, 'RESULT: ', a
contains
  subroutine test(a, size, x, y, z)
    integer :: x, y, z, size
    integer :: a(:)
    a(x:y:1) = a(z:x-1:-1) + 1
  end subroutine test
end program main
```
@vzakhari vzakhari requested review from tblah and jeanPerier December 12, 2024 22:04
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Dec 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

The changes are needed to be able to optimize 'x(9,:)=SUM(x(1:8,:),DIM=1)'
without a temporary array. This pattern exists in exchange2.

The patch also fixes an existing problem in Flang with this test:

program main
  integer :: a(10) = (/1,2,3,4,5,6,7,8,9,10/)
  integer :: expected(10) = (/1,10,9,8,7,6,5,4,3,2/)
  print *, 'INPUT: ', a
  print *, 'EXPECTED: ', expected
  call test(a, 10, 2, 10, 9)
  print *, 'RESULT: ', a
contains
  subroutine test(a, size, x, y, z)
    integer :: x, y, z, size
    integer :: a(:)
    a(x:y:1) = a(z:x-1:-1) + 1
  end subroutine test
end program main

Patch is 39.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119780.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+241-100)
  • (modified) flang/test/HLFIR/opt-array-slice-assign.fir (+424)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index ef6aabbceacb76..e3d97e6ad54305 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -159,28 +159,162 @@ containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect,
   return mlir::AliasResult::NoAlias;
 }
 
-// Returns true if the given array references represent identical
-// or completely disjoint array slices. The callers may use this
-// method when the alias analysis reports an alias of some kind,
-// so that we can run Fortran specific analysis on the array slices
-// to see if they are identical or disjoint. Note that the alias
-// analysis are not able to give such an answer about the references.
-static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) {
+// Helper class for analyzing two array slices represented
+// by two hlfir.designate operations.
+class ArraySectionAnalyzer {
+public:
+  // The result of the analyzis is one of the values below.
+  enum class SlicesOverlapKind {
+    // Slices overlap is unknown.
+    Unknown,
+    // Slices are definitely disjoint.
+    DefinitelyIdentical,
+    // Slices are definitely identical.
+    DefinitelyDisjoint,
+    // Slices may be either disjoint or identical,
+    // i.e. there is definitely no partial overlap.
+    EitherIdenticalOrDisjoint
+  };
+
+  // Analyzes two hlfir.designate results and returns the overlap kind.
+  // The callers may use this method when the alias analysis reports
+  // an alias of some kind, so that we can run Fortran specific analysis
+  // on the array slices to see if they are identical or disjoint.
+  // Note that the alias analysis are not able to give such an answer
+  // about the references.
+  static SlicesOverlapKind analyze(mlir::Value ref1, mlir::Value ref2);
+
+private:
+  struct SectionDesc {
+    // An array section is described by <lb, ub, stride> tuple.
+    // If the designator's subscript is not a triple, then
+    // the section descriptor is constructed as <lb, nullptr, nullptr>.
+    mlir::Value lb, ub, stride;
+
+    SectionDesc(mlir::Value lb, mlir::Value ub, mlir::Value stride)
+        : lb(lb), ub(ub), stride(stride) {
+      assert(lb && "lower bound or index must be specified");
+      normalize();
+    }
+
+    // Normalize the section descriptor:
+    //   1. If UB is nullptr, then it is set to LB.
+    //   2. If LB==UB, then stride does not matter,
+    //      so it is reset to nullptr.
+    //   3. If STRIDE==1, then it is reset to nullptr.
+    void normalize() {
+      if (!ub)
+        ub = lb;
+      if (lb == ub)
+        stride = nullptr;
+      if (stride)
+        if (auto val = fir::getIntIfConstant(stride))
+          if (*val == 1)
+            stride = nullptr;
+    }
+
+    bool operator==(const SectionDesc &other) const {
+      return lb == other.lb && ub == other.ub && stride == other.stride;
+    }
+  };
+
+  // Given an operand_iterator over the indices operands,
+  // read the subscript values and return them as SectionDesc
+  // updating the iterator. If isTriplet is true,
+  // the subscript is a triplet, and the result is <lb, ub, stride>.
+  // Otherwise, the subscript is a scalar index, and the result
+  // is <index, nullptr, nullptr>.
+  static SectionDesc readSectionDesc(mlir::Operation::operand_iterator &it,
+                                     bool isTriplet) {
+    if (isTriplet)
+      return {*it++, *it++, *it++};
+    return {*it++, nullptr, nullptr};
+  }
+
+  // Return the ordered lower and upper bounds of the section.
+  // If stride is known to be non-negative, then the ordered
+  // bounds match the <lb, ub> of the descriptor.
+  // If stride is known to be negative, then the ordered
+  // bounds are <ub, lb> of the descriptor.
+  // If stride is unknown, we cannot deduce any order,
+  // so the result is <nullptr, nullptr>
+  static std::pair<mlir::Value, mlir::Value>
+  getOrderedBounds(const SectionDesc &desc) {
+    mlir::Value stride = desc.stride;
+    // Null stride means stride-1.
+    if (!stride)
+      return {desc.lb, desc.ub};
+    // Reverse the bounds, if stride is negative.
+    if (auto val = fir::getIntIfConstant(stride)) {
+      if (*val >= 0)
+        return {desc.lb, desc.ub};
+      else
+        return {desc.ub, desc.lb};
+    }
+
+    return {nullptr, nullptr};
+  }
+
+  // Given two array sections <lb1, ub1, stride1> and
+  // <lb2, ub2, stride2>, return true only if the sections
+  // are known to be disjoint.
+  //
+  // For example, for any positive constant C:
+  //   X:Y does not overlap with (Y+C):Z
+  //   X:Y does not overlap with Z:(X-C)
+  static bool areDisjointSections(const SectionDesc &desc1,
+                                  const SectionDesc &desc2) {
+    auto [lb1, ub1] = getOrderedBounds(desc1);
+    auto [lb2, ub2] = getOrderedBounds(desc2);
+    if (!lb1 || !lb2)
+      return false;
+    // Note that this comparison must be made on the ordered bounds,
+    // otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated
+    // as not overlapping (x=2, y=10, z=9).
+    if (isLess(ub1, lb2) || isLess(ub2, lb1))
+      return true;
+    return false;
+  }
+
+  // Given two array sections <lb1, ub1, stride1> and
+  // <lb2, ub2, stride2>, return true only if the sections
+  // are known to be identical.
+  //
+  // For example:
+  //   <x, x, stride>
+  //   <x, nullptr, nullptr>
+  //
+  // These sections are identical, from the point of which array
+  // elements are being addresses, even though the shape
+  // of the array slices might be different.
+  static bool areIdenticalSections(const SectionDesc &desc1,
+                                   const SectionDesc &desc2) {
+    if (desc1 == desc2)
+      return true;
+    return false;
+  }
+
+  // Return true, if v1 is known to be less than v2.
+  static bool isLess(mlir::Value v1, mlir::Value v2);
+};
+
+ArraySectionAnalyzer::SlicesOverlapKind
+ArraySectionAnalyzer::analyze(mlir::Value ref1, mlir::Value ref2) {
   if (ref1 == ref2)
-    return true;
+    return SlicesOverlapKind::DefinitelyIdentical;
 
   auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>();
   auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>();
   // We only support a pair of designators right now.
   if (!des1 || !des2)
-    return false;
+    return SlicesOverlapKind::Unknown;
 
   if (des1.getMemref() != des2.getMemref()) {
     // If the bases are different, then there is unknown overlap.
     LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n"
                             << des1 << "and:\n"
                             << des2 << "\n");
-    return false;
+    return SlicesOverlapKind::Unknown;
   }
 
   // Require all components of the designators to be the same.
@@ -194,104 +328,105 @@ static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) {
     LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n"
                             << des1 << "and:\n"
                             << des2 << "\n");
-    return false;
-  }
-
-  if (des1.getIsTriplet() != des2.getIsTriplet()) {
-    LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n"
-                            << des1 << "and:\n"
-                            << des2 << "\n");
-    return false;
+    return SlicesOverlapKind::Unknown;
   }
 
   // Analyze the subscripts.
-  // For example:
-  //   hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0)  shape %9
-  //   hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1)  shape %9
-  //
-  // If all the triplets (section speficiers) are the same, then
-  // we do not care if %0 is equal to %1 - the slices are either
-  // identical or completely disjoint.
   auto des1It = des1.getIndices().begin();
   auto des2It = des2.getIndices().begin();
   bool identicalTriplets = true;
-  for (bool isTriplet : des1.getIsTriplet()) {
-    if (isTriplet) {
-      for (int i = 0; i < 3; ++i)
-        if (*des1It++ != *des2It++) {
-          LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
-                                  << des1 << "and:\n"
-                                  << des2 << "\n");
-          identicalTriplets = false;
-          break;
-        }
-    } else {
-      ++des1It;
-      ++des2It;
+  bool identicalIndices = true;
+  for (auto [isTriplet1, isTriplet2] :
+       llvm::zip(des1.getIsTriplet(), des2.getIsTriplet())) {
+    SectionDesc desc1 = readSectionDesc(des1It, isTriplet1);
+    SectionDesc desc2 = readSectionDesc(des2It, isTriplet2);
+
+    // See if we can prove that any of the sections do not overlap.
+    // This is mostly a Polyhedron/nf performance hack that looks for
+    // particular relations between the lower and upper bounds
+    // of the array sections, e.g. for any positive constant C:
+    //   X:Y does not overlap with (Y+C):Z
+    //   X:Y does not overlap with Z:(X-C)
+    if (areDisjointSections(desc1, desc2))
+      return SlicesOverlapKind::DefinitelyDisjoint;
+
+    if (!areIdenticalSections(desc1, desc2)) {
+      if (isTriplet1 || isTriplet2) {
+        // For example:
+        //   hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0)
+        //   hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1)
+        //
+        // If all the triplets (section speficiers) are the same, then
+        // we do not care if %0 is equal to %1 - the slices are either
+        // identical or completely disjoint.
+        //
+        // Also, treat these as identical sections:
+        //   hlfir.designate %6#0 (%c2:%c2:%c1)
+        //   hlfir.designate %6#0 (%c2)
+        identicalTriplets = false;
+        LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
+                                << des1 << "and:\n"
+                                << des2 << "\n");
+      } else {
+        identicalIndices = false;
+        LLVM_DEBUG(llvm::dbgs() << "Indices mismatch for:\n"
+                                << des1 << "and:\n"
+                                << des2 << "\n");
+      }
     }
   }
-  if (identicalTriplets)
-    return true;
 
-  // See if we can prove that any of the triplets do not overlap.
-  // This is mostly a Polyhedron/nf performance hack that looks for
-  // particular relations between the lower and upper bounds
-  // of the array sections, e.g. for any positive constant C:
-  //   X:Y does not overlap with (Y+C):Z
-  //   X:Y does not overlap with Z:(X-C)
-  auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
-    auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
-      auto *op = v.getDefiningOp();
-      while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
-        op = conv.getValue().getDefiningOp();
-      return op;
-    };
+  if (identicalTriplets) {
+    if (identicalIndices)
+      return SlicesOverlapKind::DefinitelyIdentical;
+    else
+      return SlicesOverlapKind::EitherIdenticalOrDisjoint;
+  }
 
-    auto isPositiveConstant = [](mlir::Value v) -> bool {
-      if (auto conOp =
-              mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
-        if (auto iattr = mlir::dyn_cast<mlir::IntegerAttr>(conOp.getValue()))
-          return iattr.getInt() > 0;
-      return false;
-    };
+  LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n"
+                          << des1 << "and:\n"
+                          << des2 << "\n");
+  return SlicesOverlapKind::Unknown;
+}
 
-    auto *op1 = removeConvert(v1);
-    auto *op2 = removeConvert(v2);
-    if (!op1 || !op2)
-      return false;
-    if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
-      if ((addi.getLhs().getDefiningOp() == op1 &&
-           isPositiveConstant(addi.getRhs())) ||
-          (addi.getRhs().getDefiningOp() == op1 &&
-           isPositiveConstant(addi.getLhs())))
-        return true;
-    if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
-      if (subi.getLhs().getDefiningOp() == op2 &&
-          isPositiveConstant(subi.getRhs()))
-        return true;
+bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
+  auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
+    auto *op = v.getDefiningOp();
+    while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
+      op = conv.getValue().getDefiningOp();
+    return op;
+  };
+
+  auto isPositiveConstant = [](mlir::Value v) -> bool {
+    if (auto val = fir::getIntIfConstant(v))
+      return *val > 0;
     return false;
   };
 
-  des1It = des1.getIndices().begin();
-  des2It = des2.getIndices().begin();
-  for (bool isTriplet : des1.getIsTriplet()) {
-    if (isTriplet) {
-      mlir::Value des1Lb = *des1It++;
-      mlir::Value des1Ub = *des1It++;
-      mlir::Value des2Lb = *des2It++;
-      mlir::Value des2Ub = *des2It++;
-      // Ignore strides.
-      ++des1It;
-      ++des2It;
-      if (displacedByConstant(des1Ub, des2Lb) ||
-          displacedByConstant(des2Ub, des1Lb))
-        return true;
-    } else {
-      ++des1It;
-      ++des2It;
-    }
-  }
+  auto *op1 = removeConvert(v1);
+  auto *op2 = removeConvert(v2);
+  if (!op1 || !op2)
+    return false;
 
+  // Check if they are both constants.
+  if (auto val1 = fir::getIntIfConstant(op1->getResult(0)))
+    if (auto val2 = fir::getIntIfConstant(op2->getResult(0)))
+      return *val1 < *val2;
+
+  // Handle some variable cases (C > 0):
+  //   v2 = v1 + C
+  //   v2 = C + v1
+  //   v1 = v2 - C
+  if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
+    if ((addi.getLhs().getDefiningOp() == op1 &&
+         isPositiveConstant(addi.getRhs())) ||
+        (addi.getRhs().getDefiningOp() == op1 &&
+         isPositiveConstant(addi.getLhs())))
+      return true;
+  if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
+    if (subi.getLhs().getDefiningOp() == op2 &&
+        isPositiveConstant(subi.getRhs()))
+      return true;
   return false;
 }
 
@@ -405,21 +540,27 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
     if (!res.isPartial()) {
       if (auto designate =
               effect.getValue().getDefiningOp<hlfir::DesignateOp>()) {
-        if (!areIdenticalOrDisjointSlices(match.array, designate.getMemref())) {
+        ArraySectionAnalyzer::SlicesOverlapKind overlap =
+            ArraySectionAnalyzer::analyze(match.array, designate.getMemref());
+        if (overlap ==
+            ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint)
+          continue;
+
+        if (overlap == ArraySectionAnalyzer::SlicesOverlapKind::Unknown) {
           LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
                                   << " at " << elemental.getLoc() << "\n");
           return std::nullopt;
         }
         auto indices = designate.getIndices();
         auto elementalIndices = elemental.getIndices();
-        if (indices.size() != elementalIndices.size()) {
-          LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
-                                  << " at " << elemental.getLoc() << "\n");
-          return std::nullopt;
-        }
-        if (std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
+        if (indices.size() == elementalIndices.size() &&
+            std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
                        elementalIndices.end()))
           continue;
+
+        LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
+                                << " at " << elemental.getLoc() << "\n");
+        return std::nullopt;
       }
     }
     LLVM_DEBUG(llvm::dbgs() << "disallowed side-effect: " << effect.getValue()
diff --git a/flang/test/HLFIR/opt-array-slice-assign.fir b/flang/test/HLFIR/opt-array-slice-assign.fir
index 11bd97c1158342..3db47b1da8cd33 100644
--- a/flang/test/HLFIR/opt-array-slice-assign.fir
+++ b/flang/test/HLFIR/opt-array-slice-assign.fir
@@ -382,3 +382,427 @@ func.func @_QPtest6(%arg0: !fir.ref<!fir.array<?x?xf32>> {fir.bindc_name = "x"},
 }
 // CHECK-LABEL:   func.func @_QPtest6(
 // CHECK-NOT: hlfir.elemental
+
+// Check that 'x(9,:)=SUM(x(1:8,:),DIM=1)' is optimized
+// due to the LHS and RHS being disjoint array sections.
+func.func @test_disjoint_triple_index(%arg0: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "x"}) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c9 = arith.constant 9 : index
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %c1 = arith.constant 1 : index
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?x?xf32>>, !fir.dscope) -> (!fir.box<!fir.array<?x?xf32>>, !fir.box<!fir.array<?x?xf32>>)
+  %2:3 = fir.box_dims %1#1, %c1 : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+  %3 = arith.cmpi sgt, %2#1, %c0 : index
+  %4 = arith.select %3, %2#1, %c0 : index
+  %5 = fir.shape %c8, %4 : (index, index) -> !fir.shape<2>
+  %6 = hlfir.designate %1#0 (%c1:%c8:%c1, %c1:%2#1:%c1)  shape %5 : (!fir.box<!fir.array<?x?xf32>>, index, index, index, index, index, index, !fir.shape<2>) -> !fir.box<!fir.array<8x?xf32>>
+  %7 = fir.shape %4 : (index) -> !fir.shape<1>
+  %8 = hlfir.elemental %7 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+  ^bb0(%arg1: index):
+    %10 = fir.alloca f32 {bindc_name = ".sum.reduction"}
+    fir.store %cst to %10 : !fir.ref<f32>
+    fir.do_loop %arg2 = %c1 to %c8 step %c1 unordered {
+      %12 = fir.load %10 : !fir.ref<f32>
+      %13 = hlfir.designate %6 (%arg2, %arg1)  : (!fir.box<!fir.array<8x?xf32>>, index, index) -> !fir.ref<f32>
+      %14 = fir.load %13 : !fir.ref<f32>
+      %15 = arith.addf %12, %14 fastmath<fast> : f32
+      fir.store %15 to %10 : !fir.ref<f32>
+    }
+    %11 = fir.load %10 : !fir.ref<f32>
+    hlfir.yield_element %11 : f32
+  }
+  %9 = hlfir.designate %1#0 (%c9, %c1:%2#1:%c1)  shape %7 : (!fir.box<!fir.array<?x?xf32>>, index, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
+  hlfir.assign %8 to %9 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+  hlfir.destroy %8 : !hlfir.expr<?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @test_disjoint_triple_index(
+// CHECK-NOT: hlfir.elemental
+
+// Check that 'x(9,:)=SUM(x(9:9,:),DIM=1)' is not optimized.
+func.func @test_overlapping_triple_index(%arg0: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "x"}) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c9 = arith.constant 9 : index
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %c1 = arith.constant 1 : index
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?x?xf32>>, !fir.dscope) -> (!fir.box<!fir.array<?x?xf32>>, !fir.box<!fir.array<?x?xf32>>)
+  %2:3 = fir.box_dims %1#1, %c1 : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+  %3 = arith.cmpi sgt, %2#1, %c0 : index
+  %4 = arith.select %3, %2#1, %c0 : index
+  %5 = fir.shape %c8, %4 : (index, index) -> !fir.shape<2>
+  %6 = hlfir.designate %1#0 (%c9:%c9:%c1, %c1:%2#1:%c1)  shape %5 : (!fir.box<!fir.array<?x?xf32>>, index, index, index, index, index, index, !fir.shape<2>) -> !fir.box<!fir.array<8x?xf32>>
+  %7 = fir.shape %4 : (index) -> !fir.shape<1>
+  %8 = hlfir.elemental %7 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+  ^bb0(%arg1: index):
+    %10 = fir.alloca f32 {bindc_name = ".sum.reduction"}
+    fir.store %cst to %10 : !fir.ref<f32>
+    fir.do_loop %arg2 = %c1 to %c8 step %c1 unordered {
+      %12 = fir.load %10 : !fir.ref<f32>
+      %13 = hlfir.designate %6 (%arg2, %arg1)  : (!fir.box<!fir.array<8x?xf32>>, index, index) -> !fir.ref<f32>
+      %14 = fir.load %13 : !fir.ref<f32>
+      %15 = arith.addf %12, %14 fastmath<fast> : f32
+      fir.store %15 to %10 : !fir.ref<f32>
+    }
+    %11 = fir.load %10 : !fir.ref<f32>
+    hlfir.yield_element %11 : f32
+  }
+  %9 = hlfir.designate %1#0 (%c9, %c1:%2#1:%c1)  shape %7 : (!fir.box<!fir.array<?x?xf32>>, index, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
+  hlfir.assign %8 to %9 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+  hlfir.destroy %8 : !hlfir.expr<?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @test_overlapping_triple_index(
+// CHECK: hlfir.elemental
+
+// Check that 'x(9:ub) = x(lb:6) ...
[truncated]

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice

Comment on lines 170 to 172
// Slices are definitely disjoint.
DefinitelyIdentical,
// Slices are definitely identical.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Slices are definitely disjoint.
DefinitelyIdentical,
// Slices are definitely identical.
// Slices are definitely identical.
DefinitelyIdentical,
// Slices are definitely disjoint.

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

I wonder if the isLess could be improved/made even more generic by using MLIR/LLVM Presburger (I have never used it, so this is a random comment, and I do not expect a follow-up here).

static std::pair<mlir::Value, mlir::Value>
getOrderedBounds(const SectionDesc &desc) {
mlir::Value stride = desc.stride;
// Null stride means stride-1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Null stride means stride-1.
// Null stride means stride=1.

@vzakhari
Copy link
Contributor Author

Looks great, thanks!

I wonder if the isLess could be improved/made even more generic by using MLIR/LLVM Presburger (I have never used it, so this is a random comment, and I do not expect a follow-up here).

Thanks for the hint, Jean! I will learn about MLIR Presburger to see how it can be applied.

@vzakhari vzakhari merged commit af5d3af into llvm:main Dec 13, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants