@@ -53,7 +53,7 @@ class AdjointGenerator
53
53
GradientUtils *const gutils;
54
54
const std::vector<DIFFE_TYPE> &constant_args;
55
55
DIFFE_TYPE retType;
56
- TypeResults &TR;
56
+ TypeResults &TR = gutils->TR ;
57
57
std::function<unsigned(Instruction *, CacheType)> getIndex;
58
58
const std::map<CallInst *, const std::map<Argument *, bool>>
59
59
uncacheable_args_map;
@@ -71,7 +71,6 @@ class AdjointGenerator
71
71
AdjointGenerator(
72
72
DerivativeMode Mode, GradientUtils *gutils,
73
73
const std::vector<DIFFE_TYPE> &constant_args, DIFFE_TYPE retType,
74
- TypeResults &TR,
75
74
std::function<unsigned(Instruction *, CacheType)> getIndex,
76
75
const std::map<CallInst *, const std::map<Argument *, bool>>
77
76
uncacheable_args_map,
@@ -84,7 +83,7 @@ class AdjointGenerator
84
83
const SmallPtrSetImpl<BasicBlock *> &oldUnreachable,
85
84
AllocaInst *dretAlloca)
86
85
: Mode(Mode), gutils(gutils), constant_args(constant_args),
87
- retType(retType), TR(TR), getIndex(getIndex),
86
+ retType(retType), getIndex(getIndex),
88
87
uncacheable_args_map(uncacheable_args_map), returnuses(returnuses),
89
88
augmentedReturn(augmentedReturn), replacedReturns(replacedReturns),
90
89
unnecessaryValues(unnecessaryValues),
@@ -435,7 +434,7 @@ class AdjointGenerator
435
434
auto placeholder = cast<PHINode>(&*found->second);
436
435
gutils->invertedPointers.erase(found);
437
436
438
- if (!is_value_needed_in_reverse<ValueType::Shadow>(TR, gutils, &I, Mode,
437
+ if (!is_value_needed_in_reverse<ValueType::Shadow>(gutils, &I, Mode,
439
438
oldUnreachable)) {
440
439
gutils->erase(placeholder);
441
440
return;
@@ -505,7 +504,7 @@ class AdjointGenerator
505
504
506
505
IRBuilder<> BuilderZ(newi);
507
506
// only make shadow where caching needed
508
- if (!is_value_needed_in_reverse<ValueType::Shadow>(TR, gutils, &I, Mode,
507
+ if (!is_value_needed_in_reverse<ValueType::Shadow>(gutils, &I, Mode,
509
508
oldUnreachable)) {
510
509
gutils->erase(placeholder);
511
510
return;
@@ -551,7 +550,7 @@ class AdjointGenerator
551
550
// TODO: In the case of fwd mode this should be true if the loaded value
552
551
// itself is used as a pointer.
553
552
bool needShadow = is_value_needed_in_reverse<ValueType::Shadow>(
554
- TR, gutils, &I, Mode, oldUnreachable);
553
+ gutils, &I, Mode, oldUnreachable);
555
554
556
555
switch (Mode) {
557
556
@@ -564,7 +563,7 @@ class AdjointGenerator
564
563
assert(newip->getType() == type);
565
564
if (Mode == DerivativeMode::ReverseModePrimal && can_modref &&
566
565
is_value_needed_in_reverse<ValueType::Shadow>(
567
- TR, gutils, &I, DerivativeMode::ReverseModeGradient,
566
+ gutils, &I, DerivativeMode::ReverseModeGradient,
568
567
oldUnreachable)) {
569
568
gutils->cacheForReverse(BuilderZ, newip,
570
569
getIndex(&I, CacheType::Shadow));
@@ -632,7 +631,7 @@ class AdjointGenerator
632
631
primalNeededInReverse = true;
633
632
}
634
633
primalNeededInReverse |= is_value_needed_in_reverse<ValueType::Primal>(
635
- TR, gutils, &I, Mode, Seen, oldUnreachable);
634
+ gutils, &I, Mode, Seen, oldUnreachable);
636
635
if (primalNeededInReverse) {
637
636
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&I));
638
637
inst = gutils->cacheForReverse(BuilderZ, newi,
@@ -8283,8 +8282,8 @@ class AdjointGenerator
8283
8282
} else {
8284
8283
if (!orig->getType()->isFPOrFPVectorTy() &&
8285
8284
TR.query(orig).Inner0().isPossiblePointer()) {
8286
- if (is_value_needed_in_reverse<ValueType::Shadow>(
8287
- TR, gutils, orig, Mode, oldUnreachable)) {
8285
+ if (is_value_needed_in_reverse<ValueType::Shadow>(gutils, orig, Mode,
8286
+ oldUnreachable)) {
8288
8287
subretType = DIFFE_TYPE::DUP_ARG;
8289
8288
shadowReturnUsed = true;
8290
8289
} else
@@ -8364,7 +8363,7 @@ class AdjointGenerator
8364
8363
if (!orig->getType()->isFPOrFPVectorTy() &&
8365
8364
TR.query(orig).Inner0().isPossiblePointer()) {
8366
8365
if (is_value_needed_in_reverse<ValueType::Shadow>(
8367
- TR, gutils, orig, DerivativeMode::ReverseModePrimal,
8366
+ gutils, orig, DerivativeMode::ReverseModePrimal,
8368
8367
oldUnreachable)) {
8369
8368
hasNonReturnUse = true;
8370
8369
}
@@ -8448,7 +8447,7 @@ class AdjointGenerator
8448
8447
if (!pair.second)
8449
8448
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
8450
8449
primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
8451
- TR, gutils, orig, Mode, Seen, oldUnreachable);
8450
+ gutils, orig, Mode, Seen, oldUnreachable);
8452
8451
}
8453
8452
if (subretused && primalNeededInReverse) {
8454
8453
if (normalReturn != newCall) {
@@ -9352,7 +9351,7 @@ class AdjointGenerator
9352
9351
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
9353
9352
bool primalNeededInReverse =
9354
9353
is_value_needed_in_reverse<ValueType::Primal>(
9355
- TR, gutils, orig, Mode, Seen, oldUnreachable);
9354
+ gutils, orig, Mode, Seen, oldUnreachable);
9356
9355
shouldCache = primalNeededInReverse;
9357
9356
}
9358
9357
@@ -10727,7 +10726,7 @@ class AdjointGenerator
10727
10726
Mode == DerivativeMode::ForwardMode
10728
10727
? false
10729
10728
: is_value_needed_in_reverse<ValueType::Primal>(
10730
- TR, gutils, orig, Mode, Seen, oldUnreachable);
10729
+ gutils, orig, Mode, Seen, oldUnreachable);
10731
10730
10732
10731
bool cacheWholeAllocation = false;
10733
10732
if (gutils->knownRecomputeHeuristic.count(orig)) {
@@ -10957,7 +10956,7 @@ class AdjointGenerator
10957
10956
Mode == DerivativeMode::ForwardModeSplit)
10958
10957
? true
10959
10958
: is_value_needed_in_reverse<ValueType::Shadow>(
10960
- TR, gutils, orig, Mode, oldUnreachable);
10959
+ gutils, orig, Mode, oldUnreachable);
10961
10960
if (!needShadow) {
10962
10961
gutils->invertedPointers.erase(ifound);
10963
10962
gutils->erase(placeholder);
@@ -11175,7 +11174,7 @@ class AdjointGenerator
11175
11174
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
11176
11175
bool primalNeededInReverse =
11177
11176
is_value_needed_in_reverse<ValueType::Primal>(
11178
- TR, gutils, rmat.first, Mode, Seen, oldUnreachable);
11177
+ gutils, rmat.first, Mode, Seen, oldUnreachable);
11179
11178
bool cacheWholeAllocation = false;
11180
11179
if (gutils->knownRecomputeHeuristic.count(rmat.first)) {
11181
11180
if (!gutils->knownRecomputeHeuristic[rmat.first]) {
@@ -11508,7 +11507,7 @@ class AdjointGenerator
11508
11507
return;
11509
11508
}
11510
11509
11511
- bool modifyPrimal = shouldAugmentCall(orig, gutils, TR );
11510
+ bool modifyPrimal = shouldAugmentCall(orig, gutils);
11512
11511
11513
11512
SmallVector<Value *, 8> args;
11514
11513
SmallVector<Value *, 8> pre_args;
@@ -11622,7 +11621,7 @@ class AdjointGenerator
11622
11621
11623
11622
if (Mode == DerivativeMode::ReverseModeCombined && !foreignFunction) {
11624
11623
replaceFunction = legalCombinedForwardReverse(
11625
- orig, *replacedReturns, postCreate, userReplace, gutils, TR,
11624
+ orig, *replacedReturns, postCreate, userReplace, gutils,
11626
11625
unnecessaryInstructions, oldUnreachable, subretused);
11627
11626
if (replaceFunction)
11628
11627
modifyPrimal = false;
@@ -11842,8 +11841,8 @@ class AdjointGenerator
11842
11841
}
11843
11842
11844
11843
if (Mode == DerivativeMode::ReverseModePrimal &&
11845
- is_value_needed_in_reverse<ValueType::Primal>(
11846
- TR, gutils, orig, Mode, oldUnreachable) &&
11844
+ is_value_needed_in_reverse<ValueType::Primal>(gutils, orig, Mode,
11845
+ oldUnreachable) &&
11847
11846
!gutils->unnecessaryIntermediates.count(orig)) {
11848
11847
gutils->cacheForReverse(BuilderZ, dcall,
11849
11848
getIndex(orig, CacheType::Self));
@@ -11876,8 +11875,8 @@ class AdjointGenerator
11876
11875
}
11877
11876
11878
11877
if (subretused) {
11879
- if (is_value_needed_in_reverse<ValueType::Primal>(
11880
- TR, gutils, orig, Mode, oldUnreachable) &&
11878
+ if (is_value_needed_in_reverse<ValueType::Primal>(gutils, orig, Mode,
11879
+ oldUnreachable) &&
11881
11880
!gutils->unnecessaryIntermediates.count(orig)) {
11882
11881
cachereplace = BuilderZ.CreatePHI(orig->getType(), 1,
11883
11882
orig->getName() + "_tmpcacheB");
@@ -11983,8 +11982,8 @@ class AdjointGenerator
11983
11982
}
11984
11983
if (/*!topLevel*/ Mode != DerivativeMode::ReverseModeCombined &&
11985
11984
subretused && !orig->doesNotAccessMemory()) {
11986
- if (is_value_needed_in_reverse<ValueType::Primal>(
11987
- TR, gutils, orig, Mode, oldUnreachable) &&
11985
+ if (is_value_needed_in_reverse<ValueType::Primal>(gutils, orig, Mode,
11986
+ oldUnreachable) &&
11988
11987
!gutils->unnecessaryIntermediates.count(orig)) {
11989
11988
assert(!replaceFunction);
11990
11989
cachereplace = BuilderZ.CreatePHI(orig->getType(), 1,
0 commit comments