@@ -525,6 +525,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
525525 case VPRecipeBase::VPInstructionSC:
526526 case VPRecipeBase::VPReductionEVLSC:
527527 case VPRecipeBase::VPReductionSC:
528+ case VPRecipeBase::VPMulAccumulateReductionSC:
529+ case VPRecipeBase::VPExtendedReductionSC:
528530 case VPRecipeBase::VPReplicateSC:
529531 case VPRecipeBase::VPScalarIVStepsSC:
530532 case VPRecipeBase::VPVectorPointerSC:
@@ -609,13 +611,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
609611 DisjointFlagsTy (bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
610612 };
611613
614+ struct NonNegFlagsTy {
615+ char NonNeg : 1 ;
616+ NonNegFlagsTy (bool IsNonNeg) : NonNeg(IsNonNeg) {}
617+ };
618+
612619private:
613620 struct ExactFlagsTy {
614621 char IsExact : 1 ;
615622 };
616- struct NonNegFlagsTy {
617- char NonNeg : 1 ;
618- };
619623 struct FastMathFlagsTy {
620624 char AllowReassoc : 1 ;
621625 char NoNaNs : 1 ;
@@ -709,6 +713,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
709713 : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
710714 DisjointFlags(DisjointFlags) {}
711715
716+ template <typename IterT>
717+ VPRecipeWithIRFlags (const unsigned char SC, IterT Operands,
718+ NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
719+ : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
720+ NonNegFlags(NonNegFlags) {}
721+
712722protected:
713723 template <typename IterT>
714724 VPRecipeWithIRFlags (const unsigned char SC, IterT Operands,
@@ -728,7 +738,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
728738 R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
729739 R->getVPDefID () == VPRecipeBase::VPReplicateSC ||
730740 R->getVPDefID () == VPRecipeBase::VPVectorEndPointerSC ||
731- R->getVPDefID () == VPRecipeBase::VPVectorPointerSC;
741+ R->getVPDefID () == VPRecipeBase::VPVectorPointerSC ||
742+ R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
743+ R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
732744 }
733745
734746 static inline bool classof (const VPUser *U) {
@@ -820,6 +832,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
820832
821833 FastMathFlags getFastMathFlags () const ;
822834
835+ // / Returns true if the recipe has non-negative flag.
836+ bool hasNonNegFlag () const { return OpType == OperationType::NonNegOp; }
837+
838+ bool isNonNeg () const {
839+ assert (OpType == OperationType::NonNegOp &&
840+ " recipe doesn't have a NNEG flag" );
841+ return NonNegFlags.NonNeg ;
842+ }
843+
823844 bool hasNoUnsignedWrap () const {
824845 assert (OpType == OperationType::OverflowingBinOp &&
825846 " recipe doesn't have a NUW flag" );
@@ -1231,11 +1252,22 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
12311252 : VPRecipeWithIRFlags(VPDefOpcode, Operands, I), VPIRMetadata(I),
12321253 Opcode (I.getOpcode()) {}
12331254
1255+ template <typename IterT>
1256+ VPWidenRecipe (unsigned VPDefOpcode, unsigned Opcode,
1257+ iterator_range<IterT> Operands, bool NUW, bool NSW, DebugLoc DL)
1258+ : VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL),
1259+ Opcode(Opcode) {}
1260+
12341261public:
12351262 template <typename IterT>
12361263 VPWidenRecipe (Instruction &I, iterator_range<IterT> Operands)
12371264 : VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
12381265
1266+ template <typename IterT>
1267+ VPWidenRecipe (unsigned Opcode, iterator_range<IterT> Operands, bool NUW,
1268+ bool NSW, DebugLoc DL)
1269+ : VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {}
1270+
12391271 ~VPWidenRecipe () override = default ;
12401272
12411273 VPWidenRecipe *clone () override {
@@ -1280,8 +1312,15 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
12801312 " opcode of underlying cast doesn't match" );
12811313 }
12821314
1283- VPWidenCastRecipe (Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
1284- : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPIRMetadata(),
1315+ VPWidenCastRecipe (Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1316+ DebugLoc DL = {})
1317+ : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
1318+ Opcode(Opcode), ResultTy(ResultTy) {}
1319+
1320+ VPWidenCastRecipe (Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1321+ bool IsNonNeg, DebugLoc DL = {})
1322+ : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
1323+ DL),
12851324 Opcode(Opcode), ResultTy(ResultTy) {}
12861325
12871326 ~VPWidenCastRecipe () override = default ;
@@ -2376,6 +2415,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23762415 setUnderlyingValue (I);
23772416 }
23782417
2418+ // / For VPExtendedReductionRecipe.
2419+ // / Note that the debug location is from the extend.
2420+ VPReductionRecipe (const unsigned char SC, const RecurKind RdxKind,
2421+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
2422+ bool IsOrdered, DebugLoc DL)
2423+ : VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
2424+ IsOrdered(IsOrdered), IsConditional(CondOp) {
2425+ if (CondOp)
2426+ addOperand (CondOp);
2427+ }
2428+
2429+ // / For VPMulAccumulateReductionRecipe.
2430+ // / Note that the NUW/NSW flags and the debug location are from the Mul.
2431+ VPReductionRecipe (const unsigned char SC, const RecurKind RdxKind,
2432+ ArrayRef<VPValue *> Operands, VPValue *CondOp,
2433+ bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
2434+ : VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
2435+ IsOrdered(IsOrdered), IsConditional(CondOp) {
2436+ if (CondOp)
2437+ addOperand (CondOp);
2438+ }
2439+
23792440public:
23802441 VPReductionRecipe (RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
23812442 VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2384,6 +2445,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23842445 ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
23852446 IsOrdered, DL) {}
23862447
2448+ VPReductionRecipe (const RecurKind RdxKind, FastMathFlags FMFs,
2449+ VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2450+ bool IsOrdered, DebugLoc DL = {})
2451+ : VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr ,
2452+ ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2453+ IsOrdered, DL) {}
2454+
23872455 ~VPReductionRecipe () override = default ;
23882456
23892457 VPReductionRecipe *clone () override {
@@ -2394,7 +2462,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23942462
23952463 static inline bool classof (const VPRecipeBase *R) {
23962464 return R->getVPDefID () == VPRecipeBase::VPReductionSC ||
2397- R->getVPDefID () == VPRecipeBase::VPReductionEVLSC;
2465+ R->getVPDefID () == VPRecipeBase::VPReductionEVLSC ||
2466+ R->getVPDefID () == VPRecipeBase::VPExtendedReductionSC ||
2467+ R->getVPDefID () == VPRecipeBase::VPMulAccumulateReductionSC;
23982468 }
23992469
24002470 static inline bool classof (const VPUser *U) {
@@ -2474,6 +2544,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
24742544 }
24752545};
24762546
2547+ // / A recipe to represent inloop extended reduction operations, performing a
2548+ // / reduction on a extended vector operand into a scalar value, and adding the
2549+ // / result to a chain. This recipe is abstract and needs to be lowered to
2550+ // / concrete recipes before codegen. The operands are {ChainOp, VecOp,
2551+ // / [Condition]}.
2552+ class VPExtendedReductionRecipe : public VPReductionRecipe {
2553+ // / Opcode of the extend recipe will be lowered to.
2554+ Instruction::CastOps ExtOp;
2555+
2556+ Type *ResultTy;
2557+
2558+ // / For cloning VPExtendedReductionRecipe.
2559+ VPExtendedReductionRecipe (VPExtendedReductionRecipe *ExtRed)
2560+ : VPReductionRecipe(
2561+ VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind (),
2562+ {ExtRed->getChainOp (), ExtRed->getVecOp ()}, ExtRed->getCondOp (),
2563+ ExtRed->isOrdered(), ExtRed->getDebugLoc()),
2564+ ExtOp(ExtRed->getExtOpcode ()), ResultTy(ExtRed->getResultType ()) {
2565+ transferFlags (*ExtRed);
2566+ }
2567+
2568+ public:
2569+ VPExtendedReductionRecipe (VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2570+ : VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind (),
2571+ {R->getChainOp (), Ext->getOperand (0 )}, R->getCondOp (),
2572+ R->isOrdered(), Ext->getDebugLoc()),
2573+ ExtOp(Ext->getOpcode ()), ResultTy(Ext->getResultType ()) {
2574+ // Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
2575+ // the original recipe to prevent setting wrong flags.
2576+ transferFlags (*Ext);
2577+ }
2578+
2579+ ~VPExtendedReductionRecipe () override = default ;
2580+
2581+ VPExtendedReductionRecipe *clone () override {
2582+ auto *Copy = new VPExtendedReductionRecipe (this );
2583+ Copy->transferFlags (*this );
2584+ return Copy;
2585+ }
2586+
2587+ VP_CLASSOF_IMPL (VPDef::VPExtendedReductionSC);
2588+
2589+ void execute (VPTransformState &State) override {
2590+ llvm_unreachable (" VPExtendedReductionRecipe should be transform to "
2591+ " VPExtendedRecipe + VPReductionRecipe before execution." );
2592+ };
2593+
2594+ // / Return the cost of VPExtendedReductionRecipe.
2595+ InstructionCost computeCost (ElementCount VF,
2596+ VPCostContext &Ctx) const override ;
2597+
2598+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2599+ // / Print the recipe.
2600+ void print (raw_ostream &O, const Twine &Indent,
2601+ VPSlotTracker &SlotTracker) const override ;
2602+ #endif
2603+
2604+ // / The scalar type after extending.
2605+ Type *getResultType () const { return ResultTy; }
2606+
2607+ // / Is the extend ZExt?
2608+ bool isZExt () const { return getExtOpcode () == Instruction::ZExt; }
2609+
2610+ // / The opcode of extend recipe.
2611+ Instruction::CastOps getExtOpcode () const { return ExtOp; }
2612+ };
2613+
2614+ // / A recipe to represent inloop MulAccumulateReduction operations, performing a
2615+ // / reduction.add on the result of vector operands (might be extended)
2616+ // / multiplication into a scalar value, and adding the result to a chain. This
2617+ // / recipe is abstract and needs to be lowered to concrete recipes before
2618+ // / codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
2619+ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2620+ // / Opcode of the extend recipe.
2621+ Instruction::CastOps ExtOp;
2622+
2623+ // / Non-neg flag of the extend recipe.
2624+ bool IsNonNeg = false ;
2625+
2626+ Type *ResultTy;
2627+
2628+ // / For cloning VPMulAccumulateReductionRecipe.
2629+ VPMulAccumulateReductionRecipe (VPMulAccumulateReductionRecipe *MulAcc)
2630+ : VPReductionRecipe(
2631+ VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind (),
2632+ {MulAcc->getChainOp (), MulAcc->getVecOp0 (), MulAcc->getVecOp1 ()},
2633+ MulAcc->getCondOp (), MulAcc->isOrdered(),
2634+ WrapFlagsTy(MulAcc->hasNoUnsignedWrap (), MulAcc->hasNoSignedWrap()),
2635+ MulAcc->getDebugLoc()),
2636+ ExtOp(MulAcc->getExtOpcode ()), IsNonNeg(MulAcc->isNonNeg ()),
2637+ ResultTy(MulAcc->getResultType ()) {}
2638+
2639+ public:
2640+ VPMulAccumulateReductionRecipe (VPReductionRecipe *R, VPWidenRecipe *Mul,
2641+ VPWidenCastRecipe *Ext0,
2642+ VPWidenCastRecipe *Ext1, Type *ResultTy)
2643+ : VPReductionRecipe(
2644+ VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind (),
2645+ {R->getChainOp (), Ext0->getOperand (0 ), Ext1->getOperand (0 )},
2646+ R->getCondOp (), R->isOrdered(),
2647+ WrapFlagsTy(Mul->hasNoUnsignedWrap (), Mul->hasNoSignedWrap()),
2648+ R->getDebugLoc()),
2649+ ExtOp(Ext0->getOpcode ()), ResultTy(ResultTy) {
2650+ assert (RecurrenceDescriptor::getOpcode (getRecurrenceKind ()) ==
2651+ Instruction::Add &&
2652+ " The reduction instruction in MulAccumulateteReductionRecipe must "
2653+ " be Add" );
2654+ // Only set the non-negative flag if the original recipe contains.
2655+ if (Ext0->hasNonNegFlag ())
2656+ IsNonNeg = Ext0->isNonNeg ();
2657+ }
2658+
2659+ VPMulAccumulateReductionRecipe (VPReductionRecipe *R, VPWidenRecipe *Mul)
2660+ : VPReductionRecipe(
2661+ VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind (),
2662+ {R->getChainOp (), Mul->getOperand (0 ), Mul->getOperand (1 )},
2663+ R->getCondOp (), R->isOrdered(),
2664+ WrapFlagsTy(Mul->hasNoUnsignedWrap (), Mul->hasNoSignedWrap()),
2665+ R->getDebugLoc()),
2666+ ExtOp(Instruction::CastOps::CastOpsEnd) {
2667+ assert (RecurrenceDescriptor::getOpcode (getRecurrenceKind ()) ==
2668+ Instruction::Add &&
2669+ " The reduction instruction in MulAccumulateReductionRecipe must be "
2670+ " Add" );
2671+ }
2672+
2673+ ~VPMulAccumulateReductionRecipe () override = default ;
2674+
2675+ VPMulAccumulateReductionRecipe *clone () override {
2676+ auto *Copy = new VPMulAccumulateReductionRecipe (this );
2677+ Copy->transferFlags (*this );
2678+ return Copy;
2679+ }
2680+
2681+ VP_CLASSOF_IMPL (VPDef::VPMulAccumulateReductionSC);
2682+
2683+ void execute (VPTransformState &State) override {
2684+ llvm_unreachable (" VPMulAccumulateReductionRecipe should transform to "
2685+ " VPWidenCastRecipe + "
2686+ " VPWidenRecipe + VPReductionRecipe before execution" );
2687+ }
2688+
2689+ // / Return the cost of VPMulAccumulateReductionRecipe.
2690+ InstructionCost computeCost (ElementCount VF,
2691+ VPCostContext &Ctx) const override ;
2692+
2693+ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2694+ // / Print the recipe.
2695+ void print (raw_ostream &O, const Twine &Indent,
2696+ VPSlotTracker &SlotTracker) const override ;
2697+ #endif
2698+
2699+ Type *getResultType () const {
2700+ assert (isExtended () && " Only support getResultType when this recipe "
2701+ " contains implicit extend." );
2702+ return ResultTy;
2703+ }
2704+
2705+ // / The VPValue of the vector value to be extended and reduced.
2706+ VPValue *getVecOp0 () const { return getOperand (1 ); }
2707+ VPValue *getVecOp1 () const { return getOperand (2 ); }
2708+
2709+ // / Return if this MulAcc recipe contains extended operands.
2710+ bool isExtended () const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2711+
2712+ // / Return the opcode of the extends for the operands.
2713+ Instruction::CastOps getExtOpcode () const { return ExtOp; }
2714+
2715+ // / Return if the operands are zero extended.
2716+ bool isZExt () const { return ExtOp == Instruction::CastOps::ZExt; }
2717+
2718+ // / Return the non negative flag of the ext recipe.
2719+ bool isNonNeg () const { return IsNonNeg; }
2720+ };
2721+
24772722// / VPReplicateRecipe replicates a given instruction producing multiple scalar
24782723// / copies of the original scalar type, one per lane, instead of producing a
24792724// / single copy of widened type for all lanes. If the instruction is known to be
0 commit comments