@@ -1592,17 +1592,21 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
1592
1592
if (BinOp->isIntDivRem () && llvm::is_contained (OuterMask, PoisonMaskElem))
1593
1593
return false ;
1594
1594
1595
- Value *Op00, *Op01;
1596
- ArrayRef<int > Mask0;
1597
- if (!match (BinOp->getOperand (0 ),
1598
- m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0)))))
1595
+ Value *Op00, *Op01, *Op10, *Op11;
1596
+ ArrayRef<int > Mask0, Mask1;
1597
+ bool Match0 =
1598
+ match (BinOp->getOperand (0 ),
1599
+ m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0))));
1600
+ bool Match1 =
1601
+ match (BinOp->getOperand (1 ),
1602
+ m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1))));
1603
+ if (!Match0 && !Match1)
1599
1604
return false ;
1600
1605
1601
- Value *Op10, *Op11;
1602
- ArrayRef<int > Mask1;
1603
- if (!match (BinOp->getOperand (1 ),
1604
- m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1)))))
1605
- return false ;
1606
+ Op00 = Match0 ? Op00 : BinOp->getOperand (0 );
1607
+ Op01 = Match0 ? Op01 : BinOp->getOperand (0 );
1608
+ Op10 = Match1 ? Op10 : BinOp->getOperand (1 );
1609
+ Op11 = Match1 ? Op11 : BinOp->getOperand (1 );
1606
1610
1607
1611
Instruction::BinaryOps Opcode = BinOp->getOpcode ();
1608
1612
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
@@ -1620,37 +1624,46 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
1620
1624
any_of (OuterMask, [NumSrcElts](int M) { return M >= (int )NumSrcElts; }))
1621
1625
return false ;
1622
1626
1623
- // Merge outer / inner shuffles.
1627
+ // Merge outer / inner (or identity if no match) shuffles.
1624
1628
SmallVector<int > NewMask0, NewMask1;
1625
1629
for (int M : OuterMask) {
1626
1630
if (M < 0 || M >= (int )NumSrcElts) {
1627
1631
NewMask0.push_back (PoisonMaskElem);
1628
1632
NewMask1.push_back (PoisonMaskElem);
1629
1633
} else {
1630
- NewMask0.push_back (Mask0[M]);
1631
- NewMask1.push_back (Mask1[M]);
1634
+ NewMask0.push_back (Match0 ? Mask0[M] : M );
1635
+ NewMask1.push_back (Match1 ? Mask1[M] : M );
1632
1636
}
1633
1637
}
1634
1638
1639
+ unsigned NumOpElts = Op0Ty->getNumElements ();
1640
+ bool IsIdentity0 = ShuffleVectorInst::isIdentityMask (NewMask0, NumOpElts);
1641
+ bool IsIdentity1 = ShuffleVectorInst::isIdentityMask (NewMask1, NumOpElts);
1642
+
1635
1643
// Try to merge shuffles across the binop if the new shuffles are not costly.
1636
1644
InstructionCost OldCost =
1637
1645
TTI.getArithmeticInstrCost (Opcode, BinOpTy, CostKind) +
1638
1646
TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1639
- OuterMask, CostKind, 0 , nullptr , {BinOp}, &I) +
1640
- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1641
- CostKind, 0 , nullptr , {Op00, Op01},
1642
- cast<Instruction>(BinOp->getOperand (0 ))) +
1643
- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1644
- CostKind, 0 , nullptr , {Op10, Op11},
1645
- cast<Instruction>(BinOp->getOperand (1 )));
1647
+ OuterMask, CostKind, 0 , nullptr , {BinOp}, &I);
1648
+ if (Match0)
1649
+ OldCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty,
1650
+ Mask0, CostKind, 0 , nullptr , {Op00, Op01},
1651
+ cast<Instruction>(BinOp->getOperand (0 )));
1652
+ if (Match1)
1653
+ OldCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty,
1654
+ Mask1, CostKind, 0 , nullptr , {Op10, Op11},
1655
+ cast<Instruction>(BinOp->getOperand (1 )));
1646
1656
1647
1657
InstructionCost NewCost =
1648
- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1649
- CostKind, 0 , nullptr , {Op00, Op01}) +
1650
- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1651
- CostKind, 0 , nullptr , {Op10, Op11}) +
1652
1658
TTI.getArithmeticInstrCost (Opcode, ShuffleDstTy, CostKind);
1653
1659
1660
+ if (!IsIdentity0)
1661
+ NewCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty,
1662
+ NewMask0, CostKind, 0 , nullptr , {Op00, Op01});
1663
+ if (!IsIdentity1)
1664
+ NewCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty,
1665
+ NewMask1, CostKind, 0 , nullptr , {Op10, Op11});
1666
+
1654
1667
LLVM_DEBUG (dbgs () << " Found a shuffle feeding a shuffled binop: " << I
1655
1668
<< " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1656
1669
<< " \n " );
@@ -1659,16 +1672,18 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
1659
1672
if (NewCost > OldCost)
1660
1673
return false ;
1661
1674
1662
- Value *Shuf0 = Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1663
- Value *Shuf1 = Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1664
- Value *NewBO = Builder.CreateBinOp (Opcode, Shuf0, Shuf1);
1675
+ Value *LHS =
1676
+ IsIdentity0 ? Op00 : Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1677
+ Value *RHS =
1678
+ IsIdentity1 ? Op10 : Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1679
+ Value *NewBO = Builder.CreateBinOp (Opcode, LHS, RHS);
1665
1680
1666
1681
// Intersect flags from the old binops.
1667
1682
if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1668
1683
NewInst->copyIRFlags (BinOp);
1669
1684
1670
- Worklist.pushValue (Shuf0 );
1671
- Worklist.pushValue (Shuf1 );
1685
+ Worklist.pushValue (LHS );
1686
+ Worklist.pushValue (RHS );
1672
1687
replaceValue (I, *NewBO);
1673
1688
return true ;
1674
1689
}
0 commit comments