Skip to content

Commit dd8ff11

Browse files
committed
Some cleanup
Signed-off-by: Vimarsh Sathia <vimarsh.sathia@gmail.com>
1 parent 103dcce commit dd8ff11

File tree

1 file changed

+18
-37
lines changed

1 file changed

+18
-37
lines changed

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
203203

204204
auto in_idx = 0;
205205
SmallVector<mlir::Value, 2> in_args;
206-
SmallVector<Type, 2> in_ty;
207206
SmallVector<ActivityAttr, 2> newInActivityArgs;
208207
bool changed = false;
209208
for (auto [idx, act] : llvm::enumerate(inActivity)) {
@@ -219,67 +218,49 @@ class FwdInpOpt final : public OpRewritePattern<ForwardDiffOp> {
219218

220219
case mlir::enzyme::Activity::enzyme_const:
221220
in_args.push_back(inp);
222-
in_ty.push_back(inp.getType());
223221
newInActivityArgs.push_back(iattr);
224222
break;
225223

226-
case Activity::enzyme_dupnoneed:
224+
case Activity::enzyme_dupnoneed: {
227225
// always pass in primal
228226
in_args.push_back(inp);
229-
in_ty.push_back(inp.getType());
230227
in_idx++;
231228

232229
// selectively push or skip directional derivative
233230
inp = uop.getInputs()[in_idx];
234-
if (!isMutable(inp.getType())) {
235-
if (matchPattern(inp, m_Zero()) ||
236-
matchPattern(inp, m_AnyZeroFloat())) {
237-
// skip and promote to const
238-
auto new_const = mlir::enzyme::ActivityAttr::get(
239-
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
240-
newInActivityArgs.push_back(new_const);
241-
changed = true;
242-
} else {
243-
// push derivative value
244-
in_ty.push_back(inp.getType());
245-
in_args.push_back(inp);
246-
newInActivityArgs.push_back(iattr);
247-
}
231+
if (!isMutable(inp.getType()) &&
232+
(matchPattern(inp, m_Zero()) ||
233+
matchPattern(inp, m_AnyZeroFloat()))) {
234+
// skip and promote to const
235+
auto new_const = mlir::enzyme::ActivityAttr::get(
236+
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
237+
newInActivityArgs.push_back(new_const);
238+
changed = true;
248239
} else {
249240
// push derivative value
250-
in_ty.push_back(inp.getType());
251241
in_args.push_back(inp);
252242
newInActivityArgs.push_back(iattr);
253243
}
254244
break;
245+
}
255246

256247
case Activity::enzyme_dup: {
257248
// always pass in primal
258249
in_args.push_back(inp);
259-
in_ty.push_back(inp.getType());
260250
in_idx++;
261251

262252
// selectively push or skip directional derivative
263253
inp = uop.getInputs()[in_idx];
264-
if (!isMutable(inp.getType())) {
265-
if (matchPattern(inp, m_Zero()) ||
266-
matchPattern(inp, m_AnyZeroFloat())) {
267-
268-
// skip and promote to const
269-
auto new_const = mlir::enzyme::ActivityAttr::get(
270-
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
271-
newInActivityArgs.push_back(new_const);
272-
changed = true;
273-
} else {
274-
275-
// push derivative value
276-
in_ty.push_back(inp.getType());
277-
in_args.push_back(inp);
278-
newInActivityArgs.push_back(iattr);
279-
}
254+
if (!isMutable(inp.getType()) &&
255+
(matchPattern(inp, m_Zero()) ||
256+
matchPattern(inp, m_AnyZeroFloat()))) {
257+
// skip and promote to const
258+
auto new_const = mlir::enzyme::ActivityAttr::get(
259+
rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
260+
newInActivityArgs.push_back(new_const);
261+
changed = true;
280262
} else {
281263
// push derivative value
282-
in_ty.push_back(inp.getType());
283264
in_args.push_back(inp);
284265
newInActivityArgs.push_back(iattr);
285266
}

0 commit comments

Comments
 (0)