Skip to content

Commit 283b9ee

Browse files
authored
Aggressive option for reshape slice and elementwise (EnzymeAD#673)
* Aggressive option for reshape slice and elementwise * fmt * fix * fmt * fix * fix
1 parent ac3d3d2 commit 283b9ee

File tree

7 files changed

+75
-17
lines changed

7 files changed

+75
-17
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,13 @@ struct ReshapeDUS final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
457457
};
458458

459459
struct ReshapeSlice final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
460-
using OpRewritePattern::OpRewritePattern;
460+
bool onlySingleUser;
461+
462+
ReshapeSlice(bool onlySingleUser, MLIRContext *context,
463+
PatternBenefit benefit = 1,
464+
ArrayRef<StringRef> generatedNames = {})
465+
: OpRewritePattern(context, benefit, generatedNames),
466+
onlySingleUser(onlySingleUser) {}
461467

462468
LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op,
463469
PatternRewriter &rewriter) const override {
@@ -466,7 +472,7 @@ struct ReshapeSlice final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
466472
if (!slice)
467473
return failure();
468474

469-
if (!llvm::hasSingleElement(slice->getUsers()))
475+
if (onlySingleUser && !llvm::hasSingleElement(slice->getUsers()))
470476
return failure();
471477

472478
SmallVector<int64_t> startIndices(slice.getStartIndices().begin(),
@@ -7427,15 +7433,21 @@ struct ReshapeReduceWindow final
74277433
};
74287434

74297435
struct ReshapeElementwise final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
7430-
using OpRewritePattern::OpRewritePattern;
7436+
bool onlySingleUser;
7437+
7438+
ReshapeElementwise(bool onlySingleUser, MLIRContext *context,
7439+
PatternBenefit benefit = 1,
7440+
ArrayRef<StringRef> generatedNames = {})
7441+
: OpRewritePattern(context, benefit, generatedNames),
7442+
onlySingleUser(onlySingleUser) {}
74317443

74327444
LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op,
74337445
PatternRewriter &rewriter) const override {
74347446
auto elem = op.getOperand().getDefiningOp();
74357447
if (!elem)
74367448
return failure();
74377449

7438-
if (!llvm::hasSingleElement(elem->getUsers()))
7450+
if (onlySingleUser && !llvm::hasSingleElement(elem->getUsers()))
74397451
return failure();
74407452

74417453
if (!elem->hasTrait<mlir::OpTrait::Elementwise>())
@@ -13827,6 +13839,19 @@ void mlir::transform::addTransposeElementwise(RewritePatternSet &patterns,
1382713839
patterns.insert<TransposeElementwise>(onlySingleUser, &context, benefit);
1382813840
}
1382913841

13842+
void mlir::transform::addReshapeElementwise(RewritePatternSet &patterns,
13843+
bool onlySingleUser,
13844+
MLIRContext &context,
13845+
PatternBenefit benefit) {
13846+
patterns.insert<ReshapeElementwise>(onlySingleUser, &context, benefit);
13847+
}
13848+
13849+
void mlir::transform::addReshapeSlice(RewritePatternSet &patterns,
13850+
bool onlySingleUser, MLIRContext &context,
13851+
PatternBenefit benefit) {
13852+
patterns.insert<ReshapeSlice>(onlySingleUser, &context, benefit);
13853+
}
13854+
1383013855
namespace {
1383113856

1383213857
struct EnzymeHLOOptPass
@@ -13993,9 +14018,9 @@ struct EnzymeHLOOptPass
1399314018

1399414019
if (passses & (2048 * 64)) {
1399514020
// add reshape push up cases here
13996-
patterns.add<ReshapeElementwise, ReshapeOfConcatToConcatOfReshape,
13997-
ReshapeDUS, ReshapeSlice, ReshapePad, ReshapeReduceWindow>(
13998-
context);
14021+
patterns.add<ReshapeElementwise, ReshapeSlice>(true, context);
14022+
patterns.add<ReshapeOfConcatToConcatOfReshape, ReshapeDUS, ReshapePad,
14023+
ReshapeReduceWindow>(context);
1399914024
}
1400014025

1400114026
if (passses & (2048 * 128)) {

src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,9 @@ void addConcatenateOpCanon(RewritePatternSet &patterns,
5656
PatternBenefit benefit);
5757
void addTransposeElementwise(RewritePatternSet &patterns, bool onlySingleUser,
5858
MLIRContext &context, PatternBenefit benefit);
59+
void addReshapeElementwise(RewritePatternSet &patterns, bool onlySingleUser,
60+
MLIRContext &context, PatternBenefit benefit);
61+
void addReshapeSlice(RewritePatternSet &patterns, bool onlySingleUser,
62+
MLIRContext &context, PatternBenefit benefit);
5963

6064
} // namespace mlir::transform

src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ LogicalResult parseTransform(OpBuilder &builder, Location loc,
149149
opName == "pad_licm" || opName == "elementwise_licm" ||
150150
opName == "concatenate_licm" || opName == "broadcastindim_licm" ||
151151
opName == "reshape_licm" || opName == "transpose_licm" ||
152-
opName == "transpose_elementwise")
152+
opName == "transpose_elementwise" ||
153+
opName == "reshape_elementwise" || opName == "reshape_slice")
153154
state.addAttribute("parameter", builder.getBoolAttr(parameter));
154155
else
155156
state.addAttribute("parameter", builder.getI64IntegerAttr(parameter));

src/enzyme_ad/jax/TransformOps/TransformOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ void ApplyTransposeElementwisePatterns::populatePatterns(
9797
addTransposeElementwise(patterns, getParameter(), *getContext(),
9898
PatternBenefit(getBenefit().value_or(1)));
9999
}
100+
void ApplyReshapeElementwisePatterns::populatePatterns(
101+
RewritePatternSet &patterns) {
102+
addReshapeElementwise(patterns, getParameter(), *getContext(),
103+
PatternBenefit(getBenefit().value_or(1)));
104+
}
105+
void ApplyReshapeSlicePatterns::populatePatterns(RewritePatternSet &patterns) {
106+
addReshapeSlice(patterns, getParameter(), *getContext(),
107+
PatternBenefit(getBenefit().value_or(1)));
108+
}
100109
void ApplySumToConvPatterns::populatePatterns(RewritePatternSet &patterns) {
101110
addSumToConv(patterns, getParameter(), *getContext(),
102111
PatternBenefit(getBenefit().value_or(0)));

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,15 +1182,34 @@ def ApplyReshapeDUSPatterns : EnzymeHLOPatternOp<
11821182
let patterns = ["ReshapeDUS"];
11831183
}
11841184

1185-
def ApplyReshapeSlicePatterns : EnzymeHLOPatternOp<
1185+
def ApplyReshapeSlicePatterns : EnzymeHLOParameterizedPatternOp<
11861186
"reshape_slice"> {
1187-
let patterns = ["ReshapeSlice"];
1187+
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);
1188+
let assemblyFormat = "attr-dict";
1189+
// TODO: this should be made better searchable.
1190+
let extraClassDeclaration = [{
1191+
::llvm::SmallVector<::mlir::DictionaryAttr>
1192+
static getPossibleAttrCombinations(::mlir::Builder &builder) {
1193+
return {builder.getDictionaryAttr(
1194+
builder.getNamedAttr("parameter",
1195+
builder.getBoolAttr(true)))};
1196+
}
1197+
}];
11881198
}
11891199

