Skip to content

Conversation

@kasuga-fj
Copy link
Contributor

@kasuga-fj kasuga-fj commented Jul 14, 2025

Before this patch, when a reduction exists in the loop, the legality check of LoopInterchange only verified if there exists a non-reassociative floating-point instruction in the reduction calculation. However, it is insufficient, because reordering integer reductions can also lead to incorrect transformations. Consider the following example:

int A[2][2] = {
  { INT_MAX, INT_MAX },
  { INT_MIN, INT_MIN },
};

int sum = 0;
for (int i = 0; i < 2; i++)
  for (int j = 0; j < 2; j++)
    sum += A[j][i];

To make this exchange legal, we must drop nuw/nsw flags from the instructions involved in the reduction operations.

This patch extends the legality check to correctly handle such cases. In particular, for integer addition and multiplication, it verifies that the nsw and nuw flags are set on involved instructions, and drop them when the transformation actually performed. This patch also introduces explicit checks for the kind of reduction and permits only those that are known to be safe for interchange. Consequently, some "unknown" reductions (at the moment, FindFirst* and FindLast*) are rejected.

Fix #148228

@llvmbot
Copy link
Member

llvmbot commented Jul 14, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Ryotaro Kasuga (kasuga-fj)

Changes

Patch is 29.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148612.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LoopInterchange.cpp (+56-1)
  • (modified) llvm/test/Transforms/LoopInterchange/pr48212.ll (+1-1)
  • (added) llvm/test/Transforms/LoopInterchange/reductions-kind.ll (+864)
diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
index a5008907b9014..a2aa72e1a01f2 100644
--- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
@@ -812,7 +812,62 @@ static PHINode *findInnerReductionPhi(Loop *L, Value *V) {
         // Detect floating point reduction only when it can be reordered.
         if (RD.getExactFPMathInst() != nullptr)
           return nullptr;
-        return PHI;
+
+        RecurKind RK = RD.getRecurrenceKind();
+        switch (RK) {
+        case RecurKind::Or:
+        case RecurKind::And:
+        case RecurKind::Xor:
+        case RecurKind::SMin:
+        case RecurKind::SMax:
+        case RecurKind::UMin:
+        case RecurKind::UMax:
+        case RecurKind::FAdd:
+        case RecurKind::FMul:
+        case RecurKind::FMin:
+        case RecurKind::FMax:
+        case RecurKind::FMinimum:
+        case RecurKind::FMaximum:
+        case RecurKind::FMinimumNum:
+        case RecurKind::FMaximumNum:
+        case RecurKind::FMulAdd:
+        case RecurKind::AnyOf:
+          return PHI;
+
+        // Change the order of integer addition/multiplication may change the
+        // semantics. Consider the following case:
+        //
+        //  int A[2][2] = {{ INT_MAX, INT_MAX }, { INT_MIN, INT_MIN }};
+        //  int sum = 0;
+        //  for (int i = 0; i < 2; i++)
+        //    for (int j = 0; j < 2; j++)
+        //      sum += A[j][i];
+        //
+        // If the above loops are exchanged, the addition will cause an
+        // overflow. To prove the legality, we must ensure that all reduction
+        // operations don't have nuw/nsw flags.
+        case RecurKind::Add:
+        case RecurKind::Mul: {
+          unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
+          SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);
+
+          // FIXME: Is this check necessary?
+          if (Ops.empty())
+            return nullptr;
+          for (Instruction *I : Ops) {
+            // FIXME: Is this check necessary?
+            if (I->getOpcode() != OpCode)
+              return nullptr;
+
+            // Reject if the reduction operation has nuw/nsw flags.
+            if (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())
+              return nullptr;
+          }
+          return PHI;
+        }
+        default:
+          return nullptr;
+        }
       }
       return nullptr;
     }
diff --git a/llvm/test/Transforms/LoopInterchange/pr48212.ll b/llvm/test/Transforms/LoopInterchange/pr48212.ll
index 936c53e217540..cb1300846cf0f 100644
--- a/llvm/test/Transforms/LoopInterchange/pr48212.ll
+++ b/llvm/test/Transforms/LoopInterchange/pr48212.ll
@@ -38,7 +38,7 @@ for.body3:                                        ; preds = %L2, %for.inc
   %idxprom4 = sext i32 %k1.03 to i64
   %arrayidx5 = getelementptr inbounds [5 x i32], ptr %arrayidx, i64 0, i64 %idxprom4
   %0 = load i32, ptr %arrayidx5
