Skip to content

Fold slice+copy_ into index_put_ #1901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@
"DropoutTrainModule_basic",
"StdCorrectionKeepDimModule_basic",
"StdCorrectionNoneModule_basic",
"SliceCopy_Module_basic",
"SliceCopyNegative_Module_basic",
"VarBiasedModule_basic",
"VarCorrectionAllDimReduceModule_basic",
"VarCorrectionEmptyDimModule_basic",
Expand Down
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);

std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();

std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();

std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_mlir_library(TorchMLIRTorchPasses
LowerToBackendContract.cpp
MaximizeValueSemantics.cpp
PrepareForGlobalizeObjectGraph.cpp
RecomposeComplexOps.cpp
ReduceOpVariants.cpp
RefinePublicReturn.cpp
RefineTypes.cpp
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// Clean up again to avoid needing to to back around the fixed-point
// iteration.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
Expand Down
103 changes: 103 additions & 0 deletions lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopy_Op op,
PatternRewriter &rewriter) const override {
if (!op.getSelf().getDefiningOp() ||
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp()))
return failure();
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp());

// Get indices
int64_t dim;
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
return failure();
int64_t end;
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
return failure();

Value newEnd = sliceOp.getEnd();
if (end < 0) {
Value dimSize = rewriter.create<AtenSizeIntOp>(
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
newEnd =
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
}

Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);

// Create IndexPut_Op
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
Value range = rewriter.create<AtenArangeStartStepOp>(
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
/*pin_memory=*/noneVal);

SmallVector<Value> indicesVector;
for (auto i = 0; i < dim - 1; i++)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be i < dim not i < dim - 1

Copy link
Collaborator

@AmosLewis AmosLewis Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the correct indices should be (torch.tensor([0, 0, 0]), torch.tensor([1, 2, 3])), the first for dim0, the second for dim1, not the (torch.tensor([1, 2, 3])). In python code is a[[0, 0, 0], [1, 2, 3]]. I don't think the convert the dim-1 to dim would fix this issue.

Copy link
Collaborator

@AmosLewis AmosLewis Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After convert the dim-1 to the dim, get the indices like this
%6 = torch.prim.ListConstruct %none, %5 : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
compare to the preiveous dim-1
%6 = torch.prim.ListConstruct %5 : (!torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [1]: import torch

In [2]: torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]),
   ...:                          (None, torch.tensor([1, 2, 3]),),
   ...:                          torch.tensor([[4, 5, 6]]))
Out[2]: tensor([[0, 4, 5, 6]])

That should work as expected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it correct.

indicesVector.push_back(noneVal);
indicesVector.push_back(range);
Value indices = rewriter.create<PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(op->getContext(),
Torch::OptionalType::get(tensorType)),
indicesVector);

rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(),
/*accumulate=*/falseVal, /*unsafe=*/falseVal);

return success();
}
};
} // namespace

namespace {
class RecomposeComplexOps
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
public:
RecomposeComplexOps() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

// pattern.add calls go here
patterns.add<RecomposeSliceCopy_>(context);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRecomposeComplexOps() {
return std::make_unique<RecomposeComplexOps>();
}
44 changes: 44 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,47 @@ def forward(self, x):
@register_test_case(module_factory=lambda: NarrowVerticalTest2())
def NarrowVerticalTest2_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4))

# ==============================================================================

class SliceCopy_Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([10, 4, 4], torch.float32, True),
([4, 4, 4], torch.float32, True),
])
def forward(self, x, y):
xslice = torch.ops.aten.slice(x, 0, 2, 6, 1)
xslice.copy_(y)
return x


@register_test_case(module_factory=lambda: SliceCopy_Module())
def SliceCopy_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))

# ==============================================================================

class SliceCopyNegative_Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y):
xslice = torch.ops.aten.slice(x, 0, 2, -4, 1)
xslice.copy_(y)
return x


@register_test_case(module_factory=lambda: SliceCopyNegative_Module())
def SliceCopyNegative_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))