Skip to content

Commit 3fb7efd

Browse files
Update based on review comments
NOTE: We need to re-write the overall commit message, as it is not close to accurate any longer.
1 parent 966954c commit 3fb7efd

File tree

2 files changed

+126
-46
lines changed

2 files changed

+126
-46
lines changed

llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,6 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
183183
DenseSet<BasicBlock *> DeadBlocks;
184184
// PHI nodes we have visited before.
185185
DenseSet<Instruction *> VisitedPHIs;
186-
// PHI nodes forming a strongly connected component.
187-
DenseSet<PHINode *> StronglyConnectedPHIs;
188186
// PHI nodes we have visited once without successfully constant folding them.
189187
// Once the InstCostVisitor has processed all the specialization arguments,
190188
// it should be possible to determine whether those PHIs can be folded
@@ -219,7 +217,8 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
219217
Cost estimateSwitchInst(SwitchInst &I);
220218
Cost estimateBranchInst(BranchInst &I);
221219

222-
void discoverStronglyConnectedComponent(PHINode *PN, unsigned Depth);
220+
bool discoverTransitivelyIncomngValues(DenseSet<PHINode *> &PhiNodes,
221+
PHINode *PN, unsigned Depth);
223222

224223
Constant *visitInstruction(Instruction &I) { return nullptr; }
225224
Constant *visitPHINode(PHINode &I);

llvm/lib/Transforms/IPO/FunctionSpecialization.cpp

Lines changed: 124 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -267,38 +267,77 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
267267
return estimateBasicBlocks(WorkList);
268268
}
269269

