Skip to content

Commit 67278b8

Browse files
committed
[LV] Support Interleaved Store Group With Gaps
Teach LV to use masked-store to support interleave-store-group with gaps (instead of scatters/scalarization). The symmetric case of using masked-load to support interleaved-load-group with gaps was introduced a while ago, by https://reviews.llvm.org/D53668; This patch completes the store-scenario leftover from D53668, and solves PR50566. Reviewed by: Ayal Zaks Differential Revision: https://reviews.llvm.org/D104750
1 parent 657bb72 commit 67278b8

File tree

6 files changed

+557
-90
lines changed

6 files changed

+557
-90
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -686,10 +686,8 @@ template <typename InstTy> class InterleaveGroup {
686686
if (getMember(getFactor() - 1))
687687
return false;
688688

689-
// We have a group with gaps. It therefore cannot be a group of stores,
690-
// and it can't be a reversed access, because such groups get invalidated.
691-
assert(!getMember(0)->mayWriteToMemory() &&
692-
"Group should have been invalidated");
689+
// We have a group with gaps. It therefore can't be a reversed access,
690+
// because such groups get invalidated (TODO).
693691
assert(!isReverse() && "Group should have been invalidated");
694692

695693
// This is a group of loads, with gaps, and without a last-member

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,9 +1212,9 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
12121212
// used (those corresponding to elements [0:1] and [8:9] of the unlegalized
12131213
// type). The other loads are unused.
12141214
//
1215-
// We only scale the cost of loads since interleaved store groups aren't
1216-
// allowed to have gaps.
1217-
if (Opcode == Instruction::Load && VecTySize > VecTyLTSize) {
1215+
// TODO: Note that legalization can turn masked loads/stores into unmasked
1216+
// (legalized) loads/stores. This can be reflected in the cost.
1217+
if (VecTySize > VecTyLTSize) {
12181218
// The number of loads of a legal type it will take to represent a load
12191219
// of the unlegalized vector type.
12201220
unsigned NumLegalInsts = divideCeil(VecTySize, VecTyLTSize);
@@ -1235,6 +1235,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
12351235
}
12361236

12371237
// Then plus the cost of interleave operation.
1238+
assert(Indices.size() <= Factor &&
1239+
"Interleaved memory op has too many members");
12381240
if (Opcode == Instruction::Load) {
12391241
// The interleave cost is similar to extract sub vectors' elements
12401242
// from the wide vector, and insert them into sub vectors.
@@ -1244,44 +1246,49 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
12441246
// %v0 = shuffle %vec, undef, <0, 2, 4, 6> ; Index 0
12451247
// The cost is estimated as extract elements at 0, 2, 4, 6 from the
12461248
// <8 x i32> vector and insert them into a <4 x i32> vector.
1247-
1248-
assert(Indices.size() <= Factor &&
1249-
"Interleaved memory op has too many members");
1250-
12511249
for (unsigned Index : Indices) {
12521250
assert(Index < Factor && "Invalid index for interleaved memory op");
12531251

12541252
// Extract elements from loaded vector for each sub vector.
1255-
for (unsigned i = 0; i < NumSubElts; i++)
1253+
for (unsigned Elm = 0; Elm < NumSubElts; Elm++)
12561254
Cost += thisT()->getVectorInstrCost(Instruction::ExtractElement, VT,
1257-
Index + i * Factor);
1255+
Index + Elm * Factor);
12581256
}
12591257

12601258
InstructionCost InsSubCost = 0;
1261-
for (unsigned i = 0; i < NumSubElts; i++)
1259+
for (unsigned Elm = 0; Elm < NumSubElts; Elm++)
12621260
InsSubCost +=
1263-
thisT()->getVectorInstrCost(Instruction::InsertElement, SubVT, i);
1261+
thisT()->getVectorInstrCost(Instruction::InsertElement, SubVT, Elm);
12641262

12651263
Cost += Indices.size() * InsSubCost;
12661264
} else {
1267-
// The interleave cost is extract all elements from sub vectors, and
1265+
// The interleave cost is extract elements from sub vectors, and
12681266
// insert them into the wide vector.
12691267
//
1270-
// E.g. An interleaved store of factor 2:
1271-
// %v0_v1 = shuffle %v0, %v1, <0, 4, 1, 5, 2, 6, 3, 7>
1272-
// store <8 x i32> %interleaved.vec, <8 x i32>* %ptr
1273-
// The cost is estimated as extract all elements from both <4 x i32>
1274-
// vectors and insert into the <8 x i32> vector.
1275-
1268+
// E.g. An interleaved store of factor 3 with 2 members at indices 0,1:
1269+
// (using VF=4):
1270+
// %v0_v1 = shuffle %v0, %v1, <0,4,undef,1,5,undef,2,6,undef,3,7,undef>
1271+
// %gaps.mask = <true, true, false, true, true, false,
1272+
// true, true, false, true, true, false>
1273+
// call llvm.masked.store <12 x i32> %v0_v1, <12 x i32>* %ptr,
1274+
// i32 Align, <12 x i1> %gaps.mask
1275+
// The cost is estimated as extract all elements (of actual members,
1276+
// excluding gaps) from both <4 x i32> vectors and insert into the <12 x
1277+
// i32> vector.
12761278
InstructionCost ExtSubCost = 0;
1277-
for (unsigned i = 0; i < NumSubElts; i++)
1278-
ExtSubCost +=
1279-
thisT()->getVectorInstrCost(Instruction::ExtractElement, SubVT, i);
1280-
Cost += ExtSubCost * Factor;
1281-
1282-
for (unsigned i = 0; i < NumElts; i++)
1283-
Cost += static_cast<T *>(this)
1284-
->getVectorInstrCost(Instruction::InsertElement, VT, i);
1279+
for (unsigned Elm = 0; Elm < NumSubElts; Elm++)
1280+
ExtSubCost += thisT()->getVectorInstrCost(Instruction::ExtractElement,
1281+
SubVT, Elm);
1282+
Cost += ExtSubCost * Indices.size();
1283+
1284+
for (unsigned Index : Indices) {
1285+
assert(Index < Factor && "Invalid index for interleaved memory op");
1286+
1287+
// Insert elements from loaded vector for each sub vector.
1288+
for (unsigned Elm = 0; Elm < NumSubElts; Elm++)
1289+
Cost += thisT()->getVectorInstrCost(Instruction::InsertElement, VT,
1290+
Index + Elm * Factor);
1291+
}
12851292
}
12861293

12871294
if (!UseMaskForCond)

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,15 +1193,23 @@ void InterleavedAccessInfo::analyzeInterleaving(
11931193
} // Iteration over A accesses.
11941194
} // Iteration over B accesses.
11951195

