Skip to content

Commit

Permalink
[SLP]Fix mask generation after cost estimation
Browse files Browse the repository at this point in the history
When estimating the cost of entries shuffles for buildvectors, need to
rebuild original mask, not a generated submask, used for subregisters
analysis.

Fixes llvm#122430
  • Loading branch information
alexey-bataev committed Jan 10, 2025
1 parent cc88a5e commit 681c83a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
29 changes: 21 additions & 8 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13445,14 +13445,15 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
for_each(SubMask, [&](int &Idx) {
if (Idx == PoisonMaskElem)
return;
Idx = (Idx % VF) - (MinElement % VF) +
Idx = ((Idx % VF) - (((MinElement % VF) / NewVF) * NewVF)) % NewVF +
(Idx >= static_cast<int>(VF) ? NewVF : 0);
});
VF = NewVF;
} else {
NewVF = VF;
}

constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
auto *VecTy = getWidenedType(VL.front()->getType(), VF);
auto *VecTy = getWidenedType(VL.front()->getType(), NewVF);
auto *MaskVecTy = getWidenedType(VL.front()->getType(), SubMask.size());
auto GetShuffleCost = [&,
&TTI = *TTI](ArrayRef<int> Mask,
Expand All @@ -13477,7 +13478,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
APInt DemandedElts = APInt::getAllOnes(SubMask.size());
bool IsIdentity = true;
for (auto [I, Idx] : enumerate(FirstMask)) {
if (Idx >= static_cast<int>(VF)) {
if (Idx >= static_cast<int>(NewVF)) {
Idx = PoisonMaskElem;
} else {
DemandedElts.clearBit(I);
Expand All @@ -13500,12 +13501,12 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
APInt DemandedElts = APInt::getAllOnes(SubMask.size());
bool IsIdentity = true;
for (auto [I, Idx] : enumerate(SecondMask)) {
if (Idx < static_cast<int>(VF) && Idx >= 0) {
if (Idx < static_cast<int>(NewVF) && Idx >= 0) {
Idx = PoisonMaskElem;
} else {
DemandedElts.clearBit(I);
if (Idx != PoisonMaskElem) {
Idx -= VF;
Idx -= NewVF;
IsIdentity &= static_cast<int>(I) == Idx;
}
}
Expand All @@ -13525,12 +13526,24 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
/*Extract=*/false, CostKind);
const TreeEntry *BestEntry = nullptr;
if (FirstShuffleCost < ShuffleCost) {
copy(FirstMask, std::next(Mask.begin(), Part * VL.size()));
std::for_each(std::next(Mask.begin(), Part * VL.size()),
std::next(Mask.begin(), (Part + 1) * VL.size()),
[&](int &Idx) {
if (Idx >= static_cast<int>(VF))
Idx = PoisonMaskElem;
});
BestEntry = Entries.front();
ShuffleCost = FirstShuffleCost;
}
if (SecondShuffleCost < ShuffleCost) {
copy(SecondMask, std::next(Mask.begin(), Part * VL.size()));
std::for_each(std::next(Mask.begin(), Part * VL.size()),
std::next(Mask.begin(), (Part + 1) * VL.size()),
[&](int &Idx) {
if (Idx < static_cast<int>(VF))
Idx = PoisonMaskElem;
else
Idx -= VF;
});
BestEntry = Entries[1];
ShuffleCost = SecondShuffleCost;
}
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/SLPVectorizer/X86/bv-shuffle-mask.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ define i16 @test(i16 %v1, i16 %v2) {
; CHECK-NEXT: [[TMP2:%.*]] = or <4 x i16> [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = and <4 x i16> [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <4 x i16> [[TMP2]], <4 x i16> [[TMP3]], <4 x i32> <i32 0, i32 1, i32 2, i32 7>
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i16> [[TMP1]], <4 x i16> poison, <2 x i32> <i32 0, i32 poison>
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i16> [[TMP5]], i16 [[V2]], i32 1
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i16> [[TMP0]], <4 x i16> poison, <2 x i32> <i32 poison, i32 3>
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i16> [[TMP5]], i16 [[V1]], i32 0
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <2 x i16> [[TMP6]], <2 x i16> poison, <4 x i32> <i32 0, i32 0, i32 0, i32 1>
; CHECK-NEXT: [[TMP8:%.*]] = or <4 x i16> [[TMP7]], zeroinitializer
; CHECK-NEXT: [[TMP9:%.*]] = and <4 x i16> [[TMP4]], zeroinitializer
Expand Down

0 comments on commit 681c83a

Please sign in to comment.