270-
void InstCostVisitor::discoverStronglyConnectedComponent(PHINode *PN,
271-
unsigned Depth) {
272-
if (Depth > MaxDiscoveryDepth)
273-
return;
270+
// This function is finding candidates for a PHINode is part of a chain or graph
271+
// of PHINodes that all link to each other. That means, if the original input to
272+
// the chain is a constant all the other values are also that constant.
273+
//
274+
// The caller of this function will later check that no other nodes are involved
275+
// that are non-constant, and discard it from the possible conversions.
276+
//
277+
// For example:
278+
//
279+
// %a = load %0
280+
// %c = phi [%a, %d]
281+
// %d = phi [%e, %c]
282+
// %e = phi [%c, %f]
283+
// %f = phi [%j, %h]
284+
// %j = phi [%h, %j]
285+
// %h = phi [%g, %c]
286+
//
287+
// This is only showing the PHINodes, not the branches that choose the
288+
// different paths.
289+
//
290+
// A depth limit is used to avoid extreme recurusion.
291+
// A max number of incoming phi values ensures that expensive searches
292+
// are avoided.
293+
//
294+
// Returns false if the discovery was aborted due to the above conditions.
295+
bool InstCostVisitor::discoverTransitivelyIncomngValues(
296+
DenseSet<PHINode *> &PHINodes, PHINode *PN, unsigned Depth) {
297+
if (Depth > MaxDiscoveryDepth) {
298+
LLVM_DEBUG(dbgs() << "FnSpecialization: Discover PHI nodes too deep ("
299+
<< Depth << ">" << MaxDiscoveryDepth << ")\n");
300+
return false;
301+
}
274302

275-
if (PN->getNumIncomingValues() > MaxIncomingPhiValues)
276-
return;
303+
if (PN->getNumIncomingValues() > MaxIncomingPhiValues) {
304+
LLVM_DEBUG(
305+
dbgs() << "FnSpecialization: Discover PHI nodes has too many values ("
306+
<< PN->getNumIncomingValues() << ">" << MaxIncomingPhiValues
307+
<< ")\n");
308+
return false;
309+
}
277310

278-
if (!StronglyConnectedPHIs.insert(PN).second)
279-
return;
311+
// Already seen this, no more processing needed.
312+
if (!PHINodes.insert(PN).second)
313+
return true;
280314

281315
for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
282316
Value *V = PN->getIncomingValue(I);
283317
if (auto *Phi = dyn_cast<PHINode>(V)) {
284318
if (Phi == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
285319
continue;
286-
discoverStronglyConnectedComponent(Phi, Depth + 1);
320+
if (!discoverTransitivelyIncomngValues(PHINodes, Phi, Depth + 1))
321+
return false;
287322
}
288323
}
324+
return true;
289325
}
290326

291327
Constant *InstCostVisitor::visitPHINode(PHINode &I) {
292328
if (I.getNumIncomingValues() > MaxIncomingPhiValues)
293329
return nullptr;
294330

331+
// PHI nodes
332+
DenseSet<PHINode *> TransitivePHIs;
333+
295334
bool Inserted = VisitedPHIs.insert(&I).second;
296-
Constant *Const = nullptr;
297335
SmallVector<PHINode *, 8> UnknownIncomingValues;
298336

299-
auto CanConstantFoldPhi = [&](PHINode *PN) -> bool {
300-
UnknownIncomingValues.clear();
337+
auto canConstantFoldPhiTrivially = [&](PHINode *PN) -> Constant * {
338+
Constant *Const = nullptr;
301339

340+
UnknownIncomingValues.clear();
302341
for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
303342
Value *V = PN->getIncomingValue(I);
304343

@@ -311,21 +350,22 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
311350
if (!Const)
312351
Const = C;
313352
// Not all incoming values are the same constant. Bail immediately.
314-
else if (C != Const)
315-
return false;
316-
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
317-
// It's not a strongly connected phi. Collect it and bail at the end.
318-
if (!StronglyConnectedPHIs.contains(Phi))
319-
UnknownIncomingValues.push_back(Phi);
320-
} else {
321-
// We can't reason about anything else.
322-
return false;
353+
if (C != Const)
354+
return nullptr;
355+
continue;
323356
}
357+
if (auto *Phi = dyn_cast<PHINode>(V)) {
358+
UnknownIncomingValues.push_back(Phi);
359+
continue;
360+
}
361+
362+
// We can't reason about anything else.
363+
return nullptr;
324364
}
325-
return UnknownIncomingValues.empty();
365+
return UnknownIncomingValues.empty() ? Const : nullptr;
326366
};
327367

328-
if (CanConstantFoldPhi(&I))
368+
if (Constant *Const = canConstantFoldPhiTrivially(&I))
329369
return Const;
330370

331371
if (Inserted) {
@@ -335,18 +375,59 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
335375
return nullptr;
336376
}
337377

378+
// Try to see if we can collect a nest of transitive phis. Bail if
379+
// it's too complex.
338380
for (PHINode *Phi : UnknownIncomingValues)
339-
discoverStronglyConnectedComponent(Phi, 1);
381+
if (!discoverTransitivelyIncomngValues(TransitivePHIs, Phi, 1))
382+
return nullptr;
383+
384+
// A nested set of PHINodes can be constantfolded if:
385+
// - It has a constant input.
386+
// - It is always the SAME constant.
387+
auto canConstantFoldNestedPhi = [&](PHINode *PN) -> Constant * {
388+
Constant *Const = nullptr;
340389

341-
bool CannotConstantFoldPhi = false;
342-
for (PHINode *Phi : StronglyConnectedPHIs) {
343-
if (!CanConstantFoldPhi(Phi)) {
344-
CannotConstantFoldPhi = true;
345-
break;
390+
for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) {
391+
Value *V = PN->getIncomingValue(I);
392+
393+
// Disregard self-references and dead incoming values.
394+
if (auto *Inst = dyn_cast<Instruction>(V))
395+
if (Inst == PN || DeadBlocks.contains(PN->getIncomingBlock(I)))
396+
continue;
397+
398+
if (Constant *C = findConstantFor(V, KnownConstants)) {
399+
if (!Const)
400+
Const = C;
401+
// Not all incoming values are the same constant. Bail immediately.
402+
if (C != Const)
403+
return nullptr;
404+
continue;
405+
}
406+
if (auto *Phi = dyn_cast<PHINode>(V)) {
407+
// It's not a Transitive phi. Bail out.
408+
if (!TransitivePHIs.contains(Phi))
409+
return nullptr;
410+
continue;
411+
}
412+
413+
// We can't reason about anything else.
414+
return nullptr;
415+
}
416+
return Const;
417+
};
418+
419+
// All TransitivePHIs have to be the SAME constant.
420+
Constant *Retval = nullptr;
421+
for (PHINode *Phi : TransitivePHIs) {
422+
if (Constant *Const = canConstantFoldNestedPhi(Phi)) {
423+
if (!Retval)
424+
Retval = Const;
425+
else if (Retval != Const)
426+
return nullptr;
346427
}
347428
}
348-
StronglyConnectedPHIs.clear();
349-
return CannotConstantFoldPhi ? nullptr : Const;
429+
430+
return Retval;
350431
}
351432

352433
Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
@@ -871,37 +952,37 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
871952
unsigned FuncGrowth) -> bool {
872953
// No check required.
873954
if (ForceSpecialization) {
874-
LLVM_DEBUG(dbgs() << "Force is on\n");
955+
LLVM_DEBUG(dbgs() << "FnSpecialization: Force is on\n");
875956
return true;
876957
}
877958
// Minimum inlining bonus.
878959
if (Score > MinInliningBonus * FuncSize / 100) {
879960
LLVM_DEBUG(dbgs()
880-
<< "FnSpecialization: Min inliningbous: Score = " << Score
881-
<< " > " << MinInliningBonus * FuncSize / 100 << "\n");
961+
<< "FnSpecialization: Sufficient inlining bonus (" << Score
962+
<< " > " << MinInliningBonus * FuncSize / 100 << ")\n");
882963
return true;
883964
}
884965
// Minimum codesize savings.
885966
if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100) {
886967
LLVM_DEBUG(dbgs()
887-
<< "FnSpecialization: Min CodeSize Saving: CodeSize = "
968+
<< "FnSpecialization: Insufficinet CodeSize Saving ("
888969
<< B.CodeSize << " > "
889-
<< MinCodeSizeSavings * FuncSize / 100 << "\n");
970+
<< MinCodeSizeSavings * FuncSize / 100 << ")\n");
890971
return false;
891972
}
892973
// Minimum latency savings.
893974
if (B.Latency < MinLatencySavings * FuncSize / 100) {
894-
LLVM_DEBUG(dbgs()
895-
<< "FnSpecialization: Min Latency Saving: Latency = "
896-
<< B.Latency << " > " << MinLatencySavings * FuncSize / 100
897-
<< "\n");
975+
LLVM_DEBUG(dbgs() << "FnSpecialization: Insufficinet Latency Saving ("
976+
<< B.Latency << " > "
977+
<< MinLatencySavings * FuncSize / 100 << ")\n");
898978
return false;
899979
}
900980
// Maximum codesize growth.
901981
if (FuncGrowth / FuncSize > MaxCodeSizeGrowth) {
902-
LLVM_DEBUG(dbgs() << "FnSpecialization: Max Func Growth: CodeSize = "
903-
<< FuncGrowth / FuncSize << " > "
904-
<< MaxCodeSizeGrowth << "\n");
982+
LLVM_DEBUG(dbgs()
983+
<< "FnSpecialization: Function Growth exceeds threshold ("
984+
<< FuncGrowth / FuncSize << " > " << MaxCodeSizeGrowth
985+
<< ")\n");
905986
return false;
906987
}
907988
return true;

0 commit comments

Comments
 (0)