1196-
// Remove interleaved store groups with gaps.
1197-
for (auto *Group : StoreGroups)
1198-
if (Group->getNumMembers() != Group->getFactor()) {
1199-
LLVM_DEBUG(
1200-
dbgs() << "LV: Invalidate candidate interleaved store group due "
1201-
"to gaps.\n");
1202-
releaseGroup(Group);
1203-
}
1204-
// Remove interleaved groups with gaps (currently only loads) whose memory
1196+
auto InvalidateGroupIfMemberMayWrap = [&](InterleaveGroup<Instruction> *Group,
1197+
int Index,
1198+
std::string FirstOrLast) -> bool {
1199+
Instruction *Member = Group->getMember(Index);
1200+
assert(Member && "Group member does not exist");
1201+
Value *MemberPtr = getLoadStorePointerOperand(Member);
1202+
if (getPtrStride(PSE, MemberPtr, TheLoop, Strides, /*Assume=*/false,
1203+
/*ShouldCheckWrap=*/true))
1204+
return false;
1205+
LLVM_DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to "
1206+
<< FirstOrLast
1207+
<< " group member potentially pointer-wrapping.\n");
1208+
releaseGroup(Group);
1209+
return true;
1210+
};
1211+
1212+
// Remove interleaved groups with gaps whose memory
12051213
// accesses may wrap around. We have to revisit the getPtrStride analysis,
12061214
// this time with ShouldCheckWrap=true, since collectConstStrideAccesses does
12071215
// not check wrapping (see documentation there).
@@ -1227,26 +1235,12 @@ void InterleavedAccessInfo::analyzeInterleaving(
12271235
// So we check only group member 0 (which is always guaranteed to exist),
12281236
// and group member Factor - 1; If the latter doesn't exist we rely on
12291237
// peeling (if it is a non-reversed accsess -- see Case 3).
1230-
Value *FirstMemberPtr = getLoadStorePointerOperand(Group->getMember(0));
1231-
if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false,
1232-
/*ShouldCheckWrap=*/true)) {
1233-
LLVM_DEBUG(
1234-
dbgs() << "LV: Invalidate candidate interleaved group due to "
1235-
"first group member potentially pointer-wrapping.\n");
1236-
releaseGroup(Group);
1238+
if (InvalidateGroupIfMemberMayWrap(Group, 0, std::string("first")))
12371239
continue;
1238-
}
1239-
Instruction *LastMember = Group->getMember(Group->getFactor() - 1);
1240-
if (LastMember) {
1241-
Value *LastMemberPtr = getLoadStorePointerOperand(LastMember);
1242-
if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false,
1243-
/*ShouldCheckWrap=*/true)) {
1244-
LLVM_DEBUG(
1245-
dbgs() << "LV: Invalidate candidate interleaved group due to "
1246-
"last group member potentially pointer-wrapping.\n");
1247-
releaseGroup(Group);
1248-
}
1249-
} else {
1240+
if (Group->getMember(Group->getFactor() - 1))
1241+
InvalidateGroupIfMemberMayWrap(Group, Group->getFactor() - 1,
1242+
std::string("last"));
1243+
else {
12501244
// Case 3: A non-reversed interleaved load group with gaps: We need
12511245
// to execute at least one scalar epilogue iteration. This will ensure
12521246
// we don't speculatively access memory out-of-bounds. We only need
@@ -1264,6 +1258,39 @@ void InterleavedAccessInfo::analyzeInterleaving(
12641258
RequiresScalarEpilogue = true;
12651259
}
12661260
}
1261+
1262+
for (auto *Group : StoreGroups) {
1263+
// Case 1: A full group. Can Skip the checks; For full groups, if the wide
1264+
// store would wrap around the address space we would do a memory access at
1265+
// nullptr even without the transformation.
1266+
if (Group->getNumMembers() == Group->getFactor())
1267+
continue;
1268+
1269+
// Interleave-store-group with gaps is implemented using masked wide store.
1270+
// Remove interleaved store groups with gaps if
1271+
// masked-interleaved-accesses are not enabled by the target.
1272+
if (!EnablePredicatedInterleavedMemAccesses) {
1273+
LLVM_DEBUG(
1274+
dbgs() << "LV: Invalidate candidate interleaved store group due "
1275+
"to gaps.\n");
1276+
releaseGroup(Group);
1277+
continue;
1278+
}
1279+
1280+
// Case 2: If first and last members of the group don't wrap this implies
1281+
// that all the pointers in the group don't wrap.
1282+
// So we check only group member 0 (which is always guaranteed to exist),
1283+
// and the last group member. Case 3 (scalar epilog) is not relevant for
1284+
// stores with gaps, which are implemented with masked-store (rather than
1285+
// speculative access, as in loads).
1286+
if (InvalidateGroupIfMemberMayWrap(Group, 0, std::string("first")))
1287+
continue;
1288+
for (int Index = Group->getFactor() - 1; Index > 0; Index--)
1289+
if (Group->getMember(Index)) {
1290+
InvalidateGroupIfMemberMayWrap(Group, Index, std::string("last"));
1291+
break;
1292+
}
1293+
}
12671294
}
12681295

