diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index eb9bd3371a49..b36ad30d2604 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -850,6 +850,8 @@ "DropoutTrainModule_basic", "StdCorrectionKeepDimModule_basic", "StdCorrectionNoneModule_basic", + "SliceCopy_Module_basic", + "SliceCopyNegative_Module_basic", "VarBiasedModule_basic", "VarCorrectionAllDimReduceModule_basic", "VarCorrectionEmptyDimModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 45cd888dc7f5..4cf27639ab90 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -98,6 +98,8 @@ std::unique_ptr> createRefinePublicReturnPass(); std::unique_ptr> createDecomposeComplexOpsPass(ArrayRef legalOps); +std::unique_ptr> createRecomposeComplexOps(); + std::unique_ptr> createPreprocessShapeLibraryPass(); std::unique_ptr> createReifyShapeCalculationsPass(); diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index 77f504f08dbe..ce577cf5bd49 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_library(TorchMLIRTorchPasses LowerToBackendContract.cpp MaximizeValueSemantics.cpp PrepareForGlobalizeObjectGraph.cpp + RecomposeComplexOps.cpp ReduceOpVariants.cpp RefinePublicReturn.cpp RefineTypes.cpp diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 4455ec1a785c..934ff7c25281 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -106,6 +106,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline( // Clean up again to avoid needing to to back around the fixed-point // iteration. pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createRecomposeComplexOps()); // Reduce variants of ops to a smaller set of primitives. pm.addNestedPass(createReduceOpVariantsPass()); pm.addNestedPass(createCanonicalizerPass()); diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp new file mode 100644 index 000000000000..7a5269946a48 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -0,0 +1,103 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class RecomposeSliceCopy_ : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenCopy_Op op, + PatternRewriter &rewriter) const override { + if (!op.getSelf().getDefiningOp() || + !isa(op.getSelf().getDefiningOp())) + return failure(); + auto sliceOp = cast(op.getSelf().getDefiningOp()); + + // Get indices + int64_t dim; + if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim))) + return failure(); + int64_t end; + if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end))) + return failure(); + + Value newEnd = sliceOp.getEnd(); + if (end < 0) { + Value dimSize = rewriter.create( + op.getLoc(), sliceOp.getSelf(), sliceOp.getDim()); + newEnd = + rewriter.create(op.getLoc(), dimSize, sliceOp.getEnd()); + } + + Value noneVal = rewriter.create(op.getLoc()); + Value falseVal = rewriter.create(op.getLoc(), false); + + // Create IndexPut_Op + BaseTensorType tensorType = op->getResultTypes()[0].cast(); + Value range = rewriter.create( + op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(), + /*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal, + /*pin_memory=*/noneVal); + + SmallVector indicesVector; + for (auto i = 0; i < dim - 1; i++) + indicesVector.push_back(noneVal); + indicesVector.push_back(range); + Value indices = rewriter.create( + op.getLoc(), + Torch::ListType::get(op->getContext(), + Torch::OptionalType::get(tensorType)), + indicesVector); + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(), + /*accumulate=*/falseVal, /*unsafe=*/falseVal); + + return success(); + } +}; +} // namespace + +namespace { +class RecomposeComplexOps + : public DecomposeComplexOpsBase { +public: + RecomposeComplexOps() = default; + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + // pattern.add calls go here + patterns.add(context); + + GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.maxIterations = GreedyRewriteConfig::kNoLimit; + + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createRecomposeComplexOps() { + return std::make_unique(); +} diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 032dadfb8e8b..1e8566826547 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -481,3 +481,47 @@ def forward(self, x): @register_test_case(module_factory=lambda: NarrowVerticalTest2()) def NarrowVerticalTest2_basic(module, tu: TestUtils): module.forward(tu.rand(6,4)) + +# ============================================================================== + +class SliceCopy_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([10, 4, 4], torch.float32, True), + ([4, 4, 4], torch.float32, True), + ]) + def forward(self, x, y): + xslice = torch.ops.aten.slice(x, 0, 2, 6, 1) + xslice.copy_(y) + return x + + +@register_test_case(module_factory=lambda: SliceCopy_Module()) +def SliceCopy_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4)) + +# ============================================================================== + +class SliceCopyNegative_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x, y): + xslice = torch.ops.aten.slice(x, 0, 2, -4, 1) + xslice.copy_(y) + return x + + +@register_test_case(module_factory=lambda: SliceCopyNegative_Module()) +def SliceCopyNegative_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))