Skip to content

Commit 3fd233c

Browse files
authored
Move TypeResults into GradientUtils (rust-lang#660)
1 parent f0f64db commit 3fd233c

File tree

12 files changed

+273
-293
lines changed

12 files changed

+273
-293
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,8 @@ static inline void propagateArgumentInformation(
457457
/// Return whether this instruction is known not to propagate adjoints
458458
/// Note that instructions could return an active pointer, but
459459
/// do not propagate adjoints themselves
460-
bool ActivityAnalyzer::isConstantInstruction(TypeResults &TR, Instruction *I) {
460+
bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR,
461+
Instruction *I) {
461462
// This analysis may only be called by instructions corresponding to
462463
// the function analyzed by TypeInfo
463464
assert(I);
@@ -787,7 +788,7 @@ bool isValuePotentiallyUsedAsPointer(llvm::Value *val) {
787788
return false;
788789
}
789790

790-
bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
791+
bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
791792
// This analysis may only be called by instructions corresponding to
792793
// the function analyzed by TypeInfo -- however if the Value
793794
// was created outside a function (e.g. global, constant), that is allowed
@@ -1919,7 +1920,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
19191920
}
19201921

19211922
/// Is the instruction guaranteed to be inactive because of its operands
1922-
bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults &TR,
1923+
bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR,
19231924
llvm::Value *val) {
19241925
// Must be an analyzer only searching up
19251926
assert(directions == UP);
@@ -2175,7 +2176,7 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults &TR,
21752176
}
21762177

21772178
/// Is the value free of any active uses
2178-
bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults &TR,
2179+
bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR,
21792180
llvm::Value *val,
21802181
UseActivity PUA,
21812182
Instruction **FoundInst) {
@@ -2365,7 +2366,7 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults &TR,
23652366
}
23662367

23672368
/// Is the value potentially actively returned or stored
2368-
bool ActivityAnalyzer::isValueActivelyStoredOrReturned(TypeResults &TR,
2369+
bool ActivityAnalyzer::isValueActivelyStoredOrReturned(TypeResults const &TR,
23692370
llvm::Value *val,
23702371
bool outside) {
23712372
// Must be an analyzer only searching down
@@ -2503,7 +2504,7 @@ bool ActivityAnalyzer::isValueActivelyStoredOrReturned(TypeResults &TR,
25032504
return false;
25042505
}
25052506

2506-
void ActivityAnalyzer::InsertConstantInstruction(TypeResults &TR,
2507+
void ActivityAnalyzer::InsertConstantInstruction(TypeResults const &TR,
25072508
llvm::Instruction *I) {
25082509
ConstantInstructions.insert(I);
25092510
auto found = ReEvaluateValueIfInactiveInst.find(I);
@@ -2522,7 +2523,8 @@ void ActivityAnalyzer::InsertConstantInstruction(TypeResults &TR,
25222523
}
25232524
}
25242525

2525-
void ActivityAnalyzer::InsertConstantValue(TypeResults &TR, llvm::Value *V) {
2526+
void ActivityAnalyzer::InsertConstantValue(TypeResults const &TR,
2527+
llvm::Value *V) {
25262528
ConstantValues.insert(V);
25272529
auto found = ReEvaluateValueIfInactiveValue.find(V);
25282530
if (found != ReEvaluateValueIfInactiveValue.end()) {

enzyme/Enzyme/ActivityAnalysis.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ class ActivityAnalyzer {
120120
/// Return whether this instruction is known not to propagate adjoints
121121
/// Note that instructions could return an active pointer, but
122122
/// do not propagate adjoints themselves
123-
bool isConstantInstruction(TypeResults &TR, llvm::Instruction *inst);
123+
bool isConstantInstruction(TypeResults const &TR, llvm::Instruction *inst);
124124

125125
/// Return whether this values is known not to contain derivative
126126
// information, either directly or as a pointer to
127-
bool isConstantValue(TypeResults &TR, llvm::Value *val);
127+
bool isConstantValue(TypeResults const &TR, llvm::Value *val);
128128

129129
private:
130130
llvm::DenseMap<llvm::Instruction *, llvm::SmallPtrSet<llvm::Value *, 4>>
@@ -135,8 +135,8 @@ class ActivityAnalyzer {
135135
llvm::DenseMap<llvm::Value *, llvm::SmallPtrSet<llvm::Instruction *, 4>>
136136
ReEvaluateInstIfInactiveValue;
137137

138-
void InsertConstantInstruction(TypeResults &TR, llvm::Instruction *I);
139-
void InsertConstantValue(TypeResults &TR, llvm::Value *V);
138+
void InsertConstantInstruction(TypeResults const &TR, llvm::Instruction *I);
139+
void InsertConstantValue(TypeResults const &TR, llvm::Value *V);
140140

141141
/// Create a new analyzer starting from an existing Analyzer
142142
/// This is used to perform inductive assumptions
@@ -154,7 +154,8 @@ class ActivityAnalyzer {
154154
}
155155

156156
/// Import known constants from an existing analyzer
157-
void insertConstantsFrom(TypeResults &TR, ActivityAnalyzer &Hypothesis) {
157+
void insertConstantsFrom(TypeResults const &TR,
158+
ActivityAnalyzer &Hypothesis) {
158159
for (auto I : Hypothesis.ConstantInstructions) {
159160
InsertConstantInstruction(TR, I);
160161
}
@@ -164,7 +165,7 @@ class ActivityAnalyzer {
164165
}
165166

166167
/// Import known data from an existing analyzer
167-
void insertAllFrom(TypeResults &TR, ActivityAnalyzer &Hypothesis,
168+
void insertAllFrom(TypeResults const &TR, ActivityAnalyzer &Hypothesis,
168169
llvm::Value *Orig) {
169170
insertConstantsFrom(TR, Hypothesis);
170171
for (auto I : Hypothesis.ActiveInstructions) {
@@ -185,7 +186,7 @@ class ActivityAnalyzer {
185186
bool isFunctionArgumentConstant(llvm::CallInst *CI, llvm::Value *val);
186187

187188
/// Is the instruction guaranteed to be inactive because of its operands
188-
bool isInstructionInactiveFromOrigin(TypeResults &TR, llvm::Value *val);
189+
bool isInstructionInactiveFromOrigin(TypeResults const &TR, llvm::Value *val);
189190

190191
public:
191192
enum class UseActivity {
@@ -199,12 +200,12 @@ class ActivityAnalyzer {
199200
OnlyStores = 2
200201
};
201202
/// Is the value free of any active uses
202-
bool isValueInactiveFromUsers(TypeResults &TR, llvm::Value *val,
203+
bool isValueInactiveFromUsers(TypeResults const &TR, llvm::Value *val,
203204
UseActivity UA,
204205
llvm::Instruction **FoundInst = nullptr);
205206

206207
/// Is the value potentially actively returned or stored
207-
bool isValueActivelyStoredOrReturned(TypeResults &TR, llvm::Value *val,
208+
bool isValueActivelyStoredOrReturned(TypeResults const &TR, llvm::Value *val,
208209
bool outside = false);
209210

210211
private:

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class AdjointGenerator
5353
GradientUtils *const gutils;
5454
const std::vector<DIFFE_TYPE> &constant_args;
5555
DIFFE_TYPE retType;
56-
TypeResults &TR;
56+
TypeResults &TR = gutils->TR;
5757
std::function<unsigned(Instruction *, CacheType)> getIndex;
5858
const std::map<CallInst *, const std::map<Argument *, bool>>
5959
uncacheable_args_map;
@@ -71,7 +71,6 @@ class AdjointGenerator
7171
AdjointGenerator(
7272
DerivativeMode Mode, GradientUtils *gutils,
7373
const std::vector<DIFFE_TYPE> &constant_args, DIFFE_TYPE retType,
74-
TypeResults &TR,
7574
std::function<unsigned(Instruction *, CacheType)> getIndex,
7675
const std::map<CallInst *, const std::map<Argument *, bool>>
7776
uncacheable_args_map,
@@ -84,7 +83,7 @@ class AdjointGenerator
8483
const SmallPtrSetImpl<BasicBlock *> &oldUnreachable,
8584
AllocaInst *dretAlloca)
8685
: Mode(Mode), gutils(gutils), constant_args(constant_args),
87-
retType(retType), TR(TR), getIndex(getIndex),
86+
retType(retType), getIndex(getIndex),
8887
uncacheable_args_map(uncacheable_args_map), returnuses(returnuses),
8988
augmentedReturn(augmentedReturn), replacedReturns(replacedReturns),
9089
unnecessaryValues(unnecessaryValues),
@@ -435,7 +434,7 @@ class AdjointGenerator
435434
auto placeholder = cast<PHINode>(&*found->second);
436435
gutils->invertedPointers.erase(found);
437436

438-
if (!is_value_needed_in_reverse<ValueType::Shadow>(TR, gutils, &I, Mode,
437+
if (!is_value_needed_in_reverse<ValueType::Shadow>(gutils, &I, Mode,
439438
oldUnreachable)) {
440439
gutils->erase(placeholder);
441440
return;
@@ -505,7 +504,7 @@ class AdjointGenerator
505504

506505
IRBuilder<> BuilderZ(newi);
507506
// only make shadow where caching needed
508-
if (!is_value_needed_in_reverse<ValueType::Shadow>(TR, gutils, &I, Mode,
507+
if (!is_value_needed_in_reverse<ValueType::Shadow>(gutils, &I, Mode,
509508
oldUnreachable)) {
510509
gutils->erase(placeholder);
511510
return;
@@ -551,7 +550,7 @@ class AdjointGenerator
551550
// TODO: In the case of fwd mode this should be true if the loaded value
552551
// itself is used as a pointer.
553552
bool needShadow = is_value_needed_in_reverse<ValueType::Shadow>(
554-
TR, gutils, &I, Mode, oldUnreachable);
553+
gutils, &I, Mode, oldUnreachable);
555554

556555
switch (Mode) {
557556

@@ -564,7 +563,7 @@ class AdjointGenerator
564563
assert(newip->getType() == type);
565564
if (Mode == DerivativeMode::ReverseModePrimal && can_modref &&
566565
is_value_needed_in_reverse<ValueType::Shadow>(
567-
TR, gutils, &I, DerivativeMode::ReverseModeGradient,
566+
gutils, &I, DerivativeMode::ReverseModeGradient,
568567
oldUnreachable)) {
569568
gutils->cacheForReverse(BuilderZ, newip,
570569
getIndex(&I, CacheType::Shadow));
@@ -632,7 +631,7 @@ class AdjointGenerator
632631
primalNeededInReverse = true;
633632
}
634633
primalNeededInReverse |= is_value_needed_in_reverse<ValueType::Primal>(
635-
TR, gutils, &I, Mode, Seen, oldUnreachable);
634+
gutils, &I, Mode, Seen, oldUnreachable);
636635
if (primalNeededInReverse) {
637636
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&I));
638637
inst = gutils->cacheForReverse(BuilderZ, newi,
@@ -8283,8 +8282,8 @@ class AdjointGenerator
82838282
} else {
82848283
if (!orig->getType()->isFPOrFPVectorTy() &&
82858284
TR.query(orig).Inner0().isPossiblePointer()) {
8286-
if (is_value_needed_in_reverse<ValueType::Shadow>(
8287-
TR, gutils, orig, Mode, oldUnreachable)) {
8285+
if (is_value_needed_in_reverse<ValueType::Shadow>(gutils, orig, Mode,
8286+
oldUnreachable)) {
82888287
subretType = DIFFE_TYPE::DUP_ARG;
82898288
shadowReturnUsed = true;
82908289
} else
@@ -8364,7 +8363,7 @@ class AdjointGenerator
83648363
if (!orig->getType()->isFPOrFPVectorTy() &&
83658364
TR.query(orig).Inner0().isPossiblePointer()) {
83668365
if (is_value_needed_in_reverse<ValueType::Shadow>(
8367-
TR, gutils, orig, DerivativeMode::ReverseModePrimal,
8366+
gutils, orig, DerivativeMode::ReverseModePrimal,
83688367
oldUnreachable)) {
83698368
hasNonReturnUse = true;
83708369
}
@@ -8448,7 +8447,7 @@ class AdjointGenerator
84488447
if (!pair.second)
84498448
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
84508449
primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
8451-
TR, gutils, orig, Mode, Seen, oldUnreachable);
8450+
gutils, orig, Mode, Seen, oldUnreachable);
84528451
}
84538452
if (subretused && primalNeededInReverse) {
84548453
if (normalReturn != newCall) {
@@ -9352,7 +9351,7 @@ class AdjointGenerator
93529351
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
93539352
bool primalNeededInReverse =
93549353
is_value_needed_in_reverse<ValueType::Primal>(
9355-
TR, gutils, orig, Mode, Seen, oldUnreachable);
9354+
gutils, orig, Mode, Seen, oldUnreachable);
93569355
shouldCache = primalNeededInReverse;
93579356
}
93589357

@@ -10727,7 +10726,7 @@ class AdjointGenerator
1072710726
Mode == DerivativeMode::ForwardMode
1072810727
? false
1072910728
: is_value_needed_in_reverse<ValueType::Primal>(
10730-
TR, gutils, orig, Mode, Seen, oldUnreachable);
10729+
gutils, orig, Mode, Seen, oldUnreachable);
1073110730

1073210731
bool cacheWholeAllocation = false;
1073310732
if (gutils->knownRecomputeHeuristic.count(orig)) {
@@ -10957,7 +10956,7 @@ class AdjointGenerator
1095710956
Mode == DerivativeMode::ForwardModeSplit)
1095810957
? true
1095910958
: is_value_needed_in_reverse<ValueType::Shadow>(
10960-
TR, gutils, orig, Mode, oldUnreachable);
10959+
gutils, orig, Mode, oldUnreachable);
1096110960
if (!needShadow) {
1096210961
gutils->invertedPointers.erase(ifound);
1096310962
gutils->erase(placeholder);
@@ -11175,7 +11174,7 @@ class AdjointGenerator
1117511174
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
1117611175
bool primalNeededInReverse =
1117711176
is_value_needed_in_reverse<ValueType::Primal>(
11178-
TR, gutils, rmat.first, Mode, Seen, oldUnreachable);
11177+
gutils, rmat.first, Mode, Seen, oldUnreachable);
1117911178
bool cacheWholeAllocation = false;
1118011179
if (gutils->knownRecomputeHeuristic.count(rmat.first)) {
1118111180
if (!gutils->knownRecomputeHeuristic[rmat.first]) {
@@ -11508,7 +11507,7 @@ class AdjointGenerator
1150811507
return;
1150911508
}
1151011509

11511-
bool modifyPrimal = shouldAugmentCall(orig, gutils, TR);
11510+
bool modifyPrimal = shouldAugmentCall(orig, gutils);
1151211511

1151311512
SmallVector<Value *, 8> args;
1151411513
SmallVector<Value *, 8> pre_args;
@@ -11622,7 +11621,7 @@ class AdjointGenerator
1162211621

1162311622
if (Mode == DerivativeMode::ReverseModeCombined && !foreignFunction) {
1162411623
replaceFunction = legalCombinedForwardReverse(
11625-
orig, *replacedReturns, postCreate, userReplace, gutils, TR,
11624+
orig, *replacedReturns, postCreate, userReplace, gutils,
1162611625
unnecessaryInstructions, oldUnreachable, subretused);
1162711626
if (replaceFunction)
1162811627
modifyPrimal = false;
@@ -11842,8 +11841,8 @@ class AdjointGenerator
1184211841
}
1184311842

1184411843
if (Mode == DerivativeMode::ReverseModePrimal &&
11845-
is_value_needed_in_reverse<ValueType::Primal>(
11846-
TR, gutils, orig, Mode, oldUnreachable) &&
11844+
is_value_needed_in_reverse<ValueType::Primal>(gutils, orig, Mode,
11845+
oldUnreachable) &&
1184711846
!gutils->unnecessaryIntermediates.count(orig)) {
1184811847
gutils->cacheForReverse(BuilderZ, dcall,
1184911848
getIndex(orig, CacheType::Self));
@@ -11876,8 +11875,8 @@ class AdjointGenerator
1187611875
}
1187711876

1187811877
if (subretused) {
11879-
if (is_value_needed_in_reverse<ValueType::Primal>(
11880-
TR, gutils, orig, Mode, oldUnreachable) &&
11878+
if (is_value_needed_in_reverse<ValueType::Primal>(gutils, orig, Mode,
11879+
oldUnreachable) &&
1188111880
!gutils->unnecessaryIntermediates.count(orig)) {
1188211881
cachereplace = BuilderZ.CreatePHI(orig->getType(), 1,
1188311882
orig->getName() + "_tmpcacheB");
@@ -11983,8 +11982,8 @@ class AdjointGenerator
1198311982
}
1198411983
if (/*!topLevel*/ Mode != DerivativeMode::ReverseModeCombined &&
1198511984
subretused && !orig->doesNotAccessMemory()) {
11986-
if (is_value_needed_in_reverse<ValueType::Primal>(
11987-
TR, gutils, orig, Mode, oldUnreachable) &&
11985+
if (is_value_needed_in_reverse<ValueType::Primal>(gutils, orig, Mode,
11986+
oldUnreachable) &&
1198811987
!gutils->unnecessaryIntermediates.count(orig)) {
1198911988
assert(!replaceFunction);
1199011989
cachereplace = BuilderZ.CreatePHI(orig->getType(), 1,

enzyme/Enzyme/CApi.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ void *EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI,
228228
}
229229

230230
void *EnzymeGradientUtilsTypeAnalyzer(GradientUtils *G) {
231-
return (void *)&G->my_TR->analyzer;
231+
return (void *)&G->TR.analyzer;
232232
}
233233

234234
void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
@@ -371,15 +371,13 @@ LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) {
371371
CTypeTreeRef EnzymeGradientUtilsAllocAndGetTypeTree(GradientUtils *gutils,
372372
LLVMValueRef val) {
373373
auto v = unwrap(val);
374-
assert(gutils->my_TR);
375-
TypeTree TT = gutils->my_TR->query(v);
374+
TypeTree TT = gutils->TR.query(v);
376375
TypeTree *pTT = new TypeTree(TT);
377376
return (CTypeTreeRef)pTT;
378377
}
379378

380379
void EnzymeGradientUtilsDumpTypeResults(GradientUtils *gutils) {
381-
assert(gutils->my_TR);
382-
gutils->my_TR->dump();
380+
gutils->TR.dump();
383381
}
384382

385383
void EnzymeGradientUtilsSubTransferHelper(

0 commit comments

Comments
 (0)