12691296
void InterleavedAccessInfo::invalidateGroupsRequiringScalarEpilogue() {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,12 +2837,25 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
28372837
auto *SubVT = VectorType::get(ScalarTy, VF);
28382838

28392839
// Vectorize the interleaved store group.
2840+
MaskForGaps = createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group);
2841+
assert((!MaskForGaps || useMaskedInterleavedAccesses(*TTI)) &&
2842+
"masked interleaved groups are not allowed.");
2843+
assert((!MaskForGaps || !VF.isScalable()) &&
2844+
"masking gaps for scalable vectors is not yet supported.");
28402845
for (unsigned Part = 0; Part < UF; Part++) {
28412846
// Collect the stored vector from each member.
28422847
SmallVector<Value *, 4> StoredVecs;
28432848
for (unsigned i = 0; i < InterleaveFactor; i++) {
2844-
// Interleaved store group doesn't allow a gap, so each index has a member
2845-
assert(Group->getMember(i) && "Fail to get a member from an interleaved store group");
2849+
assert((Group->getMember(i) || MaskForGaps) &&
2850+
"Fail to get a member from an interleaved store group");
2851+
Instruction *Member = Group->getMember(i);
2852+
2853+
// Skip the gaps in the group.
2854+
if (!Member) {
2855+
Value *Undef = PoisonValue::get(SubVT);
2856+
StoredVecs.push_back(Undef);
2857+
continue;
2858+
}
28462859

28472860
Value *StoredVec = State.get(StoredValues[i], Part);
28482861

@@ -2866,16 +2879,21 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(
28662879
"interleaved.vec");
28672880

28682881
Instruction *NewStoreInstr;
2869-
if (BlockInMask) {
2870-
Value *BlockInMaskPart = State.get(BlockInMask, Part);
2871-
Value *ShuffledMask = Builder.CreateShuffleVector(
2872-
BlockInMaskPart,
2873-
createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()),
2874-
"interleaved.mask");
2875-
NewStoreInstr = Builder.CreateMaskedStore(
2876-
IVec, AddrParts[Part], Group->getAlign(), ShuffledMask);
2877-
}
2878-
else
2882+
if (BlockInMask || MaskForGaps) {
2883+
Value *GroupMask = MaskForGaps;
2884+
if (BlockInMask) {
2885+
Value *BlockInMaskPart = State.get(BlockInMask, Part);
2886+
Value *ShuffledMask = Builder.CreateShuffleVector(
2887+
BlockInMaskPart,
2888+
createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()),
2889+
"interleaved.mask");
2890+
GroupMask = MaskForGaps ? Builder.CreateBinOp(Instruction::And,
2891+
ShuffledMask, MaskForGaps)
2892+
: ShuffledMask;
2893+
}
2894+
NewStoreInstr = Builder.CreateMaskedStore(IVec, AddrParts[Part],
2895+
Group->getAlign(), GroupMask);
2896+
} else
28792897
NewStoreInstr =
28802898
Builder.CreateAlignedStore(IVec, AddrParts[Part], Group->getAlign());
28812899

