@@ -203,7 +203,6 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
203203
204204 auto in_idx = 0 ;
205205 SmallVector<mlir::Value, 2 > in_args;
206- SmallVector<Type, 2 > in_ty;
207206 SmallVector<ActivityAttr, 2 > newInActivityArgs;
208207 bool changed = false ;
209208 for (auto [idx, act] : llvm::enumerate (inActivity)) {
@@ -219,67 +218,49 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
219218
220219 case mlir::enzyme::Activity::enzyme_const:
221220 in_args.push_back (inp);
222- in_ty.push_back (inp.getType ());
223221 newInActivityArgs.push_back (iattr);
224222 break ;
225223
226- case Activity::enzyme_dupnoneed:
224+ case Activity::enzyme_dupnoneed: {
227225 // always pass in primal
228226 in_args.push_back (inp);
229- in_ty.push_back (inp.getType ());
230227 in_idx++;
231228
232229 // selectively push or skip directional derivative
233230 inp = uop.getInputs ()[in_idx];
234- if (!isMutable (inp.getType ())) {
235- if (matchPattern (inp, m_Zero ()) ||
236- matchPattern (inp, m_AnyZeroFloat ())) {
237- // skip and promote to const
238- auto new_const = mlir::enzyme::ActivityAttr::get (
239- rewriter.getContext (), mlir::enzyme::Activity::enzyme_const);
240- newInActivityArgs.push_back (new_const);
241- changed = true ;
242- } else {
243- // push derivative value
244- in_ty.push_back (inp.getType ());
245- in_args.push_back (inp);
246- newInActivityArgs.push_back (iattr);
247- }
231+ if (!isMutable (inp.getType ()) &&
232+ (matchPattern (inp, m_Zero ()) ||
233+ matchPattern (inp, m_AnyZeroFloat ()))) {
234+ // skip and promote to const
235+ auto new_const = mlir::enzyme::ActivityAttr::get (
236+ rewriter.getContext (), mlir::enzyme::Activity::enzyme_const);
237+ newInActivityArgs.push_back (new_const);
238+ changed = true ;
248239 } else {
249240 // push derivative value
250- in_ty.push_back (inp.getType ());
251241 in_args.push_back (inp);
252242 newInActivityArgs.push_back (iattr);
253243 }
254244 break ;
245+ }
255246
256247 case Activity::enzyme_dup: {
257248 // always pass in primal
258249 in_args.push_back (inp);
259- in_ty.push_back (inp.getType ());
260250 in_idx++;
261251
262252 // selectively push or skip directional derivative
263253 inp = uop.getInputs ()[in_idx];
264- if (!isMutable (inp.getType ())) {
265- if (matchPattern (inp, m_Zero ()) ||
266- matchPattern (inp, m_AnyZeroFloat ())) {
267-
268- // skip and promote to const
269- auto new_const = mlir::enzyme::ActivityAttr::get (
270- rewriter.getContext (), mlir::enzyme::Activity::enzyme_const);
271- newInActivityArgs.push_back (new_const);
272- changed = true ;
273- } else {
274-
275- // push derivative value
276- in_ty.push_back (inp.getType ());
277- in_args.push_back (inp);
278- newInActivityArgs.push_back (iattr);
279- }
254+ if (!isMutable (inp.getType ()) &&
255+ (matchPattern (inp, m_Zero ()) ||
256+ matchPattern (inp, m_AnyZeroFloat ()))) {
257+ // skip and promote to const
258+ auto new_const = mlir::enzyme::ActivityAttr::get (
259+ rewriter.getContext (), mlir::enzyme::Activity::enzyme_const);
260+ newInActivityArgs.push_back (new_const);
261+ changed = true ;
280262 } else {
281263 // push derivative value
282- in_ty.push_back (inp.getType ());
283264 in_args.push_back (inp);
284265 newInActivityArgs.push_back (iattr);
285266 }
0 commit comments