1190-
def ReshapeElementwisePatterns : EnzymeHLOPatternOp<
1191-
"reshape_elementwise"
1192-
> {
1193-
let patterns = ["ReshapeElementwise"];
1200+
def ApplyReshapeElementwisePatterns : EnzymeHLOParameterizedPatternOp<
1201+
"reshape_elementwise"> {
1202+
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);
1203+
let assemblyFormat = "attr-dict";
1204+
// TODO: this should be made better searchable.
1205+
let extraClassDeclaration = [{
1206+
::llvm::SmallVector<::mlir::DictionaryAttr>
1207+
static getPossibleAttrCombinations(::mlir::Builder &builder) {
1208+
return {builder.getDictionaryAttr(
1209+
builder.getNamedAttr("parameter",
1210+
builder.getBoolAttr(true)))};
1211+
}
1212+
}];
11941213
}
11951214

11961215
def ReshapeOfConcatPatterns : EnzymeHLOPatternOp<

test/lit_tests/reshapeelementwise.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reshape_elementwise" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reshape_elementwise(1)" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
22

33
module {
44
func.func @main(%a : tensor<100x200x300xbf16>, %b: tensor<100x200x300xbf16>) -> tensor<20000x300xbf16> {

test/lit_tests/reshapeslice.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reshape_slice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reshape_slice(1)" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
22

33
module {
44
// Test case where reshape adds a unit dimension at the beginning.
@@ -71,4 +71,4 @@ module {
7171
return %140 : tensor<268x2060xf64>
7272
}
7373

74-
}
74+
}

0 commit comments

Comments
 (0)