@@ -5274,12 +5292,19 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened(
52745292

52755293
// Check if masking is required.
52765294
// A Group may need masking for one of two reasons: it resides in a block that
5277-
// needs predication, or it was decided to use masking to deal with gaps.
5295+
// needs predication, or it was decided to use masking to deal with gaps
5296+
// (either a gap at the end of a load-access that may result in a speculative
5297+
// load, or any gaps in a store-access).
52785298
bool PredicatedAccessRequiresMasking =
52795299
Legal->blockNeedsPredication(I->getParent()) && Legal->isMaskRequired(I);
5280-
bool AccessWithGapsRequiresMasking =
5281-
Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed();
5282-
if (!PredicatedAccessRequiresMasking && !AccessWithGapsRequiresMasking)
5300+
bool LoadAccessWithGapsRequiresEpilogMasking =
5301+
isa<LoadInst>(I) && Group->requiresScalarEpilogue() &&
5302+
!isScalarEpilogueAllowed();
5303+
bool StoreAccessWithGapsRequiresMasking =
5304+
isa<StoreInst>(I) && (Group->getNumMembers() < Group->getFactor());
5305+
if (!PredicatedAccessRequiresMasking &&
5306+
!LoadAccessWithGapsRequiresEpilogMasking &&
5307+
!StoreAccessWithGapsRequiresMasking)
52835308
return true;
52845309

52855310
// If masked interleaving is required, we expect that the user/target had
@@ -7118,18 +7143,16 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I,
71187143
unsigned InterleaveFactor = Group->getFactor();
71197144
auto *WideVecTy = VectorType::get(ValTy, VF * InterleaveFactor);
71207145

7121-
// Holds the indices of existing members in an interleaved load group.
7122-
// An interleaved store group doesn't need this as it doesn't allow gaps.
7146+
// Holds the indices of existing members in the interleaved group.
71237147
SmallVector<unsigned, 4> Indices;
7124-
if (isa<LoadInst>(I)) {
7125-
for (unsigned i = 0; i < InterleaveFactor; i++)
7126-
if (Group->getMember(i))
7127-
Indices.push_back(i);
7128-
}
7148+
for (unsigned IF = 0; IF < InterleaveFactor; IF++)
7149+
if (Group->getMember(IF))
7150+
Indices.push_back(IF);
71297151

71307152
// Calculate the cost of the whole interleaved group.
71317153
bool UseMaskForGaps =
7132-
Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed();
7154+
(Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed()) ||
7155+
(isa<StoreInst>(I) && (Group->getNumMembers() < Group->getFactor()));
71337156
InstructionCost Cost = TTI.getInterleavedMemoryOpCost(
71347157
I->getOpcode(), WideVecTy, Group->getFactor(), Indices, Group->getAlign(),
71357158
AS, TTI::TCK_RecipThroughput, Legal->isMaskRequired(I), UseMaskForGaps);

0 commit comments

Comments
 (0)