|
6 | 6 | //===----------------------------------------------------------------------===// |
7 | 7 |
|
8 | 8 | #include "Ops.h" |
9 | | -#include "Dialect.h" |
10 | 9 | #include "Interfaces/AutoDiffTypeInterface.h" |
11 | | -#include "mlir/Dialect/Arith/IR/Arith.h" |
12 | | -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
13 | 10 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
14 | 11 | #include "mlir/IR/AffineExpr.h" |
15 | 12 | #include "mlir/IR/Builders.h" |
| 13 | +#include "mlir/IR/Matchers.h" |
16 | 14 | #include "mlir/IR/PatternMatch.h" |
17 | 15 | #include "mlir/IR/Value.h" |
18 | | -#include "mlir/Interfaces/ControlFlowInterfaces.h" |
19 | 16 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
20 | | -#include "mlir/Interfaces/SideEffectInterfaces.h" |
21 | 17 |
|
22 | | -#include "mlir/Dialect/Affine/IR/AffineOps.h" |
23 | | -#include "mlir/Dialect/Arith/Utils/Utils.h" |
24 | 18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
25 | 19 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
26 | | -#include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
27 | | -#include "mlir/Dialect/SCF/IR/SCF.h" |
28 | | -#include "mlir/IR/IRMapping.h" |
29 | 20 | #include "mlir/IR/IntegerSet.h" |
30 | 21 |
|
31 | 22 | #include "llvm/ADT/STLExtras.h" |
32 | | -#include "llvm/ADT/SetVector.h" |
33 | 23 | #include "llvm/ADT/SmallVector.h" |
34 | | -#include "llvm/Support/Debug.h" |
35 | 24 |
|
36 | | -#include "llvm/ADT/TypeSwitch.h" |
37 | 25 | #include "llvm/Support/ErrorHandling.h" |
38 | 26 | #include "llvm/Support/LogicalResult.h" |
39 | 27 |
|
@@ -187,6 +175,119 @@ static inline bool isMutable(Type type) { |
187 | 175 | return false; |
188 | 176 | } |
189 | 177 |
|
| 178 | +/** |
| 179 | + * |
| 180 | + * Modifies input activites for the FwdDiffOp |
| 181 | + * The activity promotion flow is as follows |
| 182 | + * (depending on variable use): |
| 183 | + * |
| 184 | + * -----> enzyme_dupnoneed |
| 185 | + * / / |
| 186 | + * enzyme_dup / |
| 187 | + * \ v |
| 188 | + * ------> enzyme_const |
| 189 | + * |
| 190 | + */ |
| 191 | +class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> { |
| 192 | +public: |
| 193 | + using OpRewritePattern<ForwardDiffOp>::OpRewritePattern; |
| 194 | + |
| 195 | + LogicalResult matchAndRewrite(ForwardDiffOp uop, |
| 196 | + PatternRewriter &rewriter) const override { |
| 197 | + |
| 198 | + if (uop.getOutputs().size() == 0) |
| 199 | + return failure(); |
| 200 | + |
| 201 | + auto inActivity = uop.getActivity(); |
| 202 | + |
| 203 | + auto in_idx = 0; |
| 204 | + SmallVector<mlir::Value, 2> in_args; |
| 205 | + SmallVector<ActivityAttr, 2> newInActivityArgs; |
| 206 | + bool changed = false; |
| 207 | + for (auto [idx, act] : llvm::enumerate(inActivity)) { |
| 208 | + auto iattr = cast<ActivityAttr>(act); |
| 209 | + auto val = iattr.getValue(); |
| 210 | + |
| 211 | + // Forward mode Input activities can only take values {dup, dupnoneed, |
| 212 | + // const } |
| 213 | + |
| 214 | + mlir::Value inp = uop.getInputs()[in_idx]; |
| 215 | + |
| 216 | + switch (val) { |
| 217 | + |
| 218 | + case mlir::enzyme::Activity::enzyme_const: |
| 219 | + in_args.push_back(inp); |
| 220 | + newInActivityArgs.push_back(iattr); |
| 221 | + break; |
| 222 | + |
| 223 | + case Activity::enzyme_dupnoneed: { |
| 224 | + // always pass in primal |
| 225 | + in_args.push_back(inp); |
| 226 | + in_idx++; |
| 227 | + |
| 228 | + // selectively push or skip directional derivative |
| 229 | + inp = uop.getInputs()[in_idx]; |
| 230 | + auto ET = inp.getType(); |
| 231 | + auto ETintf = dyn_cast<AutoDiffTypeInterface>(ET); |
| 232 | + |
| 233 | + if (ETintf && !isMutable(ET) && ETintf.isZero(inp).succeeded()) { |
| 234 | + // skip and promote to const |
| 235 | + auto new_const = mlir::enzyme::ActivityAttr::get( |
| 236 | + rewriter.getContext(), mlir::enzyme::Activity::enzyme_const); |
| 237 | + newInActivityArgs.push_back(new_const); |
| 238 | + changed = true; |
| 239 | + } else { |
| 240 | + // push derivative value |
| 241 | + in_args.push_back(inp); |
| 242 | + newInActivityArgs.push_back(iattr); |
| 243 | + } |
| 244 | + break; |
| 245 | + } |
| 246 | + |
| 247 | + case Activity::enzyme_dup: { |
| 248 | + // always pass in primal |
| 249 | + in_args.push_back(inp); |
| 250 | + in_idx++; |
| 251 | + |
| 252 | + // selectively push or skip directional derivative |
| 253 | + inp = uop.getInputs()[in_idx]; |
| 254 | + auto ET = inp.getType(); |
| 255 | + auto ETintf = dyn_cast<AutoDiffTypeInterface>(ET); |
| 256 | + |
| 257 | + if (ETintf && !isMutable(ET) && ETintf.isZero(inp).succeeded()) { |
| 258 | + // skip and promote to const |
| 259 | + auto new_const = mlir::enzyme::ActivityAttr::get( |
| 260 | + rewriter.getContext(), mlir::enzyme::Activity::enzyme_const); |
| 261 | + newInActivityArgs.push_back(new_const); |
| 262 | + changed = true; |
| 263 | + } else { |
| 264 | + // push derivative value |
| 265 | + in_args.push_back(inp); |
| 266 | + newInActivityArgs.push_back(iattr); |
| 267 | + } |
| 268 | + break; |
| 269 | + } |
| 270 | + default: |
| 271 | + llvm_unreachable("unexpected input activity arg"); |
| 272 | + } |
| 273 | + |
| 274 | + in_idx++; |
| 275 | + } |
| 276 | + |
| 277 | + if (!changed) |
| 278 | + return failure(); |
| 279 | + |
| 280 | + // create the new op |
| 281 | + ArrayAttr newInActivity = |
| 282 | + ArrayAttr::get(rewriter.getContext(), |
| 283 | + llvm::ArrayRef<Attribute>(newInActivityArgs.begin(), |
| 284 | + newInActivityArgs.end())); |
| 285 | + rewriter.replaceOpWithNewOp<ForwardDiffOp>( |
| 286 | + uop, uop->getResultTypes(), uop.getFnAttr(), in_args, newInActivity, |
| 287 | + uop.getRetActivityAttr(), uop.getWidthAttr(), uop.getStrongZeroAttr()); |
| 288 | + return success(); |
| 289 | + } |
| 290 | +}; |
190 | 291 | /** |
191 | 292 | * |
192 | 293 | * Modifies return activites for the FwdDiffOp |
@@ -398,7 +499,7 @@ class FwdRetOpt final : public OpRewritePattern<ForwardDiffOp> { |
398 | 499 | void ForwardDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
399 | 500 | MLIRContext *context) { |
400 | 501 |
|
401 | | - patterns.add<FwdRetOpt>(context); |
| 502 | + patterns.add<FwdRetOpt, FwdInpOpt>(context); |
402 | 503 | } |
403 | 504 |
|
404 | 505 | LogicalResult AutoDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
|
0 commit comments