Skip to content

Commit

Permalink
Add a new RecomposeComplexOps pass, fold slice+copy_ into indeX_put_ (l…
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 authored Mar 10, 2023
1 parent 2be48c3 commit 66b1045
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 0 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@
"DropoutTrainModule_basic",
"StdCorrectionKeepDimModule_basic",
"StdCorrectionNoneModule_basic",
"SliceCopy_Module_basic",
"SliceCopyNegative_Module_basic",
"VarBiasedModule_basic",
"VarCorrectionAllDimReduceModule_basic",
"VarCorrectionEmptyDimModule_basic",
Expand Down
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);

std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();

std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();

std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_mlir_library(TorchMLIRTorchPasses
LowerToBackendContract.cpp
MaximizeValueSemantics.cpp
PrepareForGlobalizeObjectGraph.cpp
RecomposeComplexOps.cpp
ReduceOpVariants.cpp
RefinePublicReturn.cpp
RefineTypes.cpp
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// Clean up again to avoid needing to to back around the fixed-point
// iteration.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
Expand Down
103 changes: 103 additions & 0 deletions lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopy_Op op,
PatternRewriter &rewriter) const override {
if (!op.getSelf().getDefiningOp() ||
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp()))
return failure();
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp());

// Get indices
int64_t dim;
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
return failure();
int64_t end;
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
return failure();

Value newEnd = sliceOp.getEnd();
if (end < 0) {
Value dimSize = rewriter.create<AtenSizeIntOp>(
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
newEnd =
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
}

Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);

// Create IndexPut_Op
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
Value range = rewriter.create<AtenArangeStartStepOp>(
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
/*pin_memory=*/noneVal);

SmallVector<Value> indicesVector;
for (auto i = 0; i < dim - 1; i++)
indicesVector.push_back(noneVal);
indicesVector.push_back(range);
Value indices = rewriter.create<PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(op->getContext(),
Torch::OptionalType::get(tensorType)),
indicesVector);

rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(),
/*accumulate=*/falseVal, /*unsafe=*/falseVal);

return success();
}
};
} // namespace

namespace {
class RecomposeComplexOps
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
public:
RecomposeComplexOps() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

// pattern.add calls go here
patterns.add<RecomposeSliceCopy_>(context);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRecomposeComplexOps() {
return std::make_unique<RecomposeComplexOps>();
}
44 changes: 44 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,47 @@ def forward(self, x):
@register_test_case(module_factory=lambda: NarrowVerticalTest2())
def NarrowVerticalTest2_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4))

# ==============================================================================

class SliceCopy_Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([10, 4, 4], torch.float32, True),
([4, 4, 4], torch.float32, True),
])
def forward(self, x, y):
xslice = torch.ops.aten.slice(x, 0, 2, 6, 1)
xslice.copy_(y)
return x


@register_test_case(module_factory=lambda: SliceCopy_Module())
def SliceCopy_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))

# ==============================================================================

class SliceCopyNegative_Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y):
xslice = torch.ops.aten.slice(x, 0, 2, -4, 1)
xslice.copy_(y)
return x


@register_test_case(module_factory=lambda: SliceCopyNegative_Module())
def SliceCopyNegative_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))

0 comments on commit 66b1045

Please sign in to comment.