@@ -823,7 +823,8 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
823823 };
824824
825825 VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion ();
826- SmallVector<VPReductionPHIRecipe *> ReductionsToConvert;
826+ SmallVector<std::pair<VPReductionPHIRecipe *, VPValue *>>
827+ MinMaxNumReductionsToHandle;
827828 bool HasUnsupportedPhi = false ;
828829 for (auto &R : LoopRegion->getEntryBasicBlock ()->phis ()) {
829830 if (isa<VPCanonicalIVPHIRecipe, VPWidenIntOrFpInductionRecipe>(&R))
@@ -839,10 +840,15 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
839840 HasUnsupportedPhi = true ;
840841 continue ;
841842 }
842- ReductionsToConvert.push_back (Cur);
843+
844+ VPValue *MinMaxOp = GetMinMaxCompareValue (Cur);
845+ if (!MinMaxOp)
846+ return false ;
847+
848+ MinMaxNumReductionsToHandle.emplace_back (Cur, MinMaxOp);
843849 }
844850
845- if (ReductionsToConvert .empty ())
851+ if (MinMaxNumReductionsToHandle .empty ())
846852 return true ;
847853
848854 // We won't be able to resume execution in the scalar tail, if there are
@@ -867,32 +873,29 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
867873 }
868874
869875 VPBasicBlock *LatchVPBB = LoopRegion->getExitingBasicBlock ();
870- VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock ();
871- VPBuilder MiddleBuilder (MiddleVPBB, MiddleVPBB->begin ());
872- VPBuilder Builder (LatchVPBB->getTerminator ());
873- VPValue *AnyNaN = nullptr ;
876+ VPBuilder LatchBuilder (LatchVPBB->getTerminator ());
877+ VPValue *IsNaNLane = nullptr ;
874878 SmallPtrSet<VPValue *, 2 > RdxResults;
875- for (VPReductionPHIRecipe * RedPhiR : ReductionsToConvert ) {
879+ for (const auto &[ RedPhiR, MinMaxOp] : MinMaxNumReductionsToHandle ) {
876880 assert (RecurrenceDescriptor::isFPMinMaxNumRecurrenceKind (
877881 RedPhiR->getRecurrenceKind ()) &&
878882 " unsupported reduction" );
879883
880- VPValue *MinMaxOp = GetMinMaxCompareValue (RedPhiR);
881- if (!MinMaxOp)
882- return false ;
883-
884- VPValue *IsNaN = Builder.createFCmp (CmpInst::FCMP_UNO, MinMaxOp, MinMaxOp);
885- VPValue *HasNaN = Builder.createNaryOp (VPInstruction::AnyOf, {IsNaN});
886- if (AnyNaN)
887- AnyNaN = Builder.createOr (AnyNaN, HasNaN);
888- else
889- AnyNaN = HasNaN;
884+ VPValue *IsNaN =
885+ LatchBuilder.createFCmp (CmpInst::FCMP_UNO, MinMaxOp, MinMaxOp);
886+ IsNaNLane = IsNaNLane ? LatchBuilder.createOr (IsNaNLane, IsNaN) : IsNaN;
887+ }
890888
889+ VPValue *AnyNaNLane =
890+ LatchBuilder.createNaryOp (VPInstruction::AnyOf, {IsNaNLane});
891+ VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock ();
892+ VPBuilder MiddleBuilder (MiddleVPBB, MiddleVPBB->begin ());
893+ for (const auto &[RedPhiR, MinMaxOp] : MinMaxNumReductionsToHandle) {
891894 // If we exit early due to NaNs, compute the final reduction result based
892895 // on the reduction phi at the beginning of the last vector iteration.
893896 auto *RdxResult = find_singleton<VPSingleDefRecipe>(
894897 RedPhiR->getBackedgeValue ()->users (),
895- [RedPhiR](VPUser *U, bool ) -> VPSingleDefRecipe * {
898+ [RedPhiR = RedPhiR ](VPUser *U, bool ) -> VPSingleDefRecipe * {
896899 auto *VPI = dyn_cast<VPInstruction>(U);
897900 if (VPI && VPI->getOpcode () == VPInstruction::ComputeReductionResult)
898901 return VPI;
@@ -902,24 +905,25 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
902905 return nullptr ;
903906 });
904907
905- auto *NewSel =
906- MiddleBuilder. createSelect (HasNaN, RedPhiR, RdxResult->getOperand (1 ));
908+ auto *NewSel = MiddleBuilder. createSelect (AnyNaNLane, RedPhiR,
909+ RdxResult->getOperand (1 ));
907910 RdxResult->setOperand (1 , NewSel);
911+ assert (!RdxResults.contains (RdxResult) && " RdxResult already used" );
908912 RdxResults.insert (RdxResult);
909913 }
910914
911915 auto *LatchExitingBranch = LatchVPBB->getTerminator ();
912916 assert (match (LatchExitingBranch, m_BranchOnCount (m_VPValue (), m_VPValue ())) &&
913917 " Unexpected terminator" );
914- auto *IsLatchExitTaken =
915- Builder. createICmp ( CmpInst::ICMP_EQ, LatchExitingBranch->getOperand (0 ),
916- LatchExitingBranch->getOperand (1 ));
917- auto *AnyExitTaken =
918- Builder. createNaryOp ( Instruction::Or, {AnyNaN , IsLatchExitTaken});
919- Builder .createNaryOp (VPInstruction::BranchOnCond, AnyExitTaken);
918+ auto *IsLatchExitTaken = LatchBuilder. createICmp (
919+ CmpInst::ICMP_EQ, LatchExitingBranch->getOperand (0 ),
920+ LatchExitingBranch->getOperand (1 ));
921+ auto *AnyExitTaken = LatchBuilder. createNaryOp (
922+ Instruction::Or, {AnyNaNLane , IsLatchExitTaken});
923+ LatchBuilder .createNaryOp (VPInstruction::BranchOnCond, AnyExitTaken);
920924 LatchExitingBranch->eraseFromParent ();
921925
922- // Update resume phis for inductions in the scalar preheader. If AnyNaN is
926+ // Update resume phis for inductions in the scalar preheader. If AnyNaNLane is
923927 // true, the resume from the start of the last vector iteration via the
924928 // canonical IV, otherwise from the original value.
925929 for (auto &R : Plan.getScalarPreheader ()->phis ()) {
@@ -930,8 +934,9 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
930934 if (auto *DerivedIV = dyn_cast<VPDerivedIVRecipe>(VecV)) {
931935 if (DerivedIV->getNumUsers () == 1 &&
932936 DerivedIV->getOperand (1 ) == &Plan.getVectorTripCount ()) {
933- auto *NewSel = MiddleBuilder.createSelect (
934- AnyNaN, LoopRegion->getCanonicalIV (), &Plan.getVectorTripCount ());
937+ auto *NewSel =
938+ MiddleBuilder.createSelect (AnyNaNLane, LoopRegion->getCanonicalIV (),
939+ &Plan.getVectorTripCount ());
935940 DerivedIV->moveAfter (&*MiddleBuilder.getInsertPoint ());
936941 DerivedIV->setOperand (1 , NewSel);
937942 continue ;
@@ -944,15 +949,16 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
944949 " FMaxNum/FMinNum reduction.\n " );
945950 return false ;
946951 }
947- auto *NewSel =
948- MiddleBuilder. createSelect (AnyNaN , LoopRegion->getCanonicalIV (), VecV);
952+ auto *NewSel = MiddleBuilder. createSelect (
953+ AnyNaNLane , LoopRegion->getCanonicalIV (), VecV);
949954 ResumeR->setOperand (0 , NewSel);
950955 }
951956
952957 auto *MiddleTerm = MiddleVPBB->getTerminator ();
953- Builder .setInsertPoint (MiddleTerm);
958+ MiddleBuilder .setInsertPoint (MiddleTerm);
954959 VPValue *MiddleCond = MiddleTerm->getOperand (0 );
955- VPValue *NewCond = Builder.createAnd (MiddleCond, Builder.createNot (AnyNaN));
960+ VPValue *NewCond =
961+ MiddleBuilder.createAnd (MiddleCond, MiddleBuilder.createNot (AnyNaNLane));
956962 MiddleTerm->setOperand (0 , NewCond);
957963 return true ;
958964}
0 commit comments