Skip to content

Commit 52f85d6

Browse files
committed
!fixup update getPtrstride to take DT as const reference
1 parent 013ed31 commit 52f85d6

File tree

7 files changed

+33
-29
lines changed

7 files changed

+33
-29
lines changed

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -893,11 +893,10 @@ replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
893893
/// result of this function is undefined.
894894
LLVM_ABI std::optional<int64_t>
895895
getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr,
896-
const Loop *Lp,
896+
const Loop *Lp, const DominatorTree &DT,
897897
const DenseMap<Value *, const SCEV *> &StridesMap =
898898
DenseMap<Value *, const SCEV *>(),
899-
bool Assume = false, bool ShouldCheckWrap = true,
900-
DominatorTree *DT = nullptr);
899+
bool Assume = false, bool ShouldCheckWrap = true);
901900

902901
/// Returns the distance between the pointers \p PtrA and \p PtrB iff they are
903902
/// compatible and it is possible to calculate the distance between them. This

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,9 +1620,9 @@ void AccessAnalysis::processMemAccesses() {
16201620
/// Check whether the access through \p Ptr has a constant stride.
16211621
std::optional<int64_t>
16221622
llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr,
1623-
const Loop *Lp,
1623+
const Loop *Lp, const DominatorTree &DT,
16241624
const DenseMap<Value *, const SCEV *> &StridesMap,
1625-
bool Assume, bool ShouldCheckWrap, DominatorTree *DT) {
1625+
bool Assume, bool ShouldCheckWrap) {
16261626
const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr);
16271627
if (PSE.getSE()->isLoopInvariant(PtrScev, Lp))
16281628
return 0;
@@ -1644,7 +1644,7 @@ llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr,
16441644
if (!ShouldCheckWrap || !Stride)
16451645
return Stride;
16461646

1647-
if (isNoWrap(PSE, AR, Ptr, AccessTy, Lp, Assume, *DT, Stride))
1647+
if (isNoWrap(PSE, AR, Ptr, AccessTy, Lp, Assume, DT, Stride))
16481648
return Stride;
16491649

16501650
LLVM_DEBUG(
@@ -2062,9 +2062,9 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
20622062
return MemoryDepChecker::Dependence::Unknown;
20632063

20642064
std::optional<int64_t> StrideAPtr = getPtrStride(
2065-
PSE, ATy, APtr, InnermostLoop, SymbolicStrides, true, true, DT);
2065+
PSE, ATy, APtr, InnermostLoop, *DT, SymbolicStrides, true, true);
20662066
std::optional<int64_t> StrideBPtr = getPtrStride(
2067-
PSE, BTy, BPtr, InnermostLoop, SymbolicStrides, true, true, DT);
2067+
PSE, BTy, BPtr, InnermostLoop, *DT, SymbolicStrides, true, true);
20682068

