Skip to content

[LoopVectorizer] Bundle partial reductions with different extensions #136997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/SamTebbs33/mulacc-partial-reductions
Choose a base branch
from

Conversation

SamTebbs33
Copy link
Collaborator

This PR adds support for extensions of different signedness to VPMulAccumulateReductionRecipe and allows such partial reductions to be bundled into that class.

This PR adds support for extensions of different signedness to
VPMulAccumulateReductionRecipe and allows such partial reductions to be
bundled into that class.
@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: Sam Tebbs (SamTebbs33)

Changes

This PR adds support for extensions of different signedness to VPMulAccumulateReductionRecipe and allows such partial reductions to be bundled into that class.


Patch is 25.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136997.diff

5 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+27-15)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+19-8)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+12-13)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll (+28-28)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll (+13-16)
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 20d272e69e6e7..e11f608d068da 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2493,11 +2493,13 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
 /// recipe is abstract and needs to be lowered to concrete recipes before
 /// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
 class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
-  /// Opcode of the extend recipe.
-  Instruction::CastOps ExtOp;
+  /// Opcodes of the extend recipes.
+  Instruction::CastOps ExtOp0;
+  Instruction::CastOps ExtOp1;
 
-  /// Non-neg flag of the extend recipe.
-  bool IsNonNeg = false;
+  /// Non-neg flags of the extend recipe.
+  bool IsNonNeg0 = false;
+  bool IsNonNeg1 = false;
 
   Type *ResultTy;
 
@@ -2512,7 +2514,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             MulAcc->getCondOp(), MulAcc->isOrdered(),
             WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
             MulAcc->getDebugLoc()),
-        ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
+        ExtOp0(MulAcc->getExt0Opcode()), ExtOp1(MulAcc->getExt1Opcode()),
+        IsNonNeg0(MulAcc->isNonNeg0()), IsNonNeg1(MulAcc->isNonNeg1()),
         ResultTy(MulAcc->getResultType()),
         IsPartialReduction(MulAcc->isPartialReduction()) {}
 
@@ -2526,7 +2529,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             R->getCondOp(), R->isOrdered(),
             WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
             R->getDebugLoc()),
-        ExtOp(Ext0->getOpcode()), IsNonNeg(Ext0->isNonNeg()),
+        ExtOp0(Ext0->getOpcode()), ExtOp1(Ext1->getOpcode()),
+        IsNonNeg0(Ext0->isNonNeg()), IsNonNeg1(Ext1->isNonNeg()),
         ResultTy(ResultTy),
         IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
     assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
@@ -2542,7 +2546,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             R->getCondOp(), R->isOrdered(),
             WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
             R->getDebugLoc()),
-        ExtOp(Instruction::CastOps::CastOpsEnd) {
+        ExtOp0(Instruction::CastOps::CastOpsEnd),
+        ExtOp1(Instruction::CastOps::CastOpsEnd) {
     assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
                Instruction::Add &&
            "The reduction instruction in MulAccumulateReductionRecipe must be "
@@ -2586,19 +2591,26 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
   VPValue *getVecOp1() const { return getOperand(2); }
 
   /// Return if this MulAcc recipe contains extend instructions.
-  bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
+  bool isExtended() const { return ExtOp0 != Instruction::CastOps::CastOpsEnd; }
 
   /// Return if the operands of mul instruction come from same extend.
-  bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
+  bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); }
 
-  /// Return the opcode of the underlying extend.
-  Instruction::CastOps getExtOpcode() const { return ExtOp; }
+  /// Return the opcode of the underlying extends.
+  Instruction::CastOps getExt0Opcode() const { return ExtOp0; }
+  Instruction::CastOps getExt1Opcode() const { return ExtOp1; }
+
+  /// Return if the first extend's opcode is ZExt.
+  bool isZExt0() const { return ExtOp0 == Instruction::CastOps::ZExt; }
+
+  /// Return if the second extend's opcode is ZExt.
+  bool isZExt1() const { return ExtOp1 == Instruction::CastOps::ZExt; }
 
-  /// Return if the extend opcode is ZExt.
-  bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
+  /// Return the non negative flag of the first ext recipe.
+  bool isNonNeg0() const { return IsNonNeg0; }
 
-  /// Return the non negative flag of the ext recipe.
-  bool isNonNeg() const { return IsNonNeg; }
+  /// Return the non negative flag of the second ext recipe.
+  bool isNonNeg1() const { return IsNonNeg1; }
 
   /// Return if the underlying reduction recipe is a partial reduction.
   bool isPartialReduction() const { return IsPartialReduction; }
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index bdc1d49ec88d9..53698fe15d4f8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2438,14 +2438,14 @@ VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
     return Ctx.TTI.getPartialReductionCost(
         Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
         Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
-        TTI::getPartialReductionExtendKind(getExtOpcode()),
-        TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul);
+        TTI::getPartialReductionExtendKind(getExt0Opcode()),
+        TTI::getPartialReductionExtendKind(getExt1Opcode()), Instruction::Mul);
   }
 
   Type *RedTy = Ctx.Types.inferScalarType(this);
   auto *SrcVecTy =
       cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
-  return Ctx.TTI.getMulAccReductionCost(isZExt(), RedTy, SrcVecTy,
+  return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy,
                                         Ctx.CostKind);
 }
 
@@ -2530,13 +2530,24 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   if (isExtended())
     O << "(";
   getVecOp0()->printAsOperand(O, SlotTracker);
-  if (isExtended())
-    O << " extended to " << *getResultType() << "), (";
-  else
+  if (isExtended()) {
+    O << " ";
+    if (isZExt0())
+      O << "zero-";
+    else
+      O << "sign-";
+    O << "extended to " << *getResultType() << "), (";
+  } else
     O << ", ";
   getVecOp1()->printAsOperand(O, SlotTracker);
