@@ -338,7 +338,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
338
338
}
339
339
}
340
340
341
- #define getOpFullest (Builder, vtmp, frominst, check ) \
341
+ #define getOpFullest (Builder, vtmp, frominst, lookupInst, check ) \
342
342
({ \
343
343
Value *v = vtmp; \
344
344
BasicBlock *origParent = frominst; \
@@ -362,24 +362,38 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
362
362
if (!DT.dominates (opinst, &*Builder.GetInsertPoint ())) \
363
363
noLookup = true ; \
364
364
} \
365
- if (origParent) \
365
+ origParent = lookupInst; \
366
+ if (BasicBlock *forwardBlock = origParent) \
366
367
if (auto opinst = dyn_cast<Instruction>(v)) { \
367
- v = fixLCSSA (opinst, origParent); \
368
+ if (!isOriginalBlock (*forwardBlock)) { \
369
+ forwardBlock = originalForReverseBlock (*forwardBlock); \
370
+ } \
371
+ if (isPotentialLastLoopValue (opinst, forwardBlock, LI)) { \
372
+ v = fixLCSSA (opinst, forwardBlock); \
373
+ origParent = nullptr ; \
374
+ } \
368
375
} \
369
376
if (!noLookup) \
370
- ___res = lookupM (v, Builder, available, v != val); \
377
+ ___res = lookupM (v, Builder, available, v != val, origParent); \
371
378
} \
372
379
if (___res) \
373
380
assert (___res->getType () == v->getType () && " uw" ); \
374
381
} else { \
375
- if (origParent) \
382
+ origParent = lookupInst; \
383
+ if (BasicBlock *forwardBlock = origParent) \
376
384
if (auto opinst = dyn_cast<Instruction>(v)) { \
377
- v = fixLCSSA (opinst, origParent); \
385
+ if (!isOriginalBlock (*forwardBlock)) { \
386
+ forwardBlock = originalForReverseBlock (*forwardBlock); \
387
+ } \
388
+ if (isPotentialLastLoopValue (opinst, forwardBlock, LI)) { \
389
+ v = fixLCSSA (opinst, forwardBlock); \
390
+ origParent = nullptr ; \
391
+ } \
378
392
} \
379
393
assert (unwrapMode == UnwrapMode::AttemptSingleUnwrap); \
380
394
auto found = available.find (v); \
381
395
assert (found == available.end () || found->second ); \
382
- ___res = lookupM (v, Builder, available, v != val); \
396
+ ___res = lookupM (v, Builder, available, v != val, origParent); \
383
397
if (___res && ___res->getType () != v->getType ()) { \
384
398
llvm::errs () << *newFunc << " \n " ; \
385
399
llvm::errs () << " v = " << *v << " res = " << *___res << " \n " ; \
@@ -390,19 +404,25 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
390
404
___res; \
391
405
})
392
406
#define getOpFull (Builder, vtmp, frominst ) \
393
- getOpFullest (Builder, vtmp, frominst, true )
407
+ ({ \
408
+ BasicBlock *parent = scope; \
409
+ if (parent == nullptr ) \
410
+ if (auto originst = dyn_cast<Instruction>(val)) \
411
+ parent = originst->getParent (); \
412
+ getOpFullest (Builder, vtmp, frominst, parent, true ); \
413
+ })
394
414
#define getOpUnchecked (vtmp ) \
395
415
({ \
396
416
BasicBlock *parent = scope; \
397
- getOpFullest (BuilderM, vtmp, parent, false ); \
417
+ getOpFullest (BuilderM, vtmp, parent, parent, false ); \
398
418
})
399
419
#define getOp (vtmp ) \
400
420
({ \
401
421
BasicBlock *parent = scope; \
402
422
if (parent == nullptr ) \
403
423
if (auto originst = dyn_cast<Instruction>(val)) \
404
424
parent = originst->getParent (); \
405
- getOpFullest (BuilderM, vtmp, parent, true ); \
425
+ getOpFullest (BuilderM, vtmp, parent, parent, true ); \
406
426
})
407
427
408
428
if (isa<Argument>(val) || isa<Constant>(val)) {
@@ -1470,16 +1490,39 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1470
1490
if (!DT.dominates (inst, &*B.GetInsertPoint ()))
1471
1491
noLookup = true ;
1472
1492
}
1473
- Value *v = fixLCSSA (inst, nextScope);
1474
- if (!noLookup)
1475
- ___res = lookupM (v, B, prevAvailable, v != val);
1493
+ if (!noLookup) {
1494
+ BasicBlock *nS2 = nextScope;
1495
+ Value *v = inst;
1496
+ if (BasicBlock *forwardBlock = nextScope)
1497
+ if (auto opinst = dyn_cast<Instruction>(v)) {
1498
+ if (!isOriginalBlock (*forwardBlock)) {
1499
+ forwardBlock = originalForReverseBlock (*forwardBlock);
1500
+ }
1501
+ if (isPotentialLastLoopValue (opinst, forwardBlock,
1502
+ LI)) {
1503
+ v = fixLCSSA (opinst, forwardBlock);
1504
+ nS2 = nullptr ;
1505
+ }
1506
+ }
1507
+ ___res = lookupM (v, B, prevAvailable, v != val, nS2);
1508
+ }
1476
1509
}
1477
1510
if (___res)
1478
1511
assert (___res->getType () == inst->getType () && " uw" );
1479
1512
} else {
1480
- Value *v = fixLCSSA (inst, nextScope);
1481
- assert (unwrapMode == UnwrapMode::AttemptSingleUnwrap);
1482
- ___res = lookupM (v, B, prevAvailable, v != val);
1513
+ BasicBlock *nS2 = nextScope;
1514
+ Value *v = inst;
1515
+ if (BasicBlock *forwardBlock = nextScope)
1516
+ if (auto opinst = dyn_cast<Instruction>(v)) {
1517
+ if (!isOriginalBlock (*forwardBlock)) {
1518
+ forwardBlock = originalForReverseBlock (*forwardBlock);
1519
+ }
1520
+ if (isPotentialLastLoopValue (opinst, forwardBlock, LI)) {
1521
+ v = fixLCSSA (opinst, forwardBlock);
1522
+ nS2 = nullptr ;
1523
+ }
1524
+ }
1525
+ ___res = lookupM (v, B, prevAvailable, v != val, nS2);
1483
1526
if (___res && ___res->getType () != v->getType ()) {
1484
1527
llvm::errs () << *newFunc << " \n " ;
1485
1528
llvm::errs () << " v = " << *v << " res = " << *___res << " \n " ;
@@ -1771,12 +1814,19 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1771
1814
return nullptr ;
1772
1815
}
1773
1816
}
1774
- if (scope)
1817
+ BasicBlock *nS2 = scope;
1818
+ if (BasicBlock *forwardBlock = scope)
1775
1819
if (auto opinst = dyn_cast<Instruction>(nval)) {
1776
- nval = fixLCSSA (opinst, scope);
1820
+ if (!isOriginalBlock (*forwardBlock)) {
1821
+ forwardBlock = originalForReverseBlock (*forwardBlock);
1822
+ }
1823
+ if (isPotentialLastLoopValue (opinst, forwardBlock, LI)) {
1824
+ nval = fixLCSSA (opinst, forwardBlock);
1825
+ nS2 = nullptr ;
1826
+ }
1777
1827
}
1778
- auto toreturn =
1779
- lookupM (nval, BuilderM, available, /* tryLegalRecomputeCheck*/ false );
1828
+ auto toreturn = lookupM (nval, BuilderM, available,
1829
+ /* tryLegalRecomputeCheck*/ false , nS2 );
1780
1830
assert (val->getType () == toreturn->getType ());
1781
1831
return toreturn;
1782
1832
}
@@ -4974,7 +5024,7 @@ end:;
4974
5024
4975
5025
Value *GradientUtils::lookupM (Value *val, IRBuilder<> &BuilderM,
4976
5026
const ValueToValueMapTy &incoming_available,
4977
- bool tryLegalRecomputeCheck) {
5027
+ bool tryLegalRecomputeCheck, BasicBlock *scope ) {
4978
5028
4979
5029
assert (mode == DerivativeMode::ReverseModePrimal ||
4980
5030
mode == DerivativeMode::ReverseModeGradient ||
@@ -5014,6 +5064,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
5014
5064
}
5015
5065
assert (inst->getParent ()->getParent () == newFunc);
5016
5066
assert (BuilderM.GetInsertBlock ()->getParent () == newFunc);
5067
+ if (scope == nullptr )
5068
+ scope = BuilderM.GetInsertBlock ();
5069
+ assert (scope->getParent () == newFunc);
5017
5070
5018
5071
bool reduceRegister = false ;
5019
5072
@@ -5241,12 +5294,14 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
5241
5294
Instruction *prelcssaInst = inst;
5242
5295
5243
5296
assert (inst->getName () != " <badref>" );
5244
- val = fixLCSSA (inst, BuilderM. GetInsertBlock () );
5297
+ val = fixLCSSA (inst, scope );
5245
5298
if (isa<UndefValue>(val)) {
5246
5299
llvm::errs () << *oldFunc << " \n " ;
5247
5300
llvm::errs () << *newFunc << " \n " ;
5248
5301
llvm::errs () << *BuilderM.GetInsertBlock () << " \n " ;
5302
+ llvm::errs () << *scope << " \n " ;
5249
5303
llvm::errs () << *val << " inst " << *inst << " \n " ;
5304
+ assert (0 && " undef value upon lcssa" );
5250
5305
}
5251
5306
inst = cast<Instruction>(val);
5252
5307
assert (prelcssaInst->getType () == inst->getType ());
@@ -5273,7 +5328,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
5273
5328
(lrc = legalRecompute (prelcssaInst, available, &BuilderM))) {
5274
5329
if ((src = shouldRecompute (prelcssaInst, available, &BuilderM))) {
5275
5330
auto op = unwrapM (prelcssaInst, BuilderM, available,
5276
- UnwrapMode::AttemptSingleUnwrap);
5331
+ UnwrapMode::AttemptSingleUnwrap, scope );
5277
5332
if (op) {
5278
5333
assert (op);
5279
5334
assert (op->getType ());
@@ -5571,8 +5626,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
5571
5626
}
5572
5627
Value *recomp = unwrapM (
5573
5628
getNewFromOriginal (SI->getValueOperand ()), BuilderM,
5574
- ThreadLookup, UnwrapMode::AttemptFullUnwrap,
5575
- /* scope*/ nullptr ,
5629
+ ThreadLookup, UnwrapMode::AttemptFullUnwrap, scope,
5576
5630
/* permitCache*/ false );
5577
5631
if (recomp) {
5578
5632
resultValue = recomp;
@@ -6033,7 +6087,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
6033
6087
" tryLegalRecomputeCheck: " , tryLegalRecomputeCheck);
6034
6088
}
6035
6089
6036
- BasicBlock *scope = inst->getParent ();
6090
+ BasicBlock *scopeI = inst->getParent ();
6037
6091
if (auto origInst = isOriginal (inst)) {
6038
6092
auto found = rematerializableAllocations.find (origInst);
6039
6093
if (found != rematerializableAllocations.end ())
@@ -6048,7 +6102,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
6048
6102
// within the loop, force an entry-level scope so there is no need
6049
6103
// to cache.
6050
6104
if (!cacheWholeAllocation)
6051
- scope = &newFunc->getEntryBlock ();
6105
+ scopeI = &newFunc->getEntryBlock ();
6052
6106
}
6053
6107
} else {
6054
6108
for (auto pair : backwardsOnlyShadows) {
@@ -6057,13 +6111,13 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
6057
6111
pair.second .LI ->contains (pinst->getParent ())) {
6058
6112
auto found = invertedPointers.find (pair.first );
6059
6113
if (found != invertedPointers.end () && found->second == inst) {
6060
- scope = &newFunc->getEntryBlock ();
6114
+ scopeI = &newFunc->getEntryBlock ();
6061
6115
6062
6116
// Prevent the phi node from being stored into the cache by creating
6063
6117
// it before the ensureLookupCached.
6064
6118
if (scopeMap.find (inst) == scopeMap.end ()) {
6065
6119
LimitContext lctx (/* ReverseLimit*/ reverseBlocks.size () > 0 ,
6066
- scope );
6120
+ scopeI );
6067
6121
6068
6122
AllocaInst *cache = createCacheForScope (
6069
6123
lctx, inst->getType (), inst->getName (), /* shouldFree*/ true );
@@ -6078,7 +6132,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
6078
6132
}
6079
6133
}
6080
6134
6081
- ensureLookupCached (inst, /* shouldFree*/ true , scope ,
6135
+ ensureLookupCached (inst, /* shouldFree*/ true , scopeI ,
6082
6136
inst->getMetadata (LLVMContext::MD_tbaa));
6083
6137
bool isi1 = inst->getType ()->isIntegerTy () &&
6084
6138
cast<IntegerType>(inst->getType ())->getBitWidth () == 1 ;
0 commit comments