Skip to content

Commit 103dcce

Browse files
committed
impl dup to const
Signed-off-by: Vimarsh Sathia <vimarsh.sathia@gmail.com>
1 parent 3bd7388 commit 103dcce

File tree

2 files changed

+65
-105
lines changed

2 files changed

+65
-105
lines changed

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 55 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1111
#include "mlir/IR/AffineExpr.h"
1212
#include "mlir/IR/Builders.h"
13+
#include "mlir/IR/Matchers.h"
1314
#include "mlir/IR/PatternMatch.h"
1415
#include "mlir/IR/Value.h"
1516
#include "mlir/Interfaces/MemorySlotInterfaces.h"
@@ -205,7 +206,6 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
205206
SmallVector<Type, 2> in_ty;
206207
SmallVector<ActivityAttr, 2> newInActivityArgs;
207208
bool changed = false;
208-
209209
for (auto [idx, act] : llvm::enumerate(inActivity)) {
210210
auto iattr = cast<ActivityAttr>(act);
211211
auto val = iattr.getValue();
@@ -222,64 +222,67 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
222222
in_ty.push_back(inp.getType());
223223
newInActivityArgs.push_back(iattr);
224224
break;
225-
case Activity::enzyme_dupnoneed:
226225

226+
case Activity::enzyme_dupnoneed:
227+
// always pass in primal
227228
in_args.push_back(inp);
228229
in_ty.push_back(inp.getType());
229-
newInActivityArgs.push_back(iattr);
230-
// else {
231-
// if (!isMutable(inp.getType())) {
232-
// changed = true;
233-
// auto new_constnn = mlir::enzyme::ActivityAttr::get(
234-
// rewriter.getContext(),
235-
// mlir::enzyme::Activity::enzyme_constnoneed);
236-
// newInActivityArgs.push_back(new_constnn);
237-
// } else {
238-
// in_args.push_back(inp);
239-
// in_ty.push_back(inp.getType());
240-
// newInActivityArgs.push_back(iattr);
241-
// }
242-
// }
230+
in_idx++;
231+
232+
// selectively push or skip directional derivative
233+
inp = uop.getInputs()[in_idx];
234+
if (!isMutable(inp.getType())) {
235+
if (matchPattern(inp, m_Zero()) ||
236+
matchPattern(inp, m_AnyZeroFloat())) {
237+
// skip and promote to const
238+
auto new_const = mlir::enzyme::ActivityAttr::get(
239+
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
240+
newInActivityArgs.push_back(new_const);
241+
changed = true;
242+
} else {
243+
// push derivative value
244+
in_ty.push_back(inp.getType());
245+
in_args.push_back(inp);
246+
newInActivityArgs.push_back(iattr);
247+
}
248+
} else {
249+
// push derivative value
250+
in_ty.push_back(inp.getType());
251+
in_args.push_back(inp);
252+
newInActivityArgs.push_back(iattr);
253+
}
243254
break;
244255

245256
case Activity::enzyme_dup: {
246-
ActivityAttr new_dup = iattr;
247-
// if (!inp.use_empty()) {
257+
// always pass in primal
248258
in_args.push_back(inp);
249259
in_ty.push_back(inp.getType());
250-
// } else {
251-
// changed = true;
252-
// // discard return, change attr
253-
// new_dup = ActivityAttr::get(rewriter.getContext(),
254-
// Activity::enzyme_dupnoneed);
255-
// }
256-
257260
in_idx++;
258261

259-
// derivative value
262+
// selectively push or skip directional derivative
260263
inp = uop.getInputs()[in_idx];
261-
// if (!inp.use_empty()) {
262-
// activity arg doesn't update
263-
in_ty.push_back(inp.getType());
264-
in_args.push_back(inp);
265-
// } else {
266-
// // no uses, can discard
267-
// if (!isMutable(inp.getType())) {
268-
// changed = true;
269-
// // check if primal is used
270-
// if (new_dup.getValue() == Activity::enzyme_dupnoneed) {
271-
// new_dup = ActivityAttr::get(rewriter.getContext(),
272-
// Activity::enzyme_constnoneed);
273-
// } else {
274-
// new_dup = ActivityAttr::get(rewriter.getContext(),
275-
// Activity::enzyme_const);
276-
// }
277-
// } else {
278-
// in_ty.push_back(inp.getType());
279-
// in_args.push_back(inp);
280-
// }
281-
// }
282-
newInActivityArgs.push_back(new_dup);
264+
if (!isMutable(inp.getType())) {
265+
if (matchPattern(inp, m_Zero()) ||
266+
matchPattern(inp, m_AnyZeroFloat())) {
267+
268+
// skip and promote to const
269+
auto new_const = mlir::enzyme::ActivityAttr::get(
270+
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
271+
newInActivityArgs.push_back(new_const);
272+
changed = true;
273+
} else {
274+
275+
// push derivative value
276+
in_ty.push_back(inp.getType());
277+
in_args.push_back(inp);
278+
newInActivityArgs.push_back(iattr);
279+
}
280+
} else {
281+
// push derivative value
282+
in_ty.push_back(inp.getType());
283+
in_args.push_back(inp);
284+
newInActivityArgs.push_back(iattr);
285+
}
283286
break;
284287
}
285288
default:
@@ -292,66 +295,14 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
292295
if (!changed)
293296
return failure();
294297

298+
// create the new op
295299
ArrayAttr newInActivity =
296300
ArrayAttr::get(rewriter.getContext(),
297301
llvm::ArrayRef<Attribute>(newInActivityArgs.begin(),
298302
newInActivityArgs.end()));
299-
ForwardDiffOp newOp = rewriter.create<ForwardDiffOp>(
300-
uop.getLoc(), uop.getResultTypes(), uop.getFnAttr(), uop.getInputs(),
301-
newInActivity, uop.getRetActivityAttr(), uop.getWidthAttr(),
302-
uop.getStrongZeroAttr());
303-
304-
// Map old uses of uop to newOp
305-
auto oldIdx = 0;
306-
auto newIdx = 0;
307-
for (auto [idx, old_act, new_act] :
308-
llvm::enumerate(retActivity, newInActivityArgs)) {
309-
310-
auto iattr = cast<ActivityAttr>(old_act);
311-
auto old_val = iattr.getValue();
312-
auto new_val = new_act.getValue();
313-
314-
if (old_val == new_val) {
315-
// don't index into op if its a const_noneed
316-
if (old_val == Activity::enzyme_constnoneed) {
317-
continue;
318-
}
319-
// replace use
320-
uop.getOutputs()[oldIdx++].replaceAllUsesWith(
321-
newOp.getOutputs()[newIdx++]);
322-
if (old_val == Activity::enzyme_dup) {
323-
// 2nd replacement for derivative
324-
uop.getOutputs()[oldIdx++].replaceAllUsesWith(
325-
newOp.getOutputs()[newIdx++]);
326-
}
327-
} else {
328-
// handle all substitutions
329-
if (new_val == Activity::enzyme_dupnoneed &&
330-
old_val == Activity::enzyme_dup) {
331-
++oldIdx; // skip primal
332-
uop.getOutputs()[oldIdx++].replaceAllUsesWith(
333-
newOp.getOutputs()[newIdx++]);
334-
} else if (new_val == mlir::enzyme::Activity::enzyme_constnoneed &&
335-
old_val == mlir::enzyme::Activity::enzyme_const) {
336-
++oldIdx; // skip const
337-
} else if (new_val == mlir::enzyme::Activity::enzyme_constnoneed &&
338-
old_val == mlir::enzyme::Activity::enzyme_dupnoneed) {
339-
++oldIdx; // skip gradient too
340-
} else if (new_val == Activity::enzyme_const &&
341-
old_val == Activity::enzyme_dup) {
342-
343-
uop.getOutputs()[oldIdx++].replaceAllUsesWith(
344-
newOp.getOutputs()[newIdx++]);
345-
++oldIdx; // skip derivative
346-
} else if (new_val == Activity::enzyme_constnoneed &&
347-
old_val == Activity::enzyme_dup) {
348-
++oldIdx; // skip primal
349-
++oldIdx; // skip derivative
350-
}
351-
}
352-
}
353-
354-
rewriter.eraseOp(uop);
303+
rewriter.replaceOpWithNewOp<ForwardDiffOp>(
304+
uop, uop->getResultTypes(), uop.getFnAttr(), in_args, newInActivity,
305+
uop.getRetActivityAttr(), uop.getWidthAttr(), uop.getStrongZeroAttr());
355306
return success();
356307
}
357308
};

enzyme/test/MLIR/ForwardMode/canonicalize.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,16 @@ module {
7777
// CHECK: enzyme.fwddiff @square(%arg0, %arg1) {{.*ret_activity = \[#enzyme<activity enzyme_constnoneed>\]}}
7878
return %cst : f64
7979
}
80-
80+
81+
// -----
82+
83+
func.func @dsq7(%x : f64, %dx : f64) -> (f64,f64) {
84+
%cst = arith.constant 0.0000e+00 : f64
85+
%p, %r = enzyme.fwddiff @square(%x, %cst) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64, f64)
86+
// CHECK: %{{.*}}:2 = enzyme.fwddiff @square(%arg0) {activity = [#enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_dup>]}
87+
return %p, %r : f64, f64
88+
}
89+
8190
// -----
8291
// Greedy test
8392
func.func @dsq5(%x : f64, %dx : f64) -> f64 {

0 commit comments

Comments
 (0)