Skip to content

Reduction optimization #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 165 additions & 4 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,129 @@ struct SlicePad final : OpRewritePattern<mlir::stablehlo::SliceOp> {
}
};

// From
// https://github.com/openxla/stablehlo/blob/5d1a9c892500c2e9fecbfedfa66ffe84ff1caf7b/stablehlo/dialect/StablehloOps.cpp#L1498C1-L1532C1
bool hasSameOperandAndResultTypes(Operation &op) {
Type expected;
if (op.getNumResults() != 0)
expected = op.getResult(0).getType();
if (op.getNumOperands() != 0)
expected = op.getOperand(0).getType();
if (!expected)
return false;

auto typeMatch = [&](Type actual) { return actual == expected; };
return llvm::all_of(op.getOperandTypes(), typeMatch) &&
llvm::all_of(op.getResultTypes(), typeMatch);
}

static bool isEligibleForCompactPrint(stablehlo::ReduceOp op) {
// Check E1.
auto &block = op.getBody().front();
if (!hasSingleElement(block.without_terminator()))
return false;

Operation &innerOp = *block.begin();

// Check E2.
if (innerOp.getDialect() != op->getDialect())
return false;

if (innerOp.getNumOperands() != 2 ||
!innerOp.hasTrait<mlir::OpTrait::OneResult>() ||
!hasSameOperandAndResultTypes(innerOp) ||
!innerOp.hasTrait<mlir::hlo::OpTrait::IsCommutative>() ||
!innerOp.hasTrait<mlir::OpTrait::ZeroRegions>())
return false;

// Check E3.
if (op.getInputs().empty())
return false;

auto elemType =
op.getInputs()[0].getType().cast<ShapedType>().getElementType();
auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType);
if (innerOp.getOperands()[0].getType() != expectedInnerOpType)
return false;

// Check E4.
if (!llvm::equal(block.getArguments(), innerOp.getOperands()))
return false;

// Check E5.
auto retOp = dyn_cast<stablehlo::ReturnOp>(block.getTerminator());
if (!retOp)
return false;

return llvm::equal(innerOp.getResults(), retOp.getOperands());
}

struct ReduceToReshape final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
PatternRewriter &rewriter) const override {
if (op.getInputs().size() != 1)
return failure();
if (!isEligibleForCompactPrint(op))
return failure();
auto inpTy = op.getInputs()[0].getType().cast<RankedTensorType>();
for (auto idx : op.getDimensions()) {
if (inpTy.getShape()[idx] != 1)
return failure();
}

auto reshaped = rewriter.create<stablehlo::ReshapeOp>(
op.getLoc(), op.getInitValues()[0].getType(), op.getInputs()[0]);

Operation &innerOp = op.getBody().front().front();

IRMapping map;
map.map(innerOp.getOperand(0), op.getInitValues()[0]);
map.map(innerOp.getOperand(1), reshaped);
auto res = rewriter.clone(innerOp, map)->getResult(0);

rewriter.replaceOp(op, res);
return success();
}
};

struct ReduceConcat final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
PatternRewriter &rewriter) const override {
if (op.getInputs().size() != 1)
return failure();

auto concat = op.getInputs()[0].getDefiningOp<stablehlo::ConcatenateOp>();
if (!concat)
return failure();

auto dim = concat.getDimension();

if (!llvm::is_contained(op.getDimensions(), dim))
return failure();

if (!isEligibleForCompactPrint(op))
return failure();

Operation &innerOp = op.getBody().front().front();

Value prev = op.getInitValues()[0];

for (auto v : concat.getOperands()) {
IRMapping map;
map.map(op.getInitValues()[0], prev);
map.map(op.getInputs()[0], v);
prev = rewriter.clone(*op, map)->getResult(0);
}

rewriter.replaceOp(op, prev);
return success();
}
};

struct SliceConcat final : OpRewritePattern<mlir::stablehlo::SliceOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -397,6 +520,42 @@ struct AddPad final : OpRewritePattern<mlir::stablehlo::AddOp> {
}
};

struct ConcatFuse final : OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() == 1 &&
op->getOperand(0).getType() == op.getType()) {
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
SmallVector<Value> vals;
bool changed = false;
for (auto v : op->getOperands()) {
if (auto c2 = v.getDefiningOp<stablehlo::ConcatenateOp>()) {
if (c2.getDimension() == op.getDimension()) {
for (auto v2 : c2->getOperands())
vals.push_back(v2);
changed = true;
continue;
}
}
if (v.getType().cast<RankedTensorType>().getShape()[op.getDimension()] ==
0) {
changed = true;
continue;
}
vals.push_back(v);
}
if (!changed)
return failure();
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(op, op.getType(),
vals);
return success();
}
};