20692069
const SCEV *Src = PSE.getSCEV(APtr);
20702070
const SCEV *Sink = PSE.getSCEV(BPtr);
@@ -2706,8 +2706,8 @@ bool LoopAccessInfo::analyzeLoop(AAResults *AA, const LoopInfo *LI,
27062706
bool IsReadOnlyPtr = false;
27072707
Type *AccessTy = getLoadStoreType(LD);
27082708
if (Seen.insert({Ptr, AccessTy}).second ||
2709-
!getPtrStride(*PSE, AccessTy, Ptr, TheLoop, SymbolicStrides, false,
2710-
true, DT)) {
2709+
!getPtrStride(*PSE, AccessTy, Ptr, TheLoop, *DT, SymbolicStrides, false,
2710+
true)) {
27112711
++NumReads;
27122712
IsReadOnlyPtr = true;
27132713
}

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,8 +1388,8 @@ void InterleavedAccessInfo::collectConstStrideAccesses(
13881388
// even without the transformation. The wrapping checks are therefore
13891389
// deferred until after we've formed the interleaved groups.
13901390
int64_t Stride =
1391-
getPtrStride(PSE, ElementTy, Ptr, TheLoop, Strides,
1392-
/*Assume=*/true, /*ShouldCheckWrap=*/false, DT)
1391+
getPtrStride(PSE, ElementTy, Ptr, TheLoop, *DT, Strides,
1392+
/*Assume=*/true, /*ShouldCheckWrap=*/false)
13931393
.value_or(0);
13941394

13951395
const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
@@ -1644,8 +1644,8 @@ void InterleavedAccessInfo::analyzeInterleaving(
16441644
assert(Member && "Group member does not exist");
16451645
Value *MemberPtr = getLoadStorePointerOperand(Member);
16461646
Type *AccessTy = getLoadStoreType(Member);
1647-
if (getPtrStride(PSE, AccessTy, MemberPtr, TheLoop, Strides,
1648-
/*Assume=*/false, /*ShouldCheckWrap=*/true, DT)
1647+
if (getPtrStride(PSE, AccessTy, MemberPtr, TheLoop, *DT, Strides,
1648+
/*Assume=*/false, /*ShouldCheckWrap=*/true)
16491649
.value_or(0))
16501650
return false;
16511651
LLVM_DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to "

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6122,7 +6122,8 @@ AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
61226122
}
61236123

61246124
static bool containsDecreasingPointers(Loop *TheLoop,
6125-
PredicatedScalarEvolution *PSE) {
6125+
PredicatedScalarEvolution *PSE,
6126+
const DominatorTree &DT) {
61266127
const auto &Strides = DenseMap<Value *, const SCEV *>();
61276128
for (BasicBlock *BB : TheLoop->blocks()) {
61286129
// Scan the instructions in the block and look for addresses that are
@@ -6131,8 +6132,8 @@ static bool containsDecreasingPointers(Loop *TheLoop,
61316132
if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) {
61326133
Value *Ptr = getLoadStorePointerOperand(&I);
61336134
Type *AccessTy = getLoadStoreType(&I);
6134-
if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true,
6135-
/*ShouldCheckWrap=*/false)
6135+
if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, DT, Strides,
6136+
/*Assume=*/true, /*ShouldCheckWrap=*/false)
61366137
.value_or(0) < 0)
61376138
return true;
61386139
}
@@ -6177,7 +6178,8 @@ bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) const {
61776178
// negative strides. This will require extra work to reverse the loop
61786179
// predicate, which may be expensive.
61796180
if (containsDecreasingPointers(TFI->LVL->getLoop(),
6180-
TFI->LVL->getPredicatedScalarEvolution()))
6181+
TFI->LVL->getPredicatedScalarEvolution(),
6182+
*TFI->LVL->getDominatorTree()))
61816183
Required |= TailFoldingOpts::Reverse;
61826184
if (Required == TailFoldingOpts::Disabled)
61836185
Required |= TailFoldingOpts::Simple;

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,7 +2448,8 @@ static bool canTailPredicateInstruction(Instruction &I, int &ICmpCount) {
24482448
//
24492449
static bool canTailPredicateLoop(Loop *L, LoopInfo *LI, ScalarEvolution &SE,
24502450
const DataLayout &DL,
2451-
const LoopAccessInfo *LAI) {
2451+
const LoopAccessInfo *LAI,
2452+
const DominatorTree &DT) {
24522453
LLVM_DEBUG(dbgs() << "Tail-predication: checking allowed instructions\n");
24532454

24542455
// If there are live-out values, it is probably a reduction. We can predicate
@@ -2498,7 +2499,8 @@ static bool canTailPredicateLoop(Loop *L, LoopInfo *LI, ScalarEvolution &SE,
24982499
if (isa<StoreInst>(I) || isa<LoadInst>(I)) {
24992500
Value *Ptr = getLoadStorePointerOperand(&I);
25002501
Type *AccessTy = getLoadStoreType(&I);
2501-
int64_t NextStride = getPtrStride(PSE, AccessTy, Ptr, L).value_or(0);
2502+
int64_t NextStride = getPtrStride(PSE, AccessTy, Ptr, L,
2503+
DT).value_or(0);
25022504
if (NextStride == 1) {
25032505
// TODO: for now only allow consecutive strides of 1. We could support
25042506
// other strides as long as it is uniform, but let's keep it simple
@@ -2585,7 +2587,8 @@ bool ARMTTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) const {
25852587
return false;
25862588
}
25872589

2588-
return canTailPredicateLoop(L, LI, *SE, DL, LVL->getLAI());
2590+
return canTailPredicateLoop(L, LI, *SE, DL, LVL->getLAI(),
2591+
*LVL->getDominatorTree());
25892592
}
25902593

25912594
TailFoldingStyle

llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ struct StoreToLoadForwardingCandidate {
8989
/// Return true if the dependence from the store to the load has an
9090
/// absolute distance of one.
9191
/// E.g. A[i+1] = A[i] (or A[i-1] = A[i] for descending loop)
92-
bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE,
93-
Loop *L) const {
92+
bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, Loop *L,
93+
const DominatorTree &DT) const {
9494
Value *LoadPtr = Load->getPointerOperand();
9595
Value *StorePtr = Store->getPointerOperand();
9696
Type *LoadType = getLoadStoreType(Load);
@@ -102,8 +102,8 @@ struct StoreToLoadForwardingCandidate {
102102
DL.getTypeSizeInBits(getLoadStoreType(Store)) &&
103103
"Should be a known dependence");
104104

105-
int64_t StrideLoad = getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0);
106-
int64_t StrideStore = getPtrStride(PSE, LoadType, StorePtr, L).value_or(0);
105+
int64_t StrideLoad = getPtrStride(PSE, LoadType, LoadPtr, L, DT).value_or(0);
106+
int64_t StrideStore = getPtrStride(PSE, LoadType, StorePtr, L, DT).value_or(0);
107107
if (!StrideLoad || !StrideStore || StrideLoad != StrideStore)
108108
return false;
109109

@@ -287,8 +287,8 @@ class LoadEliminationForLoop {
287287
// so deciding which one forwards is easy. The later one forwards as
288288
// long as they both have a dependence distance of one to the load.
289289
if (Cand.Store->getParent() == OtherCand->Store->getParent() &&
290-
Cand.isDependenceDistanceOfOne(PSE, L) &&
291-
OtherCand->isDependenceDistanceOfOne(PSE, L)) {
290+
Cand.isDependenceDistanceOfOne(PSE, L, *DT) &&
291+
OtherCand->isDependenceDistanceOfOne(PSE, L, *DT)) {
292292
// They are in the same block, the later one will forward to the load.
293293
if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store))
294294
OtherCand = &Cand;
@@ -538,7 +538,7 @@ class LoadEliminationForLoop {
538538

539539
// Check whether the SCEV difference is the same as the induction step,
540540
// thus we load the value in the next iteration.
541-
if (!Cand.isDependenceDistanceOfOne(PSE, L))
541+
if (!Cand.isDependenceDistanceOfOne(PSE, L, *DT))
542542
continue;
543543

544544
assert(isa<SCEVAddRecExpr>(PSE.getSCEV(Cand.Load->getPointerOperand())) &&

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy,
462462

463463
bool CanAddPredicate = !llvm::shouldOptimizeForSize(
464464
TheLoop->getHeader(), PSI, BFI, PGSOQueryType::IRPass);
465-
int Stride = getPtrStride(PSE, AccessTy, Ptr, TheLoop, Strides,
465+
int Stride = getPtrStride(PSE, AccessTy, Ptr, TheLoop, *DT, Strides,
466466
CanAddPredicate, false).value_or(0);
467467
if (Stride == 1 || Stride == -1)
468468
return Stride;

0 commit comments

Comments
 (0)