Skip to content

Commit 0bf1591

Browse files
authored
[VectorCombine] foldPermuteOfBinops - fold "shuffle (binop (shuffle, other)), undef" --> "binop (shuffle), (shuffle)". (#122118)
foldPermuteOfBinops currently requires both binop operands to be oneuse shuffles to fold the shuffles across the binop, but there will be cases where its still profitable to fold across the binop with only one foldable shuffle.
1 parent 9988309 commit 0bf1591

File tree

5 files changed

+166
-151
lines changed

5 files changed

+166
-151
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,17 +1592,21 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
15921592
if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem))
15931593
return false;
15941594

1595-
Value *Op00, *Op01;
1596-
ArrayRef<int> Mask0;
1597-
if (!match(BinOp->getOperand(0),
1598-
m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0)))))
1595+
Value *Op00, *Op01, *Op10, *Op11;
1596+
ArrayRef<int> Mask0, Mask1;
1597+
bool Match0 =
1598+
match(BinOp->getOperand(0),
1599+
m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0))));
1600+
bool Match1 =
1601+
match(BinOp->getOperand(1),
1602+
m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1))));
1603+
if (!Match0 && !Match1)
15991604
return false;
16001605

1601-
Value *Op10, *Op11;
1602-
ArrayRef<int> Mask1;
1603-
if (!match(BinOp->getOperand(1),
1604-
m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1)))))
1605-
return false;
1606+
Op00 = Match0 ? Op00 : BinOp->getOperand(0);
1607+
Op01 = Match0 ? Op01 : BinOp->getOperand(0);
1608+
Op10 = Match1 ? Op10 : BinOp->getOperand(1);
1609+
Op11 = Match1 ? Op11 : BinOp->getOperand(1);
16061610

16071611
Instruction::BinaryOps Opcode = BinOp->getOpcode();
16081612
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
@@ -1620,37 +1624,46 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
16201624
any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
16211625
return false;
16221626

1623-
// Merge outer / inner shuffles.
1627+
// Merge outer / inner (or identity if no match) shuffles.
16241628
SmallVector<int> NewMask0, NewMask1;
16251629
for (int M : OuterMask) {
16261630
if (M < 0 || M >= (int)NumSrcElts) {
16271631
NewMask0.push_back(PoisonMaskElem);
16281632
NewMask1.push_back(PoisonMaskElem);
16291633
} else {
1630-
NewMask0.push_back(Mask0[M]);
1631-
NewMask1.push_back(Mask1[M]);
1634+
NewMask0.push_back(Match0 ? Mask0[M] : M);
1635+
NewMask1.push_back(Match1 ? Mask1[M] : M);
16321636
}
16331637
}
16341638

1639+
unsigned NumOpElts = Op0Ty->getNumElements();
1640+
bool IsIdentity0 = ShuffleVectorInst::isIdentityMask(NewMask0, NumOpElts);
1641+
bool IsIdentity1 = ShuffleVectorInst::isIdentityMask(NewMask1, NumOpElts);
1642+
16351643
// Try to merge shuffles across the binop if the new shuffles are not costly.
16361644
InstructionCost OldCost =
16371645
TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) +
16381646
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1639-
OuterMask, CostKind, 0, nullptr, {BinOp}, &I) +
1640-
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1641-
CostKind, 0, nullptr, {Op00, Op01},
1642-
cast<Instruction>(BinOp->getOperand(0))) +
1643-
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1644-
CostKind, 0, nullptr, {Op10, Op11},
1645-
cast<Instruction>(BinOp->getOperand(1)));
1647+
OuterMask, CostKind, 0, nullptr, {BinOp}, &I);
1648+
if (Match0)
1649+
OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty,
1650+
Mask0, CostKind, 0, nullptr, {Op00, Op01},
1651+
cast<Instruction>(BinOp->getOperand(0)));
1652+
if (Match1)
1653+
OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty,
1654+
Mask1, CostKind, 0, nullptr, {Op10, Op11},
1655+
cast<Instruction>(BinOp->getOperand(1)));
16461656

16471657
InstructionCost NewCost =
1648-
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1649-
CostKind, 0, nullptr, {Op00, Op01}) +
1650-
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1651-
CostKind, 0, nullptr, {Op10, Op11}) +
16521658
TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
16531659

1660+
if (!IsIdentity0)
1661+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty,
1662+
NewMask0, CostKind, 0, nullptr, {Op00, Op01});
1663+
if (!IsIdentity1)
1664+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty,
1665+
NewMask1, CostKind, 0, nullptr, {Op10, Op11});
1666+
16541667
LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
16551668
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
16561669
<< "\n");
@@ -1659,16 +1672,18 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
16591672
if (NewCost > OldCost)
16601673
return false;
16611674

1662-
Value *Shuf0 = Builder.CreateShuffleVector(Op00, Op01, NewMask0);
1663-
Value *Shuf1 = Builder.CreateShuffleVector(Op10, Op11, NewMask1);
1664-
Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
1675+
Value *LHS =
1676+
IsIdentity0 ? Op00 : Builder.CreateShuffleVector(Op00, Op01, NewMask0);
1677+
Value *RHS =
1678+
IsIdentity1 ? Op10 : Builder.CreateShuffleVector(Op10, Op11, NewMask1);
1679+
Value *NewBO = Builder.CreateBinOp(Opcode, LHS, RHS);
16651680

16661681
// Intersect flags from the old binops.
16671682
if (auto *NewInst = dyn_cast<Instruction>(NewBO))
16681683
NewInst->copyIRFlags(BinOp);
16691684

1670-
Worklist.pushValue(Shuf0);
1671-
Worklist.pushValue(Shuf1);
1685+
Worklist.pushValue(LHS);
1686+
Worklist.pushValue(RHS);
16721687
replaceValue(I, *NewBO);
16731688
return true;
16741689
}

0 commit comments

Comments
 (0)