@@ -1140,26 +1140,51 @@ class AdjointGenerator
11401140 if (Mode == DerivativeMode::ReverseModePrimal)
11411141 return ;
11421142
1143- IRBuilder<> Builder2 (EEI.getParent ());
1144- getReverseBuilder (Builder2);
1143+ switch (Mode) {
1144+ case DerivativeMode::ForwardMode: {
1145+ IRBuilder<> Builder2 (&EEI);
1146+ getForwardBuilder (Builder2);
11451147
1146- Value *orig_vec = EEI.getVectorOperand ();
1148+ Value *orig_vec = EEI.getVectorOperand ();
11471149
1148- if (!gutils->isConstantValue (orig_vec)) {
1149- SmallVector<Value *, 4 > sv;
1150- sv.push_back (gutils->getNewFromOriginal (EEI.getIndexOperand ()));
1150+ auto vec_diffe = gutils->isConstantValue (orig_vec)
1151+ ? ConstantVector::getNullValue (orig_vec->getType ())
1152+ : diffe (orig_vec, Builder2);
1153+ auto diffe =
1154+ Builder2.CreateExtractElement (vec_diffe, EEI.getIndexOperand ());
11511155
1152- size_t size = 1 ;
1153- if (EEI.getType ()->isSized ())
1154- size = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1155- EEI.getType ()) +
1156- 7 ) /
1157- 8 ;
1158- ((DiffeGradientUtils *)gutils)
1159- ->addToDiffe (orig_vec, diffe (&EEI, Builder2), Builder2,
1160- TR.addingType (size, &EEI), sv);
1156+ setDiffe (&EEI, diffe, Builder2);
1157+ return ;
1158+ }
1159+ case DerivativeMode::ReverseModeGradient:
1160+ case DerivativeMode::ReverseModeCombined: {
1161+ IRBuilder<> Builder2 (EEI.getParent ());
1162+ getReverseBuilder (Builder2);
1163+
1164+ Value *orig_vec = EEI.getVectorOperand ();
1165+
1166+ if (!gutils->isConstantValue (orig_vec)) {
1167+ SmallVector<Value *, 4 > sv;
1168+ sv.push_back (gutils->getNewFromOriginal (EEI.getIndexOperand ()));
1169+
1170+ size_t size = 1 ;
1171+ if (EEI.getType ()->isSized ())
1172+ size =
1173+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1174+ EEI.getType ()) +
1175+ 7 ) /
1176+ 8 ;
1177+ ((DiffeGradientUtils *)gutils)
1178+ ->addToDiffe (orig_vec, diffe (&EEI, Builder2), Builder2,
1179+ TR.addingType (size, &EEI), sv);
1180+ }
1181+ setDiffe (&EEI, Constant::getNullValue (EEI.getType ()), Builder2);
1182+ return ;
1183+ }
1184+ case DerivativeMode::ReverseModePrimal: {
1185+ return ;
1186+ }
11611187 }
1162- setDiffe (&EEI, Constant::getNullValue (EEI.getType ()), Builder2);
11631188 }
11641189
11651190 void visitInsertElementInst (llvm::InsertElementInst &IEI) {
0 commit comments