-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoopVectorizer][ARM] Detect reduce(ext(mul(ext, ext))) patterns more reliably #115847
base: main
Are you sure you want to change the base?
Conversation
… reliably. We would detect ext(mul(ext, ext)) patterns when looking up through the tree, but not when looking down. This hopefully brings the cost model closer to the vplan version, avoiding some asserts.
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-vectorizers Author: David Green (davemgreen) ChangesWe would detect ext(mul(ext, ext)) patterns when looking up through the tree, but not when looking down. This hopefully brings the cost model closer to the vplan version, avoiding some asserts and reducing the diffs needed in #113903. Full diff: https://github.com/llvm/llvm-project/pull/115847.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1ebc62f9843905..568aeae2260f11 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5818,6 +5818,15 @@ LoopVectorizationCostModel::getReductionPatternCost(
if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) &&
RetI->user_back()->getOpcode() == Instruction::Add) {
RetI = RetI->user_back();
+ } else if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) &&
+ ((match(I, m_ZExt(m_Value())) &&
+ match(RetI->user_back(), m_OneUse(m_ZExt(m_Value())))) ||
+ (match(I, m_SExt(m_Value())) &&
+ match(RetI->user_back(), m_OneUse(m_SExt(m_Value()))))) &&
+ RetI->user_back()->user_back()->getOpcode() == Instruction::Add) {
+ // This looks through ext(mul(ext, ext)), making sure that the extensions
+ // are the same sign.
+ RetI = RetI->user_back()->user_back();
}
// Test if the found instruction is a reduction, and if not return an invalid
@@ -7316,7 +7325,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
// Also include the operands of instructions in the chain, as the cost-model
// may mark extends as free.
//
- // For ARM, some of the instruction can folded into the reducion
+ // For ARM, some of the instructions can be folded into the reduction
// instruction. So we need to mark all folded instructions free.
// For example: We can fold reduce(mul(ext(A), ext(B))) into one
// instruction.
@@ -7324,6 +7333,10 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
for (Value *Op : ChainOp->operands()) {
if (auto *I = dyn_cast<Instruction>(Op)) {
ChainOpsAndOperands.insert(I);
+ if (IsZExtOrSExt(I->getOpcode())) {
+ ChainOpsAndOperands.insert(I);
+ I = dyn_cast<Instruction>(I->getOperand(0));
+ }
if (I->getOpcode() == Instruction::Mul) {
auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index c115c91cff896c..a4f96adccb64b5 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1722,10 +1722,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[TMP0:%.*]] = add nsw i32 [[N]], -1
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
; CHECK-NEXT: [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
-; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[TMP2]], -8
; CHECK-NEXT: [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -1733,26 +1733,26 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT: [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT: [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT: [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT: [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT: [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
+; CHECK-NEXT: [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP6]], [[TMP5]]
+; CHECK-NEXT: [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
+; CHECK-NEXT: [[TMP13:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
+; CHECK-NEXT: [[TMP14:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
+; CHECK-NEXT: [[TMP11:%.*]] = mul nsw <8 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT: [[TMP12:%.*]] = sext <8 x i32> [[TMP11]] to <8 x i64>
+; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
; CHECK-NEXT: [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP12]])
; CHECK-NEXT: [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
; CHECK: middle.block:
|
@ElvisWang123 if this makes the cost model change easier or more difficult let me know. It should just bring the two cost-models closer together, but the new costmodel will get the benefit already. We can easily drop this if its simpler that way. |
These test cases show how VPlan-based cost model can calculate cost more accurately. |
We would detect ext(mul(ext, ext)) patterns when looking up through the tree, but not when looking down. This hopefully brings the cost model closer to the vplan version, avoiding some asserts and reducing the diffs needed in #113903.