-  if (isExtended())
-    O << " extended to " << *getResultType() << ")";
+  if (isExtended()) {
+    O << " ";
+    if (isZExt1())
+      O << "zero-";
+    else
+      O << "sign-";
+    O << "extended to " << *getResultType() << ")";
+  }
   if (isConditional()) {
     O << ", ";
     getCondOp()->printAsOperand(O, SlotTracker);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 7a8cbd908c795..f305e09396c1c 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2121,12 +2121,12 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
   VPValue *Op0, *Op1;
   if (MulAcc->isExtended()) {
     Type *RedTy = MulAcc->getResultType();
-    if (MulAcc->isZExt())
-      Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
-                                  RedTy, MulAcc->isNonNeg(),
+    if (MulAcc->isZExt0())
+      Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
+                                  RedTy, MulAcc->isNonNeg0(),
                                   MulAcc->getDebugLoc());
     else
-      Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
+      Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
                                   RedTy, MulAcc->getDebugLoc());
     Op0->getDefiningRecipe()->insertBefore(MulAcc);
     // Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
@@ -2134,13 +2134,14 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
     if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
       Op1 = Op0;
     } else {
-      if (MulAcc->isZExt())
-        Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-                                    RedTy, MulAcc->isNonNeg(),
-                                    MulAcc->getDebugLoc());
+      if (MulAcc->isZExt1())
+        Op1 = new VPWidenCastRecipe(MulAcc->getExt1Opcode(),
+                                    MulAcc->getVecOp1(), RedTy,
+                                    MulAcc->isNonNeg1(), MulAcc->getDebugLoc());
       else
-        Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
-                                    RedTy, MulAcc->getDebugLoc());
+        Op1 =
+            new VPWidenCastRecipe(MulAcc->getExt1Opcode(), MulAcc->getVecOp1(),
+                                  RedTy, MulAcc->getDebugLoc());
       Op1->getDefiningRecipe()->insertBefore(MulAcc);
     }
   } else {
@@ -2451,10 +2452,8 @@ tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
 
   auto *Ext0 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(0));
   auto *Ext1 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(1));
-  // TODO: Make work with extends of different signedness
   if (!Ext0 || Ext0->hasMoreThanOneUniqueUser() || !Ext1 ||
-      Ext1->hasMoreThanOneUniqueUser() ||
-      Ext0->getOpcode() != Ext1->getOpcode())
+      Ext1->hasMoreThanOneUniqueUser())
     return;
 
   auto *AbstractR = new VPMulAccumulateReductionRecipe(PRed, BinOp, Ext0, Ext1,
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
index f581b6f384bc8..6e1dc7230205b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
@@ -22,19 +22,19 @@ define i32 @dotp_z_s(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
 ; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
 ; CHECK-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NEXT:    [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NEXT:    [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[TMP15:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
@@ -60,19 +60,19 @@ define i32 @dotp_z_s(ptr %a, ptr %b) #0 {
 ; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
 ; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
 ; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NOI8MM-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
@@ -121,19 +121,19 @@ define i32 @dotp_s_z(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
 ; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
 ; CHECK-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NEXT:    [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NEXT:    [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NEXT:    [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NEXT:    [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
@@ -159,19 +159,19 @@ define i32 @dotp_s_z(ptr %a, ptr %b) #0 {
 ; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP0]]
 ; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
+; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP12]])
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP16:%.*]] = mul <16 x i32> [[TMP15]], [[TMP11]]
+; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP16]])
 ; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
 ; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
 ; CHECK-NOI8MM-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
index 4dc83ed8a95b5..3911e03de2f50 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -6,7 +6,7 @@ target triple = "aarch64-none-unknown-elf"
 
 ; Tests for printing VPlans that are enabled under AArch64
 
-define i32 @print_partial_reduction(ptr %a, ptr %b) {
+define i32 @print_partial_reduction_sext_zext(ptr %a, ptr %b) {
 ; CHECK:      VPlan 'Initial VPlan for VF={8,16},UF>=1' {
 ; CHECK-NEXT: Live-in vp<[[VFxUF:%.]]> = VF * UF
 ; CHECK-NEXT: Live-in vp<[[VEC_TC:%.+]]> = vector-trip-count
@@ -21,18 +21,15 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) {
 ; CHECK-NEXT: <x1> vector loop: {
 ; CHECK-NEXT: vector.body:
 ; CHECK-NEXT:   EMIT vp<[[CAN_IV:%.+]]> = CANONICAL-INDUCTION ir<0>, vp<[[CAN_IV_NEXT:%.+]]>
-; CHECK-NEXT:   WIDEN-REDUCTION-PHI ir<[[ACC:%.+]]> = phi ir<0>, ir<[[REDUCE:%.+]]> (VF scaled b...
[truncated]

@@ -2493,11 +2493,13 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
/// recipe is abstract and needs to be lowered to concrete recipes before
/// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
/// Opcode of the extend recipe.
Instruction::CastOps ExtOp;
/// Opcodes of the extend recipes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It might be clearer to avoid many getters functions suffixed with 0 or 1. One option could be something like:

struct VecOperandInfo {
  Instruction::CastOps ExtOp{Instruction::CastOps::CastOpsEnd};
  bool IsNonNeg = false;
};

VecOperandInfo VecOpInfo[2];

Then instead of getters for each value you could just have getVecOp0Info() and getVecOp1Info(), which return VecOperandInfo.

}

Type *RedTy = Ctx.Types.inferScalarType(this);
auto *SrcVecTy =
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
return Ctx.TTI.getMulAccReductionCost(isZExt(), RedTy, SrcVecTy,
return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TTI hook also needs updating to reflect the separate extends.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants