Skip to content

Commit b03a851

Browse files
committed
only perform lookup for reverse builders
1 parent f93f942 commit b03a851

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -746,18 +746,13 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
746746
return nullptr;
747747
}
748748

749-
Value *pidx;
749+
Value *pidx = nullptr;
750750

751-
switch (mode) {
752-
case DerivativeMode::ForwardMode:
751+
if (isOriginalBlock(*BuilderM.GetInsertBlock())) {
753752
pidx = invertPointerM(dli->getOperand(0), BuilderM);
754-
break;
755-
case DerivativeMode::ReverseModePrimal:
756-
case DerivativeMode::ReverseModeGradient:
757-
case DerivativeMode::ReverseModeCombined:
753+
} else {
758754
pidx =
759755
lookupM(invertPointerM(dli->getOperand(0), BuilderM), BuilderM);
760-
break;
761756
}
762757

763758
if (pidx == nullptr)

enzyme/Enzyme/GradientUtils.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -934,16 +934,13 @@ class GradientUtils : public CacheUtility {
934934
if (auto arg = dyn_cast<Argument>(ptr)) {
935935
assert(arg->getParent() == oldFunc);
936936
}
937-
switch (mode) {
938-
case DerivativeMode::ForwardMode:
937+
938+
if (isOriginalBlock(*BuilderM.GetInsertBlock())) {
939939
ptr = invertPointerM(ptr, BuilderM);
940-
break;
941-
case DerivativeMode::ReverseModePrimal:
942-
case DerivativeMode::ReverseModeGradient:
943-
case DerivativeMode::ReverseModeCombined:
940+
} else {
944941
ptr = lookupM(invertPointerM(ptr, BuilderM), BuilderM);
945-
break;
946942
}
943+
947944
return BuilderM.CreateStore(newval, ptr);
948945
}
949946

0 commit comments

Comments
 (0)