struct ConcatConstProp final
: OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -932,10 +1091,12 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);
patterns.add<SlicePad, SliceSlice, AddPad, DotReshapeDot, ConcatConstProp,
/*ScatterToPad, */ BroadcastToReshape, SliceConcat,
SliceSimplification, CosSimplify, SinSimplify, SqrtSimplify,
AddSimplify, SubSimplify, NegateSimplify, MulSimplify,
DivSimplify, PowSimplify>(context);
ConcatFuse,
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
ReduceConcat, SliceConcat, SliceSimplification, CosSimplify,
SinSimplify, SqrtSimplify, AddSimplify, SubSimplify,
NegateSimplify, MulSimplify, DivSimplify, PowSimplify>(
context);
mlir::stablehlo::populateStablehloCanonicalizationPatterns(context,
&patterns);

Expand Down
4 changes: 3 additions & 1 deletion test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ expand_template(
substitutions = {
"@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.",
"@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"),
"@ENZYMEXLA_BINARY_DIR@": "",
"@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"),
"@ENZYME_SOURCE_DIR@": "",
"@ENZYME_BINARY_DIR@": "",
Expand All @@ -30,6 +31,7 @@ exports_files(
":lit.cfg.py",
":lit_site_cfg_py",
"//src/enzyme_ad/jax:enzyme_jax_internal",
"//:enzymexlamlir-opt",
"@llvm-project//clang:builtin_headers_gen",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:count",
Expand All @@ -38,7 +40,7 @@ exports_files(
)
for src in glob(
[
"**/*.pyt",
"**/*.pyt", "**/*.mlir",
],
)
]
Expand Down
3 changes: 2 additions & 1 deletion test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
config.test_format = lit.formats.ShTest(execute_external)

# suffixes: A list of file extensions to treat as test files.
config.suffixes = [".pyt"]
config.suffixes = [".pyt", ".mlir"]

# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
Expand All @@ -35,6 +35,7 @@
# Tweak the PATH to include the tools dir and the scripts dir.
base_paths = [
config.llvm_tools_dir,
config.enzymexla_tools_dir,
config.environment["PATH"],
]
path = os.path.pathsep.join(base_paths) # + config.extra_paths)
Expand Down
5 changes: 5 additions & 0 deletions test/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,10 @@ config.llvm_tools_dir = "@LLVM_TOOLS_BINARY_DIR@"
if len("@ENZYME_BINARY_DIR@") == 0:
config.llvm_tools_dir = os.getcwd() + "/" + config.llvm_tools_dir

config.enzymexla_tools_dir = "@ENZYMEXLA_BINARY_DIR@"

if len(config.enzymexla_tools_dir) == 0:
config.enzymexla_tools_dir = os.getcwd()

cfgfile = os.path.dirname(os.path.abspath(__file__)) + "/lit.cfg.py"
lit_config.load_config(config, cfgfile)
24 changes: 24 additions & 0 deletions test/lit_tests/reduceconcat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {

func.func @main(%a : tensor<2xf32>, %b : tensor<1xf32>, %c : tensor<1xf32>) -> tensor<f32> {
%cst0 = arith.constant dense<0.000000e+00> : tensor<f32>
%concat = stablehlo.concatenate %a, %b, %c, dim=0 : (tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4xf32>

%1308 = stablehlo.reduce(%concat init: %cst0) applies stablehlo.add across dimensions = [0] : (tensor<4xf32>, tensor<f32>) -> tensor<f32>

return %1308 : tensor<f32>

}
}

// CHECK: func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<f32> {
// CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: %0 = stablehlo.reduce(%arg0 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<2xf32>, tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<1xf32>) -> tensor<f32>
// CHECK-NEXT: %2 = stablehlo.add %0, %1 : tensor<f32>
// CHECK-NEXT: %3 = stablehlo.reshape %arg2 : (tensor<1xf32>) -> tensor<f32>
// CHECK-NEXT: %4 = stablehlo.add %2, %3 : tensor<f32>
// CHECK-NEXT: return %4 : tensor<f32>
// CHECK-NEXT: }