1010#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
1111#include " mlir/IR/AffineExpr.h"
1212#include " mlir/IR/Builders.h"
13+ #include " mlir/IR/Matchers.h"
1314#include " mlir/IR/PatternMatch.h"
1415#include " mlir/IR/Value.h"
1516#include " mlir/Interfaces/MemorySlotInterfaces.h"
@@ -205,7 +206,6 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
205206 SmallVector<Type, 2 > in_ty;
206207 SmallVector<ActivityAttr, 2 > newInActivityArgs;
207208 bool changed = false ;
208-
209209 for (auto [idx, act] : llvm::enumerate (inActivity)) {
210210 auto iattr = cast<ActivityAttr>(act);
211211 auto val = iattr.getValue ();
@@ -222,64 +222,67 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
222222 in_ty.push_back (inp.getType ());
223223 newInActivityArgs.push_back (iattr);
224224 break ;
225- case Activity::enzyme_dupnoneed:
226225
226+ case Activity::enzyme_dupnoneed:
227+ // always pass in primal
227228 in_args.push_back (inp);
228229 in_ty.push_back (inp.getType ());
229- newInActivityArgs.push_back (iattr);
230- // else {
231- // if (!isMutable(inp.getType())) {
232- // changed = true;
233- // auto new_constnn = mlir::enzyme::ActivityAttr::get(
234- // rewriter.getContext(),
235- // mlir::enzyme::Activity::enzyme_constnoneed);
236- // newInActivityArgs.push_back(new_constnn);
237- // } else {
238- // in_args.push_back(inp);
239- // in_ty.push_back(inp.getType());
240- // newInActivityArgs.push_back(iattr);
241- // }
242- // }
230+ in_idx++;
231+
232+ // selectively push or skip directional derivative
233+ inp = uop.getInputs ()[in_idx];
234+ if (!isMutable (inp.getType ())) {
235+ if (matchPattern (inp, m_Zero ()) ||
236+ matchPattern (inp, m_AnyZeroFloat ())) {
237+ // skip and promote to const
238+ auto new_const = mlir::enzyme::ActivityAttr::get (
239+ rewriter.getContext (), mlir::enzyme::Activity::enzyme_const);
240+ newInActivityArgs.push_back (new_const);
241+ changed = true ;
242+ } else {
243+ // push derivative value
244+ in_ty.push_back (inp.getType ());
245+ in_args.push_back (inp);
246+ newInActivityArgs.push_back (iattr);
247+ }
248+ } else {
249+ // push derivative value
250+ in_ty.push_back (inp.getType ());
251+ in_args.push_back (inp);
252+ newInActivityArgs.push_back (iattr);
253+ }
243254 break ;
244255
245256 case Activity::enzyme_dup: {
246- ActivityAttr new_dup = iattr;
247- // if (!inp.use_empty()) {
257+ // always pass in primal
248258 in_args.push_back (inp);
249259 in_ty.push_back (inp.getType ());
250- // } else {
251- // changed = true;
252- // // discard return, change attr
253- // new_dup = ActivityAttr::get(rewriter.getContext(),
254- // Activity::enzyme_dupnoneed);
255- // }
256-
257260 in_idx++;
258261
259- // derivative value
262+ // selectively push or skip directional derivative
260263 inp = uop.getInputs ()[in_idx];
261- // if (!inp.use_empty( )) {
262- // activity arg doesn't update
263- in_ty. push_back (inp. getType ());
264- in_args. push_back (inp);
265- // } else {
266- // // no uses, can discard
267- // if (!isMutable(inp.getType())) {
268- // changed = true ;
269- // // check if primal is used
270- // if (new_dup.getValue() == Activity::enzyme_dupnoneed) {
271- // new_dup = ActivityAttr::get(rewriter.getContext(),
272- // Activity::enzyme_constnoneed);
273- // } else {
274- // new_dup = ActivityAttr::get(rewriter.getContext(),
275- // Activity::enzyme_const );
276- // }
277- // } else {
278- // in_ty.push_back(inp.getType());
279- // in_args .push_back(inp);
280- // }
281- // }
282- newInActivityArgs. push_back (new_dup);
264+ if (!isMutable ( inp.getType () )) {
265+ if ( matchPattern (inp, m_Zero ()) ||
266+ matchPattern (inp, m_AnyZeroFloat ())) {
267+
268+ // skip and promote to const
269+ auto new_const = mlir::enzyme::ActivityAttr::get (
270+ rewriter. getContext (), mlir::enzyme::Activity::enzyme_const);
271+ newInActivityArgs. push_back (new_const) ;
272+ changed = true ;
273+ } else {
274+
275+ // push derivative value
276+ in_ty. push_back (inp. getType ());
277+ in_args. push_back (inp);
278+ newInActivityArgs. push_back (iattr );
279+ }
280+ } else {
281+ // push derivative value
282+ in_ty .push_back (inp. getType () );
283+ in_args. push_back (inp);
284+ newInActivityArgs. push_back (iattr);
285+ }
283286 break ;
284287 }
285288 default :
@@ -292,66 +295,14 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
292295 if (!changed)
293296 return failure ();
294297
298+ // create the new op
295299 ArrayAttr newInActivity =
296300 ArrayAttr::get (rewriter.getContext (),
297301 llvm::ArrayRef<Attribute>(newInActivityArgs.begin (),
298302 newInActivityArgs.end ()));
299- ForwardDiffOp newOp = rewriter.create <ForwardDiffOp>(
300- uop.getLoc (), uop.getResultTypes (), uop.getFnAttr (), uop.getInputs (),
301- newInActivity, uop.getRetActivityAttr (), uop.getWidthAttr (),
302- uop.getStrongZeroAttr ());
303-
304- // Map old uses of uop to newOp
305- auto oldIdx = 0 ;
306- auto newIdx = 0 ;
307- for (auto [idx, old_act, new_act] :
308- llvm::enumerate (retActivity, newInActivityArgs)) {
309-
310- auto iattr = cast<ActivityAttr>(old_act);
311- auto old_val = iattr.getValue ();
312- auto new_val = new_act.getValue ();
313-
314- if (old_val == new_val) {
315- // don't index into op if its a const_noneed
316- if (old_val == Activity::enzyme_constnoneed) {
317- continue ;
318- }
319- // replace use
320- uop.getOutputs ()[oldIdx++].replaceAllUsesWith (
321- newOp.getOutputs ()[newIdx++]);
322- if (old_val == Activity::enzyme_dup) {
323- // 2nd replacement for derivative
324- uop.getOutputs ()[oldIdx++].replaceAllUsesWith (
325- newOp.getOutputs ()[newIdx++]);
326- }
327- } else {
328- // handle all substitutions
329- if (new_val == Activity::enzyme_dupnoneed &&
330- old_val == Activity::enzyme_dup) {
331- ++oldIdx; // skip primal
332- uop.getOutputs ()[oldIdx++].replaceAllUsesWith (
333- newOp.getOutputs ()[newIdx++]);
334- } else if (new_val == mlir::enzyme::Activity::enzyme_constnoneed &&
335- old_val == mlir::enzyme::Activity::enzyme_const) {
336- ++oldIdx; // skip const
337- } else if (new_val == mlir::enzyme::Activity::enzyme_constnoneed &&
338- old_val == mlir::enzyme::Activity::enzyme_dupnoneed) {
339- ++oldIdx; // skip gradient too
340- } else if (new_val == Activity::enzyme_const &&
341- old_val == Activity::enzyme_dup) {
342-
343- uop.getOutputs ()[oldIdx++].replaceAllUsesWith (
344- newOp.getOutputs ()[newIdx++]);
345- ++oldIdx; // skip derivative
346- } else if (new_val == Activity::enzyme_constnoneed &&
347- old_val == Activity::enzyme_dup) {
348- ++oldIdx; // skip primal
349- ++oldIdx; // skip derivative
350- }
351- }
352- }
353-
354- rewriter.eraseOp (uop);
303+ rewriter.replaceOpWithNewOp <ForwardDiffOp>(
304+ uop, uop->getResultTypes (), uop.getFnAttr (), in_args, newInActivity,
305+ uop.getRetActivityAttr (), uop.getWidthAttr (), uop.getStrongZeroAttr ());
355306 return success ();
356307 }
357308};
0 commit comments