Skip to content

Commit c3da306

Browse files
authored
[MLIR] Forward Mode arg conversion for dup -> const (EnzymeAD#2349)
* Added skeleton impl * remove return activity stuff Signed-off-by: Vimarsh Sathia <vsathia2@illinois.edu> * impl dup to const Signed-off-by: Vimarsh Sathia <vimarsh.sathia@gmail.com> * Some cleanup Signed-off-by: Vimarsh Sathia <vimarsh.sathia@gmail.com> * use isZero * fmt * make format happy * remove unused retactivity --------- Signed-off-by: Vimarsh Sathia <vsathia2@illinois.edu> Signed-off-by: Vimarsh Sathia <vimarsh.sathia@gmail.com>
1 parent 3afb3b4 commit c3da306

File tree

2 files changed

+125
-15
lines changed

2 files changed

+125
-15
lines changed

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 115 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,22 @@
66
//===----------------------------------------------------------------------===//
77

88
#include "Ops.h"
9-
#include "Dialect.h"
109
#include "Interfaces/AutoDiffTypeInterface.h"
11-
#include "mlir/Dialect/Arith/IR/Arith.h"
12-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1310
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1411
#include "mlir/IR/AffineExpr.h"
1512
#include "mlir/IR/Builders.h"
13+
#include "mlir/IR/Matchers.h"
1614
#include "mlir/IR/PatternMatch.h"
1715
#include "mlir/IR/Value.h"
18-
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1916
#include "mlir/Interfaces/MemorySlotInterfaces.h"
20-
#include "mlir/Interfaces/SideEffectInterfaces.h"
2117

22-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
23-
#include "mlir/Dialect/Arith/Utils/Utils.h"
2418
#include "mlir/Dialect/Func/IR/FuncOps.h"
2519
#include "mlir/Dialect/MemRef/IR/MemRef.h"
26-
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
27-
#include "mlir/Dialect/SCF/IR/SCF.h"
28-
#include "mlir/IR/IRMapping.h"
2920
#include "mlir/IR/IntegerSet.h"
3021

3122
#include "llvm/ADT/STLExtras.h"
32-
#include "llvm/ADT/SetVector.h"
3323
#include "llvm/ADT/SmallVector.h"
34-
#include "llvm/Support/Debug.h"
3524

36-
#include "llvm/ADT/TypeSwitch.h"
3725
#include "llvm/Support/ErrorHandling.h"
3826
#include "llvm/Support/LogicalResult.h"
3927

@@ -187,6 +175,119 @@ static inline bool isMutable(Type type) {
187175
return false;
188176
}
189177

178+
/**
179+
*
180+
* Modifies input activites for the FwdDiffOp
181+
* The activity promotion flow is as follows
182+
* (depending on variable use):
183+
*
184+
* -----> enzyme_dupnoneed
185+
* / /
186+
* enzyme_dup /
187+
* \ v
188+
* ------> enzyme_const
189+
*
190+
*/
191+
class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
192+
public:
193+
using OpRewritePattern<ForwardDiffOp>::OpRewritePattern;
194+
195+
LogicalResult matchAndRewrite(ForwardDiffOp uop,
196+
PatternRewriter &rewriter) const override {
197+
198+
if (uop.getOutputs().size() == 0)
199+
return failure();
200+
201+
auto inActivity = uop.getActivity();
202+
203+
auto in_idx = 0;
204+
SmallVector<mlir::Value, 2> in_args;
205+
SmallVector<ActivityAttr, 2> newInActivityArgs;
206+
bool changed = false;
207+
for (auto [idx, act] : llvm::enumerate(inActivity)) {
208+
auto iattr = cast<ActivityAttr>(act);
209+
auto val = iattr.getValue();
210+
211+
// Forward mode Input activities can only take values {dup, dupnoneed,
212+
// const }
213+
214+
mlir::Value inp = uop.getInputs()[in_idx];
215+
216+
switch (val) {
217+
218+
case mlir::enzyme::Activity::enzyme_const:
219+
in_args.push_back(inp);
220+
newInActivityArgs.push_back(iattr);
221+
break;
222+
223+
case Activity::enzyme_dupnoneed: {
224+
// always pass in primal
225+
in_args.push_back(inp);
226+
in_idx++;
227+
228+
// selectively push or skip directional derivative
229+
inp = uop.getInputs()[in_idx];
230+
auto ET = inp.getType();
231+
auto ETintf = dyn_cast<AutoDiffTypeInterface>(ET);
232+
233+
if (ETintf && !isMutable(ET) && ETintf.isZero(inp).succeeded()) {
234+
// skip and promote to const
235+
auto new_const = mlir::enzyme::ActivityAttr::get(
236+
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
237+
newInActivityArgs.push_back(new_const);
238+
changed = true;
239+
} else {
240+
// push derivative value
241+
in_args.push_back(inp);
242+
newInActivityArgs.push_back(iattr);
243+
}
244+
break;
245+
}
246+
247+
case Activity::enzyme_dup: {
248+
// always pass in primal
249+
in_args.push_back(inp);
250+
in_idx++;
251+
252+
// selectively push or skip directional derivative
253+
inp = uop.getInputs()[in_idx];
254+
auto ET = inp.getType();
255+
auto ETintf = dyn_cast<AutoDiffTypeInterface>(ET);
256+
257+
if (ETintf && !isMutable(ET) && ETintf.isZero(inp).succeeded()) {
258+
// skip and promote to const
259+
auto new_const = mlir::enzyme::ActivityAttr::get(
260+
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
261+
newInActivityArgs.push_back(new_const);
262+
changed = true;
263+
} else {
264+
// push derivative value
265+
in_args.push_back(inp);
266+
newInActivityArgs.push_back(iattr);
267+
}
268+
break;
269+
}
270+
default:
271+
llvm_unreachable("unexpected input activity arg");
272+
}
273+
274+
in_idx++;
275+
}
276+
277+
if (!changed)
278+
return failure();
279+
280+
// create the new op
281+
ArrayAttr newInActivity =
282+
ArrayAttr::get(rewriter.getContext(),
283+
llvm::ArrayRef<Attribute>(newInActivityArgs.begin(),
284+
newInActivityArgs.end()));
285+
rewriter.replaceOpWithNewOp<ForwardDiffOp>(
286+
uop, uop->getResultTypes(), uop.getFnAttr(), in_args, newInActivity,
287+
uop.getRetActivityAttr(), uop.getWidthAttr(), uop.getStrongZeroAttr());
288+
return success();
289+
}
290+
};
190291
/**
191292
*
192293
* Modifies return activites for the FwdDiffOp
@@ -398,7 +499,7 @@ class FwdRetOpt final : public OpRewritePattern<ForwardDiffOp> {
398499
void ForwardDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
399500
MLIRContext *context) {
400501

401-
patterns.add<FwdRetOpt>(context);
502+
patterns.add<FwdRetOpt, FwdInpOpt>(context);
402503
}
403504

404505
LogicalResult AutoDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

enzyme/test/MLIR/ForwardMode/canonicalize.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,16 @@ module {
7777
// CHECK: enzyme.fwddiff @square(%arg0, %arg1) {{.*ret_activity = \[#enzyme<activity enzyme_constnoneed>\]}}
7878
return %cst : f64
7979
}
80-
80+
81+
// -----
82+
83+
func.func @dsq7(%x : f64, %dx : f64) -> (f64,f64) {
84+
%cst = arith.constant 0.0000e+00 : f64
85+
%p, %r = enzyme.fwddiff @square(%x, %cst) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64, f64)
86+
// CHECK: %{{.*}}:2 = enzyme.fwddiff @square(%arg0) {activity = [#enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_dup>]}
87+
return %p, %r : f64, f64
88+
}
89+
8190
// -----
8291
// Greedy test
8392
func.func @dsq5(%x : f64, %dx : f64) -> f64 {

0 commit comments

Comments
 (0)