Skip to content

Commit

Permalink
ForwardMode: extractvalue inst (rust-lang#354)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Oct 13, 2021
1 parent ffe46d5 commit ef3a0ac
Showing 1 changed file with 44 additions and 21 deletions.
65 changes: 44 additions & 21 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1260,33 +1260,56 @@ class AdjointGenerator
if (EVI.getType()->isPointerTy())
return;

if (Mode == DerivativeMode::ReverseModePrimal)
switch (Mode) {
case DerivativeMode::ForwardMode: {
IRBuilder<> Builder2(&EVI);
getForwardBuilder(Builder2);

Value *orig_aggregate = EVI.getAggregateOperand();

Value *diffe_aggregate =
gutils->isConstantValue(orig_aggregate)
? ConstantAggregate::getNullValue(orig_aggregate->getType())
: diffe(orig_aggregate, Builder2);
Value *diffe =
Builder2.CreateExtractValue(diffe_aggregate, EVI.getIndices());

setDiffe(&EVI, diffe, Builder2);
return;
}
case DerivativeMode::ReverseModeGradient:
case DerivativeMode::ReverseModeCombined: {
IRBuilder<> Builder2(EVI.getParent());
getReverseBuilder(Builder2);

Value *orig_op0 = EVI.getOperand(0);
Value *orig_op0 = EVI.getOperand(0);

IRBuilder<> Builder2(EVI.getParent());
getReverseBuilder(Builder2);
auto prediff = diffe(&EVI, Builder2);

auto prediff = diffe(&EVI, Builder2);
// todo const
if (!gutils->isConstantValue(orig_op0)) {
SmallVector<Value *, 4> sv;
for (auto i : EVI.getIndices())
sv.push_back(ConstantInt::get(Type::getInt32Ty(EVI.getContext()), i));
size_t size = 1;
if (EVI.getType()->isSized())
size =
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
EVI.getType()) +
7) /
8;
((DiffeGradientUtils *)gutils)
->addToDiffe(orig_op0, prediff, Builder2, TR.addingType(size, &EVI),
sv);
}

// todo const
if (!gutils->isConstantValue(orig_op0)) {
SmallVector<Value *, 4> sv;
for (auto i : EVI.getIndices())
sv.push_back(ConstantInt::get(Type::getInt32Ty(EVI.getContext()), i));
size_t size = 1;
if (EVI.getType()->isSized())
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
EVI.getType()) +
7) /
8;
((DiffeGradientUtils *)gutils)
->addToDiffe(orig_op0, prediff, Builder2, TR.addingType(size, &EVI),
sv);
setDiffe(&EVI, Constant::getNullValue(EVI.getType()), Builder2);
return;
}
case DerivativeMode::ReverseModePrimal: {
return;
}
}

setDiffe(&EVI, Constant::getNullValue(EVI.getType()), Builder2);
}

void visitInsertValueInst(llvm::InsertValueInst &IVI) {
Expand Down

0 comments on commit ef3a0ac

Please sign in to comment.