@@ -267,38 +267,77 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
267
267
return estimateBasicBlocks (WorkList);
268
268
}
269
269
270
- void InstCostVisitor::discoverStronglyConnectedComponent (PHINode *PN,
271
- unsigned Depth) {
272
- if (Depth > MaxDiscoveryDepth)
273
- return ;
270
+ // This function is finding candidates for a PHINode is part of a chain or graph
271
+ // of PHINodes that all link to each other. That means, if the original input to
272
+ // the chain is a constant all the other values are also that constant.
273
+ //
274
+ // The caller of this function will later check that no other nodes are involved
275
+ // that are non-constant, and discard it from the possible conversions.
276
+ //
277
+ // For example:
278
+ //
279
+ // %a = load %0
280
+ // %c = phi [%a, %d]
281
+ // %d = phi [%e, %c]
282
+ // %e = phi [%c, %f]
283
+ // %f = phi [%j, %h]
284
+ // %j = phi [%h, %j]
285
+ // %h = phi [%g, %c]
286
+ //
287
+ // This is only showing the PHINodes, not the branches that choose the
288
+ // different paths.
289
+ //
290
+ // A depth limit is used to avoid extreme recurusion.
291
+ // A max number of incoming phi values ensures that expensive searches
292
+ // are avoided.
293
+ //
294
+ // Returns false if the discovery was aborted due to the above conditions.
295
+ bool InstCostVisitor::discoverTransitivelyIncomngValues (
296
+ DenseSet<PHINode *> &PHINodes, PHINode *PN, unsigned Depth) {
297
+ if (Depth > MaxDiscoveryDepth) {
298
+ LLVM_DEBUG (dbgs () << " FnSpecialization: Discover PHI nodes too deep ("
299
+ << Depth << " >" << MaxDiscoveryDepth << " )\n " );
300
+ return false ;
301
+ }
274
302
275
- if (PN->getNumIncomingValues () > MaxIncomingPhiValues)
276
- return ;
303
+ if (PN->getNumIncomingValues () > MaxIncomingPhiValues) {
304
+ LLVM_DEBUG (
305
+ dbgs () << " FnSpecialization: Discover PHI nodes has too many values ("
306
+ << PN->getNumIncomingValues () << " >" << MaxIncomingPhiValues
307
+ << " )\n " );
308
+ return false ;
309
+ }
277
310
278
- if (!StronglyConnectedPHIs.insert (PN).second )
279
- return ;
311
+ // Already seen this, no more processing needed.
312
+ if (!PHINodes.insert (PN).second )
313
+ return true ;
280
314
281
315
for (unsigned I = 0 , E = PN->getNumIncomingValues (); I != E; ++I) {
282
316
Value *V = PN->getIncomingValue (I);
283
317
if (auto *Phi = dyn_cast<PHINode>(V)) {
284
318
if (Phi == PN || DeadBlocks.contains (PN->getIncomingBlock (I)))
285
319
continue ;
286
- discoverStronglyConnectedComponent (Phi, Depth + 1 );
320
+ if (!discoverTransitivelyIncomngValues (PHINodes, Phi, Depth + 1 ))
321
+ return false ;
287
322
}
288
323
}
324
+ return true ;
289
325
}
290
326
291
327
Constant *InstCostVisitor::visitPHINode (PHINode &I) {
292
328
if (I.getNumIncomingValues () > MaxIncomingPhiValues)
293
329
return nullptr ;
294
330
331
+ // PHI nodes
332
+ DenseSet<PHINode *> TransitivePHIs;
333
+
295
334
bool Inserted = VisitedPHIs.insert (&I).second ;
296
- Constant *Const = nullptr ;
297
335
SmallVector<PHINode *, 8 > UnknownIncomingValues;
298
336
299
- auto CanConstantFoldPhi = [&](PHINode *PN) -> bool {
300
- UnknownIncomingValues. clear () ;
337
+ auto canConstantFoldPhiTrivially = [&](PHINode *PN) -> Constant * {
338
+ Constant *Const = nullptr ;
301
339
340
+ UnknownIncomingValues.clear ();
302
341
for (unsigned I = 0 , E = PN->getNumIncomingValues (); I != E; ++I) {
303
342
Value *V = PN->getIncomingValue (I);
304
343
@@ -311,21 +350,22 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
311
350
if (!Const)
312
351
Const = C;
313
352
// Not all incoming values are the same constant. Bail immediately.
314
- else if (C != Const)
315
- return false ;
316
- } else if (auto *Phi = dyn_cast<PHINode>(V)) {
317
- // It's not a strongly connected phi. Collect it and bail at the end.
318
- if (!StronglyConnectedPHIs.contains (Phi))
319
- UnknownIncomingValues.push_back (Phi);
320
- } else {
321
- // We can't reason about anything else.
322
- return false ;
353
+ if (C != Const)
354
+ return nullptr ;
355
+ continue ;
323
356
}
357
+ if (auto *Phi = dyn_cast<PHINode>(V)) {
358
+ UnknownIncomingValues.push_back (Phi);
359
+ continue ;
360
+ }
361
+
362
+ // We can't reason about anything else.
363
+ return nullptr ;
324
364
}
325
- return UnknownIncomingValues.empty ();
365
+ return UnknownIncomingValues.empty () ? Const : nullptr ;
326
366
};
327
367
328
- if (CanConstantFoldPhi (&I))
368
+ if (Constant *Const = canConstantFoldPhiTrivially (&I))
329
369
return Const;
330
370
331
371
if (Inserted) {
@@ -335,18 +375,59 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
335
375
return nullptr ;
336
376
}
337
377
378
+ // Try to see if we can collect a nest of transitive phis. Bail if
379
+ // it's too complex.
338
380
for (PHINode *Phi : UnknownIncomingValues)
339
- discoverStronglyConnectedComponent (Phi, 1 );
381
+ if (!discoverTransitivelyIncomngValues (TransitivePHIs, Phi, 1 ))
382
+ return nullptr ;
383
+
384
+ // A nested set of PHINodes can be constantfolded if:
385
+ // - It has a constant input.
386
+ // - It is always the SAME constant.
387
+ auto canConstantFoldNestedPhi = [&](PHINode *PN) -> Constant * {
388
+ Constant *Const = nullptr ;
340
389
341
- bool CannotConstantFoldPhi = false ;
342
- for (PHINode *Phi : StronglyConnectedPHIs) {
343
- if (!CanConstantFoldPhi (Phi)) {
344
- CannotConstantFoldPhi = true ;
345
- break ;
390
+ for (unsigned I = 0 , E = PN->getNumIncomingValues (); I != E; ++I) {
391
+ Value *V = PN->getIncomingValue (I);
392
+
393
+ // Disregard self-references and dead incoming values.
394
+ if (auto *Inst = dyn_cast<Instruction>(V))
395
+ if (Inst == PN || DeadBlocks.contains (PN->getIncomingBlock (I)))
396
+ continue ;
397
+
398
+ if (Constant *C = findConstantFor (V, KnownConstants)) {
399
+ if (!Const)
400
+ Const = C;
401
+ // Not all incoming values are the same constant. Bail immediately.
402
+ if (C != Const)
403
+ return nullptr ;
404
+ continue ;
405
+ }
406
+ if (auto *Phi = dyn_cast<PHINode>(V)) {
407
+ // It's not a Transitive phi. Bail out.
408
+ if (!TransitivePHIs.contains (Phi))
409
+ return nullptr ;
410
+ continue ;
411
+ }
412
+
413
+ // We can't reason about anything else.
414
+ return nullptr ;
415
+ }
416
+ return Const;
417
+ };
418
+
419
+ // All TransitivePHIs have to be the SAME constant.
420
+ Constant *Retval = nullptr ;
421
+ for (PHINode *Phi : TransitivePHIs) {
422
+ if (Constant *Const = canConstantFoldNestedPhi (Phi)) {
423
+ if (!Retval)
424
+ Retval = Const;
425
+ else if (Retval != Const)
426
+ return nullptr ;
346
427
}
347
428
}
348
- StronglyConnectedPHIs. clear ();
349
- return CannotConstantFoldPhi ? nullptr : Const ;
429
+
430
+ return Retval ;
350
431
}
351
432
352
433
Constant *InstCostVisitor::visitFreezeInst (FreezeInst &I) {
@@ -871,37 +952,37 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
871
952
unsigned FuncGrowth) -> bool {
872
953
// No check required.
873
954
if (ForceSpecialization) {
874
- LLVM_DEBUG (dbgs () << " Force is on\n " );
955
+ LLVM_DEBUG (dbgs () << " FnSpecialization: Force is on\n " );
875
956
return true ;
876
957
}
877
958
// Minimum inlining bonus.
878
959
if (Score > MinInliningBonus * FuncSize / 100 ) {
879
960
LLVM_DEBUG (dbgs ()
880
- << " FnSpecialization: Min inliningbous: Score = " << Score
881
- << " > " << MinInliningBonus * FuncSize / 100 << " \n " );
961
+ << " FnSpecialization: Sufficient inlining bonus ( " << Score
962
+ << " > " << MinInliningBonus * FuncSize / 100 << " ) \n " );
882
963
return true ;
883
964
}
884
965
// Minimum codesize savings.
885
966
if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100 ) {
886
967
LLVM_DEBUG (dbgs ()
887
- << " FnSpecialization: Min CodeSize Saving: CodeSize = "
968
+ << " FnSpecialization: Insufficinet CodeSize Saving ( "
888
969
<< B.CodeSize << " > "
889
- << MinCodeSizeSavings * FuncSize / 100 << " \n " );
970
+ << MinCodeSizeSavings * FuncSize / 100 << " ) \n " );
890
971
return false ;
891
972
}
892
973
// Minimum latency savings.
893
974
if (B.Latency < MinLatencySavings * FuncSize / 100 ) {
894
- LLVM_DEBUG (dbgs ()
895
- << " FnSpecialization: Min Latency Saving: Latency = "
896
- << B.Latency << " > " << MinLatencySavings * FuncSize / 100
897
- << " \n " );
975
+ LLVM_DEBUG (dbgs () << " FnSpecialization: Insufficinet Latency Saving ("
976
+ << B.Latency << " > "
977
+ << MinLatencySavings * FuncSize / 100 << " )\n " );
898
978
return false ;
899
979
}
900
980
// Maximum codesize growth.
901
981
if (FuncGrowth / FuncSize > MaxCodeSizeGrowth) {
902
- LLVM_DEBUG (dbgs () << " FnSpecialization: Max Func Growth: CodeSize = "
903
- << FuncGrowth / FuncSize << " > "
904
- << MaxCodeSizeGrowth << " \n " );
982
+ LLVM_DEBUG (dbgs ()
983
+ << " FnSpecialization: Function Growth exceeds threshold ("
984
+ << FuncGrowth / FuncSize << " > " << MaxCodeSizeGrowth
985
+ << " )\n " );
905
986
return false ;
906
987
}
907
988
return true ;
0 commit comments