-  %add = add nsw i32 %temp.12, %0
+  %add = add i32 %temp.12, %0
   br label %for.inc
 
 for.inc:                                          ; preds = %for.body3
diff --git a/llvm/test/Transforms/LoopInterchange/reductions-kind.ll b/llvm/test/Transforms/LoopInterchange/reductions-kind.ll
new file mode 100644
index 0000000000000..d9e4d58a1780e
--- /dev/null
+++ b/llvm/test/Transforms/LoopInterchange/reductions-kind.ll
@@ -0,0 +1,864 @@
+; RUN: opt < %s -passes=loop-interchange -cache-line-size=64 -pass-remarks-output=%t -disable-output \
+; RUN:     -verify-dom-info -verify-loop-info -verify-loop-lcssa
+; RUN: FileCheck -input-file=%t %s
+
+; int sum = 0;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     sum += A[j][i];
+
+; CHECK:      --- !Missed
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            UnsupportedPHIOuter
+; CHECK-NEXT: Function:        reduction_add
+define void @reduction_add(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %sum.i = phi i32 [ 0, %entry ], [ %sum.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %sum.j = phi i32 [ %sum.i, %for.i.header ], [ %sum.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %sum.j.next = add nsw i32 %sum.j, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %sum.i.lcssa = phi i32 [ %sum.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_wrap_add
+define void @reduction_wrap_add(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %sum.i = phi i32 [ 0, %entry ], [ %sum.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %sum.j = phi i32 [ %sum.i, %for.i.header ], [ %sum.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %sum.j.next = add i32 %sum.j, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %sum.i.lcssa = phi i32 [ %sum.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+; CHECK:      --- !Missed
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            UnsupportedPHIOuter
+; CHECK-NEXT: Function:        reduction_cast_add
+define void @reduction_cast_add(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %sum.i = phi i32 [ 0, %entry ], [ %sum.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %sum.j = phi i32 [ %sum.i, %for.i.header ], [ %sum.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %sum.j.trunc = trunc i32 %sum.j to i16
+  %sum.j.ext = zext i16 %sum.j.trunc to i32
+  %sum.j.next = add nsw i32 %sum.j.ext, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %sum.i.lcssa = phi i32 [ %sum.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+
+; int prod = 1;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     prod *= A[j][i];
+
+; CHECK:      --- !Missed
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            UnsupportedPHIOuter
+; CHECK-NEXT: Function:        reduction_mul
+define void @reduction_mul(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %prod.i = phi i32 [ 1, %entry ], [ %prod.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %prod.j = phi i32 [ %prod.i, %for.i.header ], [ %prod.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %prod.j.next = mul nsw i32 %prod.j, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %prod.i.lcssa = phi i32 [ %prod.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_wrap_mul
+define void @reduction_wrap_mul(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %prod.i = phi i32 [ 1, %entry ], [ %prod.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %prod.j = phi i32 [ %prod.i, %for.i.header ], [ %prod.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %prod.j.next = mul i32 %prod.j, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %prod.i.lcssa = phi i32 [ %prod.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+
+; int b_or = 0;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     b_or |= A[j][i];
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_or
+define void @reduction_or(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %or.i = phi i32 [ 0, %entry ], [ %or.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %or.j = phi i32 [ %or.i, %for.i.header ], [ %or.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %or.j.next = or i32 %or.j, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %or.i.lcssa = phi i32 [ %or.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+
+; int b_and = -1;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     b_and &= A[j][i];
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_and
+define void @reduction_and(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %and.i = phi i32 [ -1, %entry ], [ %and.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %and.j = phi i32 [ %and.i, %for.i.header ], [ %and.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %and.j.next = and i32 %and.j, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %and.i.lcssa = phi i32 [ %and.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+
+; int b_xor = 0;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     b_xor ^= A[j][i];
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_xor
+define void @reduction_xor(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %xor.i = phi i32 [ 0, %entry ], [ %xor.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %xor.j = phi i32 [ %xor.i, %for.i.header ], [ %xor.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %xor.j.next = xor i32 %xor.j, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %xor.i.lcssa = phi i32 [ %xor.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+
+; int smin = init;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     smin = (A[j][i] < smin) ? A[j][i] : smin;
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_smin
+define void @reduction_smin(ptr %A, i32 %init) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %smin.i = phi i32 [ %init, %entry ], [ %smin.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %smin.j = phi i32 [ %smin.i, %for.i.header ], [ %smin.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %cmp = icmp slt i32 %a, %smin.j
+  %smin.j.next = select i1 %cmp, i32 %a, i32 %smin.j
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %smin.i.lcssa = phi i32 [ %smin.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+
+; int smax = init;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     smax = (A[j][i] > smax) ? A[j][i] : smax;
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_smax
+define void @reduction_smax(ptr %A, i32 %init) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %smax.i = phi i32 [ %init, %entry ], [ %smax.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %smax.j = phi i32 [ %smax.i, %for.i.header ], [ %smax.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %cmp = icmp sgt i32 %a, %smax.j
+  %smax.j.next = select i1 %cmp, i32 %a, i32 %smax.j
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %smax.i.lcssa = phi i32 [ %smax.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+
+; unsigned umin = init;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     umin = (A[j][i] < umin) ? A[j][i] : umin;
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_umin
+define void @reduction_umin(ptr %A, i32 %init) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %umin.i = phi i32 [ %init, %entry ], [ %umin.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %umin.j = phi i32 [ %umin.i, %for.i.header ], [ %umin.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %cmp = icmp ult i32 %a, %umin.j
+  %umin.j.next = select i1 %cmp, i32 %a, i32 %umin.j
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %umin.i.lcssa = phi i32 [ %umin.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+
+; unsigned umax = 0;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     smax = (A[j][i] > smax) ? A[j][i] : smax;
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_umax
+define void @reduction_umax(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %umax.i = phi i32 [ 0, %entry ], [ %umax.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %umax.j = phi i32 [ %umax.i, %for.i.header ], [ %umax.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %cmp = icmp ugt i32 %a, %umax.j
+  %umax.j.next = select i1 %cmp, i32 %a, i32 %umax.j
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %umax.i.lcssa = phi i32 [ %umax.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+
+; int any_of = 0;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     any_of = (A[j][i] == 42) ? 1 : any_of;
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_anyof
+define void @reduction_anyof(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %anyof.i = phi i32 [ 0, %entry ], [ %anyof.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %anyof.j = phi i32 [ %anyof.i, %for.i.header ], [ %anyof.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load i32, ptr %idx, align 4
+  %cmp = icmp eq i32 %a, 42
+  %anyof.j.next = select i1 %cmp, i32 1, i32 %anyof.j
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %anyof.i.lcssa = phi i32 [ %anyof.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+; float sum = 0;
+; for (int i = 0; i < 2; i++)
+;   for (int j = 0; j < 2; j++)
+;     sum += A[j][i];
+
+; CHECK:      --- !Missed
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            UnsupportedPHIOuter
+; CHECK-NEXT: Function:        reduction_fadd
+define void @reduction_fadd(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %sum.i = phi float [ 0.0, %entry ], [ %sum.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %sum.j = phi float [ %sum.i, %for.i.header ], [ %sum.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load float, ptr %idx, align 4
+  %sum.j.next = fadd float %sum.j, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %sum.i.lcssa = phi float [ %sum.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i = icmp slt i32 %i.inc, 2
+  br i1 %cmp.i, label %for.i.header, label %exit
+
+exit:
+  ret void
+}
+
+; CHECK:      --- !Pass
+; CHECK-NEXT: Pass:            loop-interchange
+; CHECK-NEXT: Name:            Interchanged
+; CHECK-NEXT: Function:        reduction_reassoc_fadd
+define void @reduction_reassoc_fadd(ptr %A) {
+entry:
+  br label %for.i.header
+
+for.i.header:
+  %i = phi i32 [ 0, %entry ], [ %i.inc, %for.i.latch ]
+  %sum.i = phi float [ 0.0, %entry ], [ %sum.i.lcssa, %for.i.latch ]
+  br label %for.j
+
+for.j:
+  %j = phi i32 [ 0, %for.i.header ], [ %j.inc, %for.j ]
+  %sum.j = phi float [ %sum.i, %for.i.header ], [ %sum.j.next, %for.j ]
+  %idx = getelementptr inbounds [2 x [2 x i32]], ptr %A, i32 0, i32 %j, i32 %i
+  %a = load float, ptr %idx, align 4
+  %sum.j.next = fadd reassoc float %sum.j, %a
+  %j.inc = add i32 %j, 1
+  %cmp.j = icmp slt i32 %j.inc, 2
+  br i1 %cmp.j, label %for.j, label %for.i.latch
+
+for.i.latch:
+  %sum.i.lcssa = phi float [ %sum.j.next, %for.j ]
+  %i.inc = add i32 %i, 1
+  %cmp.i...
[truncated]

Comment on lines 849 to 867
case RecurKind::Add:
case RecurKind::Mul: {
unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);

// FIXME: Is this check necessary?
if (Ops.empty())
return nullptr;
for (Instruction *I : Ops) {
// FIXME: Is this check necessary?
if (I->getOpcode() != OpCode)
return nullptr;

// Reject if the reduction operation has nuw/nsw flags.
if (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())
return nullptr;
}
return PHI;
}
Copy link
Contributor Author

@kasuga-fj kasuga-fj Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure if this check is enough...

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to drop the nuw/nsw flags during the transform instead.

(Generally speaking, adding flags should never regress optimization.)

Copy link
Collaborator

@sjoerdmeijer sjoerdmeijer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is also what I had mind at the end of last week when we discussed this case, so the approach looks good to me. Two nits on FIXMEs inlined.

unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);

// FIXME: Is this check necessary?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't tell for certain because I haven't tried, but when I look at the implementation of getReductionOpChain I see it has some bail outs to return an empty set, for which I suspect a test-case could be written. I don't know if such cases would then arrive here though, or be rejected earlier. If that is the case, maybe an assert is in place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain this scenario can actually occur, but for now, I'd prefer to bail out here rather than adding an assertion for safety.

Copy link
Collaborator

@sjoerdmeijer sjoerdmeijer Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, you should be able to write a test-case for it.
But I don't have a too strong opinions on this, it's just a bail out anyway, so is fine by me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it seems unlikely that anything unexpected would happen, so using assert here might be fine, but I'm not entirely confident.

return nullptr;
for (Instruction *I : Ops) {
// FIXME: Is this check necessary?
if (I->getOpcode() != OpCode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these two different opcodes: the instruction opcode, and the recurrence descriptor opcode? And then we are checking whether these are the same for all instructions in the reduction chain. I can see how that maybe is not the case (maybe with adds and subs), but more importantly, I don't really see what this is guaranteeing, or what you would like to achieve with this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a glance, it doesn't seem to occur for Add or Mul. I replaced it with assert.

@sjoerdmeijer sjoerdmeijer requested a review from sebpop July 14, 2025 13:07
@sjoerdmeijer
Copy link
Collaborator

My feedback and @nikic's crossed, I had not seen his when I pressed the button, so maybe look first into his remarks first.

@kasuga-fj kasuga-fj changed the title [LoopInterchange] Reject interchange if non-reassociative reduction exists [LoopInterchange] Drop nuw/nsw flags from reduction ops when interchanging Jul 15, 2025
Copy link
Contributor Author

@kasuga-fj kasuga-fj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to drop the nuw/nsw flags during the transform instead.

(Generally speaking, adding flags should never regress optimization.)

Thanks, changed in such a way.

unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);

// FIXME: Is this check necessary?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain this scenario can actually occur, but for now, I'd prefer to bail out here rather than adding an assertion for safety.

return nullptr;
for (Instruction *I : Ops) {
// FIXME: Is this check necessary?
if (I->getOpcode() != OpCode)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a glance, it doesn't seem to occur for Add or Mul. I replaced it with assert.

Copy link
Collaborator

@sjoerdmeijer sjoerdmeijer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy with this, if @nikic is happy too.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kasuga-fj
Copy link
Contributor Author

Thanks for the review!

@kasuga-fj kasuga-fj merged commit b3c293c into llvm:main Jul 15, 2025
9 checks passed
@kasuga-fj kasuga-fj deleted the loop-interchange-fix-reduction branch September 2, 2025 11:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[LoopInterchange] Incorrect handling of reductions

4 participants