Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoopVectorizer][ARM] Detect reduce(ext(mul(ext, ext))) patterns more reliably #115847

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

davemgreen
Copy link
Collaborator

We would detect ext(mul(ext, ext)) patterns when looking up through the tree, but not when looking down. This hopefully brings the cost model closer to the vplan version, avoiding some asserts and reducing the diffs needed in #113903.

… reliably.

We would detect ext(mul(ext, ext)) patterns when looking up through the tree,
but not when looking down. This hopefully brings the cost model closer to the
vplan version, avoiding some asserts.
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-vectorizers

Author: David Green (davemgreen)

Changes

We would detect ext(mul(ext, ext)) patterns when looking up through the tree, but not when looking down. This hopefully brings the cost model closer to the vplan version, avoiding some asserts and reducing the diffs needed in #113903.


Full diff: https://github.com/llvm/llvm-project/pull/115847.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+14-1)
  • (modified) llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll (+19-19)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1ebc62f9843905..568aeae2260f11 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5818,6 +5818,15 @@ LoopVectorizationCostModel::getReductionPatternCost(
   if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) &&
       RetI->user_back()->getOpcode() == Instruction::Add) {
     RetI = RetI->user_back();
+  } else if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) &&
+             ((match(I, m_ZExt(m_Value())) &&
+               match(RetI->user_back(), m_OneUse(m_ZExt(m_Value())))) ||
+              (match(I, m_SExt(m_Value())) &&
+               match(RetI->user_back(), m_OneUse(m_SExt(m_Value()))))) &&
+             RetI->user_back()->user_back()->getOpcode() == Instruction::Add) {
+    // This looks through ext(mul(ext, ext)), making sure that the extensions
+    // are the same sign.
+    RetI = RetI->user_back()->user_back();
   }
 
   // Test if the found instruction is a reduction, and if not return an invalid
@@ -7316,7 +7325,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
     // Also include the operands of instructions in the chain, as the cost-model
     // may mark extends as free.
     //
-    // For ARM, some of the instruction can folded into the reducion
+    // For ARM, some of the instructions can be folded into the reduction
     // instruction. So we need to mark all folded instructions free.
     // For example: We can fold reduce(mul(ext(A), ext(B))) into one
     // instruction.
@@ -7324,6 +7333,10 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
       for (Value *Op : ChainOp->operands()) {
         if (auto *I = dyn_cast<Instruction>(Op)) {
           ChainOpsAndOperands.insert(I);
+          if (IsZExtOrSExt(I->getOpcode())) {
+            ChainOpsAndOperands.insert(I);
+            I = dyn_cast<Instruction>(I->getOperand(0));
+          }
           if (I->getOpcode() == Instruction::Mul) {
             auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
             auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index c115c91cff896c..a4f96adccb64b5 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1722,10 +1722,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -8
 ; CHECK-NEXT:    [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
@@ -1733,26 +1733,26 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP6]], [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <8 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <8 x i32> [[TMP11]] to <8 x i64>
+; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP12]])
 ; CHECK-NEXT:    [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
 ; CHECK:       middle.block:

@davemgreen
Copy link
Collaborator Author

@ElvisWang123 if this makes the cost model change easier or more difficult let me know. It should just bring the two cost-models closer together, but the new costmodel will get the benefit already. We can easily drop this if its simpler that way.

@ElvisWang123
Copy link
Contributor

@ElvisWang123 if this makes the cost model change easier or more difficult let me know. It should just bring the two cost-models closer together, but the new costmodel will get the benefit already. We can easily drop this if its simpler that way.

These test cases show how VPlan-based cost model can calculate cost more accurately.
Without the assertion of checking VF between legacy and vplan-based, I think it is good to keep these tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants